用矩阵进行张量索引

时间:2019-04-24 16:45:35

标签: python tensorflow

我有矩阵(3 x 15)dummies,其中标记序列作为行:

[[ 1 66 67 68  0  0  0  0  0  0  0  0  0  0  0]
[ 1 66 67 66 68 66 67 66  0  0  0  0  0  0  0]
[ 1 66 67 68 18 19 20 21 22 23 24 25 26 17  0]]

此外,还有一个张量probs,形状为(3 x 15 x n_tokens),具有令牌概率。

probs中,我只需要选择dummies中令牌的概率。

我认为,可以将矩阵用作张量的索引,但是我还没有找到如何做的。

1 个答案:

答案 0 :(得分:1)

您可以这样做:

import tensorflow as tf

dummies = ...
probs = ...
s = tf.shape(dummies)
i = tf.range(s[0])
j = tf.range(s[1])
ii, jj = tf.meshgrid(i, j, indexing='ij')
idx = tf.stack([ii, jj, dummies], axis=-1)
result = tf.gather_nd(probs, idx)