TensorFlow串联/堆叠N张量交错最后一个维度

时间:2019-07-01 07:59:48

标签: python tensorflow

假设我们有四个张量abcd,它们都共享(batch_size, T, C)的相同尺寸,我们想创建一个新的张量X的形状为(batch_size, T*4, C),其中T*4在所有张量之间交错循环。

例如,如果abcd分别是所有1、2、3和4的张量,则我们期望{{1} }就像

X

2 个答案:

答案 0 :(得分:2)

在我看来,您的示例数组实际上具有形状(batch_size, T, C*4)而不是(batch_size, T*4, C)。无论如何,您可以使用tf.concat,tf.reshape和tf.transpose获得所需的内容。 2d中的一个简单示例如下:

A = tf.ones([2,3])
B = tf.ones([2,3]) * 2
AB = tf.concat([A,B], axis=1)
AB = tf.reshape(AB, [-1, 3])
AB.eval() #array([[1., 1., 1.],
   # [2., 2., 2.],
   # [1., 1., 1.],
   # [2., 2., 2.]], dtype=float32)

您将A和B连接起来得到形状为(2,6)的矩阵。然后,您可以对它进行整形,使其与行交错。为此,在3d中,要乘以4的尺寸必须是最后一个尺寸。因此,您可能需要使用tf.transpose,使用concat进行交织并整形,然后再次进行转置以重新排列尺寸。

答案 1 :(得分:1)

我认为另一种选择是使用tf.tile

import tensorflow as tf

tf.enable_eager_execution()

A = tf.ones((2, 1, 4))
B = tf.ones((2, 1, 4)) * 2
C = tf.ones((2, 1, 4)) * 3
ABC = tf.concat([A, B, C], axis=1)

print(ABC)
#tf.Tensor(
#[[[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]
#
# [[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]], shape=(2, 3, 4), dtype=float32)

X = tf.tile(ABC, multiples=[1, 3, 1])

print(X)
#tf.Tensor(
#[[[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]
#
# [[1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]
#  [1. 1. 1. 1.]
#  [2. 2. 2. 2.]
#  [3. 3. 3. 3.]]], shape=(2, 9, 4), dtype=float32)