TensorFlow:将张量分割为`batch_size`切片

时间:2018-03-03 20:15:27

标签: python tensorflow

我有一个名为tensor形状为[batch_size, axis_1, axis_2]的等级3张量,并希望沿着第一个轴将其拆分为batch_size个切片,如下所示:

batch_size = tf.shape(tensor)[0]

batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)

不幸的是,这并不起作用,因为在构建图表时batch_size的值尚未知晓。

我该如何解决这个问题?

我收到此错误:

TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.

奇怪的是,尝试在其他TensorFlow函数中使用batch_size似乎有效:

tensor = tf.reshape(tensor, [batch_size, -1])
尽管在图形构建过程中batch_size的值未知,但

仍能正常工作。

问题特别针对tf.split()

0 个答案:

没有答案