使用tf.tile复制张量N次

时间:2017-07-10 17:10:22

标签: tensorflow

我当前的张量具有(3,2)的形状,例如,  [[ 1. 2.] [ 2. 1.] [-2. -1.]]

我想扩展为(1,3,2)的形状,每个第二维度是整个张量的复制品,例如,

[[[ 1.  2.]
  [ 2.  1.]
  [ -2.  -1.]]
 [[ 1.  2.]
  [ 2.  1.]
  [ -2.  -1.]]
[[ 1.  2.]
  [ 2.  1.]
  [ -2. -1.]]]

我尝试了以下代码,但它只复制了每一行。

tiled_vecs = tf.tile(tf.expand_dims(input_vecs, 1),
                      [1, 3, 1])

结果

[[[ 1.  2.]
[ 1.  2.]
[ 1.  2.]]
[[ 2.  1.]
 [ 2.  1.]
 [ 2.  1.]]
[[-2. -1.]
 [-2. -1.]
 [-2. -1.]]]

1 个答案:

答案 0 :(得分:11)

这应该有效,

  (pf.shape(A)[0],1,1] * A

# Achieved by creating a 3d matrix as shown below 
# and multiplying it with A, which is `broadcast` to obtain the desired result.
 [[[1.]],
  [[1.]],   * A
  [[1.]]]

代码示例:

 #input 
 A = tf.constant([[ 1.,  2.], [ 2. , 1.],[-2., -1.]])
 B = tf.ones([tf.shape(A)[0], 1, 1]) * A

 #output
 array([[[ 1.,  2.],
    [ 2.,  1.],
    [-2., -1.]],

   [[ 1.,  2.],
    [ 2.,  1.],
    [-2., -1.]],

   [[ 1.,  2.],
    [ 2.,  1.],
    [-2., -1.]]], dtype=float32)

同样使用tf.tile,我们可以获得相同的内容:

  

tf.tile(tf.expand_dims(A,0),[tf.shape(A)[0],1,1]

相关问题