在TensorFlow中,我如何知道对哪些行建立索引?

时间:2018-07-07 15:15:06

标签: python tensorflow

在这里,我有一个索引矩阵index,它是数组a的索引。张量如下。

import tensorflow as tf
import numpy as np

index = tf.constant([
            [ 1, 2, 3,-1,-1],
            [ 6, 1, 3,-1,-1],
            [ 1, 3,-1, 5, 6],
            [-1,-1,-1,-1,-1],
            [ 6,-1, 9,-1,-1]
        ])

a = tf.constant([0,0,0,0,
                 0,0,0,0,
                 0,0,0,0,
                 0,0,0,0,], dtype=np.int32)

我想获得一个数组indexed,该数组指示那些已被index索引的数组,如下所示。

indexed = [ 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0]
#           0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15

我知道tf.scatter_nd_updatetf.scatter_update可能会有所帮助。但是,我不知道如何处理-1,它代表无效的索引(仅用于填充长度)。那么,如何获得如上所述的indexed数组?

0 个答案:

没有答案