使用mxnet.ndarray.UpSampling

时间:2017-12-20 02:24:19

标签: python mxnet

当我使用函数UpSampling(python,mxnet version:1.0.0)和最接近的插值时,一切正常(打印放大的输出形状):

nfilters = 16
xx = nd.random_normal(shape=[2,nfilters,64,64],ctx=mx.cpu())
print xx.asnumpy().shape
temp = nd.UpSampling(xx,scale=2,sample_type='nearest')
print temp.asnumpy().shape

当我尝试使用sample_type ='bilinear'执行相同的操作时,我收到错误:

nfilters = 16
xx = nd.random_normal(shape=[2,nfilters,64,64],ctx=mx.cpu())
print xx.asnumpy().shape
temp = nd.UpSampling(xx,scale=2,sample_type='bilinear')
print temp.asnumpy().shape
关于我做错了什么的任何指示/想法?我需要它才能正确地为ndarray和mx.sym工作(但我认为两者都应该相同)。

错误消息:

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-57-7b8d60ea54bb> in <module>()
      3 xx = nd.random_normal(shape=[2,nfilters,64,64],ctx=mx.cpu())
      4 print xx.asnumpy().shape
----> 5 temp = mx.nd.UpSampling(xx,scale=2,sample_type='bilinear')
      6 print temp.asnumpy().shape

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/ndarray/register.pyc in UpSampling(*data, **kwargs)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/_ctypes/ndarray.pyc in _imperative_invoke(handle, ndargs, keys, vals, out)
     90         c_str_array(keys),
     91         c_str_array([str(s) for s in vals]),
---> 92         ctypes.byref(out_stypes)))
     93 
     94     if original_output is not None:

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/base.pyc in check_call(ret)
    144     """
    145     if ret != 0:
--> 146         raise MXNetError(py_str(_LIB.MXGetLastError()))
    147 
    148 

MXNetError: [17:20:11] src/c_api/../imperative/imperative_utils.h:303: Check failed: num_inputs == infered_num_inputs (1 vs. 2) Operator UpSampling expects 2 inputs, but got 1 instead.

Stack trace returned 10 entries:
[bt] (0) /home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x289a1c) [0x7fe0ed9d6a1c]
[bt] (1) /home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x240538f) [0x7fe0efb5238f]
[bt] (2) /home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(+0x24029a2) [0x7fe0efb4f9a2]
[bt] (3) /home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x63) [0x7fe0efb4ffb3]
[bt] (4) /home/dia021/anaconda2/lib/python2.7/lib-dynload/_ctypes.so(ffi_call_unix64+0x4c) [0x7fe12e6dd57c]
[bt] (5) /home/dia021/anaconda2/lib/python2.7/lib-dynload/_ctypes.so(ffi_call+0x1f5) [0x7fe12e6dccd5]
[bt] (6) /home/dia021/anaconda2/lib/python2.7/lib-dynload/_ctypes.so(_ctypes_callproc+0x3e6) [0x7fe12e6d4376]
[bt] (7) /home/dia021/anaconda2/lib/python2.7/lib-dynload/_ctypes.so(+0x9db3) [0x7fe12e6cbdb3]
[bt] (8) /home/dia021/anaconda2/bin/../lib/libpython2.7.so.1.0(PyObject_Call+0x53) [0x7fe13375de93]
[bt] (9) /home/dia021/anaconda2/bin/../lib/libpython2.7.so.1.0(PyEval_EvalFrameEx+0x715d) [0x7fe13381080d]

1 个答案:

答案 0 :(得分:4)

mxnet.ndarray.UpSampling似乎期望2个输入(1输入和1个重量)的双线性sample_type

此外,我认为缺少num_args参数的文档,您可以在此处查看。 https://github.com/apache/incubator-mxnet/blob/master/src/operator/nn/upsampling-inl.h#L78

这应该有效:

import mxnet as mx
import mxnet.ndarray as nd
xx = nd.random_normal(shape=[1,1,256,256],ctx=mx.cpu())
xx1 = nd.random_normal(shape=[1,1,4,4],ctx=mx.cpu()) 
temp = nd.UpSampling(xx,xx1, num_filter=1, scale=2, sample_type='bilinear', num_args=2)
相关问题