批量收集/收集

时间:2018-07-02 20:19:01

标签: python tensorflow

我想知道TensorFlow中是否可以使用gather_nd或类似方法来完成以下操作。

我有两个张量:

  • values,形状为[128, 100]
  • indices,形状为[128, 3]

其中indices的每一行都包含沿values的第二维的索引(对于同一行)。我想使用valuesindices编制索引。例如,我想要执行此操作的东西(使用松散符号表示张量):

values  = [[0, 0, 0, 1, 1, 0, 1], 
           [1, 1, 0, 0, 1, 0, 0]]
indices = [[2, 3, 6], 
           [0, 2, 3]]
batched_gather(values, indices) = [[0, 1, 1], [1, 0, 0]]

此操作将遍历valuesindices的每一行,并使用values行在indices行上进行收集。

在TensorFlow中有一种简单的方法吗?

谢谢!

1 个答案:

答案 0 :(得分:1)

不确定这是否等同于“简单”,但是您可以为此使用def batched_gather(values, indices): row_indices = tf.range(0, tf.shape(values)[0])[:, tf.newaxis] row_indices = tf.tile(row_indices, [1, tf.shape(indices)[-1]]) indices = tf.stack([row_indices, indices], axis=-1) return tf.gather_nd(values, indices)

[0, 1]

说明:其构想是构建索引向量,例如indices,意思是“第0行和第1列中的值”。
该功能的tf.shape自变量中已经给出了列索引。
行索引是从0到例如128(在您的示例中),但根据每行的列索引数重复(平铺)(在您的示例中为3;如果此数字固定,则可以对其进行硬编码,而不使用array([[[0, 2], [0, 3], [0, 6]], [[1, 0], [1, 2], [1, 3]]]) )。
然后,将行和列索引堆叠起来以生成索引向量。在您的示例中,结果索引为

gather_nd

和{{1}}会产生所需的结果。

相关问题