如何在张量流中使用索引数组?

时间:2017-03-06 10:10:41

标签: python numpy tensorflow

如果给定形状为a的矩阵(5,3)和形状为b的索引数组(5,),我们可以轻松获取相应的向量c,< / p>

c = a[np.arange(5), b]

但是,我不能用tensorflow做同样的事情,

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]
  

Traceback(最近一次调用最后一次):文件“”,第1行,in      文件   “〜/ anaconda2 / lib中/ python2.7 /站点包/ tensorflow /蟒蛇/ OPS / array_ops.py”   第513行,在_SliceHelper中       名称=名)

     

文件   “〜/ anaconda2 / lib中/ python2.7 /站点包/ tensorflow /蟒蛇/ OPS / array_ops.py”   第671行,strided_slice       shrink_axis_mask = shrink_axis_mask)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / ops / gen_array_ops.py”,   第3688行,strided_slice       shrink_axis_mask = shrink_axis_mask,name = name)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / op_def_library.py”,   第763行,在apply_op中       op_def = op_def)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py”,   第2397行,在create_op中       set_shapes_for_outputs(ret)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py”,   第1757行,在set_shapes_for_outputs中       shapes = shape_func(op)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py”,   第1707行,在call_with_requiring中       return call_cpp_shape_fn(op,require_shape_fn = True)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / common_shapes.py”,   第610行,在call_cpp_shape_fn中       debug_python_shape_fn,require_shape_fn)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / common_shapes.py”,   第675行,在_call_cpp_shape_fn_impl中       raise ValueError(err.message)ValueError:Shape必须为1,但对于'strided_slice_14'(op:'StridedSlice'),输入为2   形状:[5,3],[2,5],[2,5],[2]。

我的问题是,如果我不能使用上述方法在numpy中产生tensorflow的预期结果,我该怎么办?

1 个答案:

答案 0 :(得分:6)

此功能目前尚未在TensorFlow中实施。 GitHub issue #4638正在跟踪NumPy风格的“高级”索引的实现。但是,您可以使用tf.gather_nd()运算符来实现您的程序:

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))

row_indices = tf.range(5)

# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])

c = tf.gather_nd(a, indices)