我有以下.pyx代码:
import cython
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
def f(m):
cdef int n = len(m)/2
cdef int j, k
z = [[0]*(n+1) for _ in range(n*(2*n-1))]
for j in range(1, 2*n):
for k in range(j):
z[j*(j-1)/2+k][0] = m[j][k]
return solve(z, 2*n, 1, [1] + [0]*n, n)
cdef solve(b, int s, int w, g, int n):
cdef complex h
cdef int u,v,j,k
if s == 0:
return w*g[n]
c = [b[(j+1)*(j+2)/2+k+2][:] for j in range(1, s-2) for k in range(j)]
h = solve(c, s-2, -w, g, n)
e = g[:]
for u in range(n):
for v in range(n-u):
e[u+v+1] += g[u]*b[0][v]
for j in range(1, s-2):
for k in range(j):
for u in range(n):
for v in range(n-u):
c[j*(j-1)/2+k][u+v+1] += b[(j+1)*(j+2)/2][u]*b[(k+1)*(k+2)/2+1][v] + b[(k+1)*(k+2)/2][u]*b[(j+1)*(j+2)/2+1][v]
return h + solve(c, s-2, w, e, n)
我不知道如何在cython中声明列表列表来加速代码。
例如,变量m
是表示为浮点数列表的列表。变量z
也是浮点数列表的列表。例如,def f(m)
行应该是什么样的?
根据@DavidW的回答中的建议,这是我的最新版本。
import cython
import numpy as np
def f(complex[:,:] m):
cdef int n = len(m)/2
cdef int j, k
cdef complex[:,:] z = np.zeros((n*(2*n-1), n+1), dtype = complex)
for j in range(1, 2*n):
for k in range(j):
z[j*(j-1)/2+k, 0] = m[j, k]
return solve(z, 2*n, 1, [1] + [0]*n, n)
cdef solve(complex[:,:] b, int s, int w, g, int n):
cdef complex h
cdef int u,v,j,k
cdef complex[:,:] c
if s == 0:
return w*g[n]
c = [b[(j+1)*(j+2)/2+k+2][:] for j in range(1, s-2) for k in range(j)]
print("c stats:", len(c), [len(c[i]) for i in len(c)])
h = solve(c, s-2, -w, g, n)
e = g[:]
for u in range(n):
for v in range(n-u):
e[u+v+1] = e[u+v+1] + g[u]*b[0][v]
for j in range(1, s-2):
for k in range(j):
for u in range(n):
for v in range(n-u):
c[j*(j-1)/2+k][u+v+1] = c[j*(j-1)/2+k][u+v+1] + b[(j+1)*(j+2)/2][u]*b[(k+1)*(k+2)/2+1][v] + b[(k+1)*(k+2)/2][u]*b[(j+1)*(j+2)/2+1][v]
return h + solve(c, s-2, w, e, n)
现在的主要问题是如何声明c,因为它目前是一个列表列表。
答案 0 :(得分:3)
列表列表不是一个可以从Cython中获得更多加速的结构。您应该使用的结构是2D typed memoryview:
def f(double[:,:] m):
# ...
这些索引为m[j,k]
而不是m[j][k]
。您可以向它们传递任何适当形状的对象,该对象公开Python缓冲区协议。大多数频率都是Numpy阵列。
您还应该避免使用@cython.boundscheck(False)
和@cython.wraparound(False)
之类的装饰器,除非您了解它们的作用并考虑它们是否适合您的功能。对于您当前的版本(您正在使用list
s),他们实际上什么都不做,并建议您在不理解的情况下复制它们。它们确实加快了记忆视图的索引(以某些安全为代价)。
编辑:在初始化c
方面,您有两种选择。
使用列表列表初始化numpy数组。这可能不是非常快(但如果其他步骤较慢则可能无关紧要):
c = np.array([b[(j+1)*(j+2)/2+k+2,:] for j in range(1, s-2) for k in range(j)], dtype=complex)
# note that I've changed the indexing of b slightly
使用适当大小的c
数组设置np.zeros
。将列表理解交换为两个循环。对我来说,这并不是100%显而易见,但这就像
c = np.zeros("some size you'll have to work out",dtype=complex)
for k in range(j):
for j in range(1,s-2):
c["some function of j and k",:] = b["some function of j and k",:]
您还希望将len(c)
替换为c.shape[0]
等
答案 1 :(得分:2)