张量的中间维度的分散操作

时间:2017-11-21 21:02:34

标签: tensorflow

我有一个3d张量,我需要在第二维中的某些位置保留向量,并将剩余的向量归零。位置指定为1d数组。我认为最好的方法是将张量乘以二进制掩码。

这是一个简单的Numpy版本:

A.shape: (b, n, m) 
indices.shape: (b)

mask = np.zeros(A.shape)
for i in range(b):
  mask[i][indices[i]] = 1
result = A*mask

因此对于A中的每个nxm矩阵,我需要保留由indices指定的行,并将其余部分归零。

我尝试使用tf.scatter_nd操作在TensorFlow中执行此操作,但我无法找出正确的索引形状:

shape = tf.constant([3,5,4])
A = tf.random_normal(shape)       
indices = tf.constant([2,1,4])   #???   
updates = tf.ones((3,4))           
mask = tf.scatter_nd(indices, updates, shape) 
result = A*mask

0 个答案:

没有答案
相关问题