在MXNet NDArray的列上使用索引数组

时间:2017-09-01 20:02:50

标签: mxnet

给定一个索引数组index,比如一个矩阵A我希望矩阵B具有A列的相应排列。

在Numpy,我会做以下事情,

>>> A = np.arange(6).reshape(2,3); A
array([[0, 1, 2],
       [3, 4, 5]])
>>> index = [2,0,1]
>>> A[:,index]
array([[2, 0, 1],
       [5, 3, 4]])

在MXNet中有 自然 高效 方式吗?函数pick()take()似乎不会以这种方式工作。我设法提出以下但是它并不优雅。

>>> mx.nd.take(A.T, mx.nd.array([[2],[0],[1]])).T.reshape((2,3))

[[ 2.  0.  1.]
 [ 5.  3.  4.]]
<NDArray 2x3 @cpu(0)>

最后,为了解决这个问题,有没有办法在这里进行这项工作?

更新这是一个稍微优雅,但可能不那么有效(由于换位),上面的版本:

>>> mx.nd.take(A.T, mx.nd.array([2,0,1])).T
[[ 2.  0.  1.]
 [ 5.  3.  4.]]
<NDArray 2x3 @cpu(0)>

1 个答案:

答案 0 :(得分:2)

您需要的是MXNet中所谓的高级索引。提交了一个PR,用于通过MXNet NDArray的高级索引获取元素,并且还将设置元素的功能添加到NDArray。预计将在1.0版本中发布。

https://github.com/apache/incubator-mxnet/pull/8246

相关问题