Cython:如何声明numpy.argwhere()

时间:2018-01-15 06:59:47

标签: python performance numpy cython

我尝试将我的部分代码Cython化为以下内容,以期获得一些速度:

# cython: boundscheck=False
import numpy as np
cimport numpy as np
import time

cpdef object my_function(np.ndarray[np.double_t, ndim = 1] array_a,
                     np.ndarray[np.double_t, ndim = 1] array_b,
                     int n_rows,
                     int n_columns):
    cdef double minimum_of_neighbours, difference, change
    cdef int i
    cdef np.ndarray[np.int_t, ndim =1] locations
    locations = np.argwhere(array_a > 0)

    for i in locations:
        minimum_of_neighbours = min(array_a[i - n_columns], array_a[i+1], array_a[i + n_columns], array_a[i-1])
        if array_a[i] - minimum_of_neighbours < 0:
            difference = minimum_of_neighbours - array_a[i]
            change = min(difference, array_a[i] / 5.)
            array_a[i] += change
            array_b[i] -= change * 5.
        print time.time()

return array_a, array_b

我可以编译它而没有错误但是当我使用该函数时我得到了这个错误:

from cythonized_code import my_function
import numpy as np

array_a = np.random.uniform(low=-100, high=100, size = 100).astype(np.double)
array_b = np.random.uniform(low=0, high=20, size = 100).astype(np.double)

a, b = my_function(array_a,array_b,5,20)

# which gives me this error:    
# locations = np.argwhere(array_a > 0)
# ValueError: Buffer has wrong number of dimensions (expected 1, got 2)

我是否需要在此处声明locations类型?我想宣布它的原因是它在通过编译代码生成的带注释的HTML文件中有黄色。

1 个答案:

答案 0 :(得分:1)

不使用python-functions locations[i]是个好主意,因为它太贵了:Python会从低c整数创建一个完整的Python整数*(这是什么存储在locations - numpy数组中),将其注册到垃圾收集器中,然后将其强制转换回int,销毁Python对象 - 这是一个很大的开销。

要直接访问低位整数,需要将locations绑定到一个类型。正常的行动过程会过于查找,locations具有哪些属性:

>>> locations.ndim
2
>>> locations.dtype
dtype('int64')

转换为cdef np.ndarray[np.int64_t, ndim =2] locations

然而,由于Cython-quirk,这将(可能无法立即检查)不足以摆脱Python开销:

for i in locations:
    ...

不会被解释为原始数组访问,但会调用Python机器。请参阅示例here

所以你必须把它改成:

for index in range(len(locations)):
      i=locations[index][0]

然后Cython&#34;理解&#34;,您希望访问原始c-int64数组。

  • 实际上,并非完全正确:在这种情况下,首先创建nd.array(例如locations[0]locations[1])然后__Pyx_PyInt_As_int(或多或少)调用[PyLong_AsLongAndOverflow][2]的别名,创建PyLongObject,在临时intPyLongObject被破坏之前从中获取C - nd.array值。

这里我们很幸运,因为length-1 numpy-arrays可以转换为Python标量。如果locations的第二个维度为>1,则代码将无效。

相关问题