有没有办法检查NumPy阵列是否共享相同的数据?

时间:2012-07-02 01:09:33

标签: python numpy

我的印象是,在NumPy中,两个阵列可以共享相同的内存。请看以下示例:

import numpy as np
a=np.arange(27)
b=a.reshape((3,3,3))
a[0]=5000
print (b[0,0,0]) #5000

#Some tests:
a.data is b.data #False
a.data == b.data #True

c=np.arange(27)
c[0]=5000
a.data == c.data #True ( Same data, not same memory storage ), False positive

很明显b没有制作a的副本;它只是创建了一些新的元数据并将其附加到a正在使用的相同内存缓冲区。有没有办法检查两个数组是否引用相同的内存缓冲区?

我的第一印象是使用a.data is b.data,但返回false。我可以做a.data == b.data返回True,但我不认为检查以确保ab共享相同的内存缓冲区,只有a引用的内存块1}}和b引用的那个具有相同的字节。

4 个答案:

答案 0 :(得分:28)

您可以使用base属性检查阵列是否与另一个阵列共享内存:

>>> import numpy as np
>>> a = np.arange(27)
>>> b = a.reshape((3,3,3))
>>> b.base is a
True
>>> a.base is b
False

不确定是否能解决您的问题。如果阵列拥有自己的内存,则base属性为None。请注意,数组的基数将是另一个数组,即使它是一个子集:

>>> c = a[2:]
>>> c.base is a
True

答案 1 :(得分:8)

我认为jterrace的回答可能是最好的方法,但这是另一种可能性。

def byte_offset(a):
    """Returns a 1-d array of the byte offset of every element in `a`.
    Note that these will not in general be in order."""
    stride_offset = np.ix_(*map(range,a.shape))
    element_offset = sum(i*s for i, s in zip(stride_offset,a.strides))
    element_offset = np.asarray(element_offset).ravel()
    return np.concatenate([element_offset + x for x in range(a.itemsize)])

def share_memory(a, b):
    """Returns the number of shared bytes between arrays `a` and `b`."""
    a_low, a_high = np.byte_bounds(a)
    b_low, b_high = np.byte_bounds(b)

    beg, end = max(a_low,b_low), min(a_high,b_high)

    if end - beg > 0:
        # memory overlaps
        amem = a_low + byte_offset(a)
        bmem = b_low + byte_offset(b)

        return np.intersect1d(amem,bmem).size
    else:
        return 0

示例:

>>> a = np.arange(10)
>>> b = a.reshape((5,2))
>>> c = a[::2]
>>> d = a[1::2]
>>> e = a[0:1]
>>> f = a[0:1]
>>> f = f.reshape(())
>>> share_memory(a,b)
80
>>> share_memory(a,c)
40
>>> share_memory(a,d)
40
>>> share_memory(c,d)
0
>>> share_memory(a,e)
8
>>> share_memory(a,f)
8

以下是一个图表,显示每个share_memory(a,a[::2])来电的时间与我计算机上a中元素数量的函数关系。

share_memory function

答案 2 :(得分:5)

要完全解决问题,可以使用

import numpy as np

a=np.arange(27)
b=a.reshape((3,3,3))

# Checks exactly by default
np.shares_memory(a, b)

# Checks bounds only
np.may_share_memory(a, b)

np.may_share_memorynp.shares_memory都采用可选的max_work参数,使您可以决定投入多少精力来确保没有误报。这个问题是NP完全的,因此始终找到正确的答案在计算上会非常昂贵。

答案 3 :(得分:4)

只是做:

a = np.arange(27)
a.__array_interface__['data']

第二行将返回一个元组,其中第一个条目是内存地址,第二个元素是数组是否为只读。结合形状和数据类型,您可以计算出数组所覆盖的内存地址的确切范围,因此当一个数组是另一个数组的子集时,您也可以从中解决这个问题。

相关问题