根据单元格索引填充numpy数组

时间:2018-11-30 16:09:00

标签: python numpy

我正在尝试创建一个3d数组,要从单元格索引中计算其单元格条目。 具体来说,我要该单元格(i,j,k) = sqrt(i+j+k)

使用以下for循环很容易做到这一点:

N=10
A=np.zeros((N,N,N))

for i in range(N):
    for j in range(N):
        for k in range(N):
            A[i][j][k] = np.sqrt(i+j+k)

我想知道numpy是否具有使这些嵌套的for循环多余的内置函数。

3 个答案:

答案 0 :(得分:6)

最简单,最有效的方法是使用list打开网格,然后执行相关操作-

np.ogrid

或使用I,J,K = np.ogrid[:N,:N,:N] A = np.sqrt(I+J+K) 获取单线开放网格的广播汇总-

np.sum

相关:General workflow on vectorizing loops involving range iterators

答案 1 :(得分:3)

您可以使用np.arange,然后使用np.newaxis创建不同的尺寸。通过简单的sumnp.sqrt可以完成以下任务:

arr = np.arange(N)
A = np.sqrt(arr + arr[:,np.newaxis]+ arr[:,np.newaxis,np.newaxis])

您得到相同的结果:

N = 10
arr = np.arange(N)
A = np.sqrt(arr + arr[:,np.newaxis]+ arr[:,np.newaxis,np.newaxis])
B = np.sqrt(np.sum(np.ogrid[:N,:N,:N]))
print ((A==B).all())
#True

此方法比使用np.ogrid快一点:

N = 10
%timeit arr = np.arange(N); A = np.sqrt(arr + arr[:,np.newaxis]+ arr[:,np.newaxis,np.newaxis])
#18.6 µs ± 3.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit A = np.sqrt(np.sum(np.ogrid[:N,:N,:N]))
#58.5 µs ± 8.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

答案 2 :(得分:2)

这对于大型N来说更快,但可能被认为是作弊;-)

它充分利用了高度规则和重复的模式,可以节省很多平方根的求值。

def cheat(N):
    values = np.sqrt(np.arange(3*N-2))
    result = np.lib.stride_tricks.as_strided(values, (N, N, N), 3*values.strides)
    return np.ascontiguousarray(result)

如果您可以使用非连续的只读视图(通过所有实际方法),则可以大大提高速度:

def cheat_nc_view(N):
    values = np.sqrt(np.arange(3*N-2))
    return np.lib.stride_tricks.as_strided(values, (N, N, N), 3*values.strides)

供参考:

def cheek(N):
    arr = np.arange(N)
    return np.sqrt(arr + arr[:,np.newaxis] + arr[:,np.newaxis,np.newaxis])

>>> np.all(cheek(20) == cheat(20))
True
>>> np.all(cheek(200) == cheat_nc_view(200))
True

时间:

>>> timeit(lambda: cheek(20), number=1000)
0.05387042500660755
>>> timeit(lambda: cheat(20), number=1000)
0.020798540994292125
>>> timeit(lambda: cheat_nc_view(20), number=1000)
0.010791150998556986

>>> timeit(lambda: cheek(200), number=100)
6.823299437994137
>>> timeit(lambda: cheat(200), number=100)
2.0583883369981777
>>> timeit(lambda: cheat_nc_view(200), number=100)
0.0014881940151099116