我有一个名为tensor
形状为[batch_size, axis_1, axis_2]
的等级3张量,并希望沿着第一个轴将其拆分为batch_size
个切片,如下所示:
batch_size = tf.shape(tensor)[0]
batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
不幸的是,这并不起作用,因为在构建图表时batch_size
的值尚未知晓。
我该如何解决这个问题?
我收到此错误:
TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
奇怪的是,尝试在其他TensorFlow函数中使用batch_size
似乎有效:
tensor = tf.reshape(tensor, [batch_size, -1])
尽管在图形构建过程中batch_size
的值未知,但仍能正常工作。
问题特别针对tf.split()
?