多处理 - 具有多维numpy数组的共享内存

时间:2018-05-10 10:03:48

标签: python numpy python-multiprocessing

我需要并行处理一个非常大的numpy数组(55x117x256x256)。尝试使用通常的多处理方法传递它会产生AssertionError,我理解这是因为数组太大而无法复制到每个进程中。因此,我想尝试使用多处理共享内存。 (我对其他方法持开放态度,只要它们不太复杂)。

我已经看到了一些关于使用python多处理共享内存方法的问题,例如

import numpy as np
import multiprocessing as mp

unsharedData = np.zeros((10,))
sharedData = mp.Array('d', unsharedData)

似乎工作正常。但是,我还没有看到一个用多维数组完成的例子。

我尝试将多维数组放入mp.Array,这样就可以得到TypeError: only size-1 arrays can be converted to Python scalars

unsharedData2 = np.zeros((10,10))
sharedData2 = mp.Array('d', unsharedData2)## Gives TypeError

我可以压扁阵列,但是如果可以避免的话,我宁愿也不要。

是否有一些技巧可以让多处理数组处理多维数据?

2 个答案:

答案 0 :(得分:0)

您可以使用 np.reshape((-1,))np.ravel 代替 np.flatten 来制作数组的一维 view,而无需进行 flatten 所做的不必要的复制:

import numpy as np
import multiprocessing as mp

unsharedData2 = np.zeros((10, 10))
ravel_copy = np.ravel(unsharedData2)
reshape_copy2 = unsharedData2.reshape((-1,))
ravel_copy[11] = 1.0       # -> saves 1.0 in unsharedData2 at [1, 1]
reshape_copy2[22] = 2.0    # -> saves 2.0 in unsharedData2 at [2, 2]
sharedData2 = mp.Array('d', ravel_copy)
sharedData2 = mp.Array('d', reshape_copy2)

答案 1 :(得分:0)

您可以使用与 get_obj() 关联的 Array 方法在共享相同内存的每个进程中创建一个新的多维 numpy 数组,该方法返回呈现缓冲区接口的 ctypes 数组。

请看下面的例子:

import ctypes as c
import numpy as np
import multiprocessing as mp


unsharedData2 = np.zeros((10, 10))
n, m = unsharedData2.shape[0], unsharedData2.shape[1]


def f1(mp_arr):
    #in each new process create a new numpy array as follows:
    arr = np.frombuffer(mp_arr.get_obj())
    b = arr.reshape((n, m))# mp_arr arr and b share the same memory
    b[2][1] = 3


def f2(mp_arr):
    #in each new process create a new numpy array as follows:
    arr = np.frombuffer(mp_arr.get_obj())
    b = arr.reshape((n, m)) # mp_arr arr and b share the same memory
    b[1][1] = 2


if __name__ == '__main__':
    mp_arr = mp.Array(c.c_double, n*m)
    p = mp.Process(target=f1, args=(mp_arr,))
    q = mp.Process(target=f2, args=(mp_arr,))
    p.start()
    q.start()
    p.join()
    q.join()
    arr = np.frombuffer(mp_arr.get_obj())
    b = arr.reshape((10, 10))
    print(b)
    '''
    [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 2. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 3. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
     [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
    '''
相关问题