如何有效地计算3d numpy数组中的相邻元素

时间:2014-08-06 20:53:59

标签: python numpy unique counting multidimensional-array

我有一个3d numpy数组填充,整数从1到7 我想计算每个单元的相邻单元中的唯一元素的数量。例如,在二维数组中:

a=[[1,1,1,7,4],
   [1,1,1,3,2],
   [1,1,1,2,2],
   [1,3,1,4,2],
   [1,1,1,4,2]]  

会产生以下结果:

[[1,1,2,3,2],
 [1,1,2,3,3],
 [1,2,2,4,1],
 [2,1,3,3,2],
 [1,2,2,3,2]]  

我目前正在遍历数组中的每个单元格并逐个检查其邻居。

temp = np.zeros(6)
if (x>0):
    temp[0] = model[x-1,y,z]
if (x<x_len-1):
    temp[1] = model[x+1,y,z]
if (y>0):
    temp[2] = model[x,y-1,z]
if (y<y_len-1):
    temp[3] = model[x,y+1,z]
if (z>0):
    temp[4] = model[x,y,z-1]
if (z<z_len-1):
    temp[5] = model[x,y,z+1]
result[x,y,z] = np.count_nonzero(np.unique(temp))  

我发现这很慢且效率低下。有没有更有效/更快的方法来做到这一点?

感谢。

2 个答案:

答案 0 :(得分:2)

嗯,可能有办法:

  • 创建6个偏移数组(左,右,上,下,前,后)
  • 将这些阵列组合成(R-2,C-2,D-2,6)4D阵列
  • 按最后一个维度(尺寸为6的维度)对4D数组进行排序

现在您有一个4D数组,您可以在其中为每个单元格选择一个已排序的邻居向量。之后,您可以通过以下方式计算不同的邻居:

  • diff用于第4轴(已排序的数组)
  • 计算沿第4轴的非零差值之和

这将为您提供不同邻居的数量 - 1。

第一部分可能相当清楚。如果一个小区有邻居(1,2,4,2,2,3),则邻居向量被分类为(1,2,2,2,3,4)。然后差分向量为(1,0,0,1,1),非零元素之和((diff(v) != 0).sum(axis=4))给出3.因此,有4个唯一的邻居。

当然,这种方法不考虑边缘。您可以通过numpy.pad模式reflect将初始数组按1个单元格填充到每个方向来解决。 (该模式实际上是唯一一个保证不会向邻域引入任何新值的模式,尝试用二维数组来理解原因。)

例如:

import numpy as np

# create some fictional data
dat = np.random.randint(1, 8, (6, 7, 8))

# pad the data by 1
datp = np.pad(dat, 1, mode='reflect')

# create the neighbouring 4D array
neigh = np.concatenate((
    datp[2:,1:-1,1:-1,None], datp[:-2,1:-1,1:-1,None], 
    datp[1:-1,2:,1:-1,None], datp[1:-1,:-2,1:-1,None],
    datp[1:-1,1:-1,2:,None], datp[1:-1,1:-1,:-2,None]), axis=3)

# sort the 4D array
neigh.sort(axis=3)

# calculate the number of unique samples
usamples = (diff(neigh, axis=3) != 0).sum(axis=3) + 1

上述解决方案非常普遍,适用于任何可排序的解决方案。但是,它消耗了大量内存(阵列的6个副本)并且不是高性能解决方案。如果我们对只适用于这种特殊情况的解决方案感到满意(值是非常小的整数),我们可以做一些魔术。

  • 创建一个数组,其中每个项目都表示为位掩码(1 = 00000001,2 = 00000010,3 = 00000100等)
  • 或相邻阵列在一起
  • 使用查找表计算ORed结果中的位数

import numpy as np

# create a "number of ones" lookup table
no_ones = np.array([bin(i).count("1") for i in range(256)], dtype='uint8')

# create some fictional data
dat = np.random.randint(1, 8, (6, 7, 8))

# create a bit mask of the cells
datb = 1 << dat.astype('uint8')

# pad the data by 1
datb = np.pad(datb, 1, mode='reflect')

# or the padded data together
ored = (datb[ 2:, 1:-1, 1:-1] |
        datb[:-2, 1:-1, 1:-1] |
        datb[1:-1,  2:, 1:-1] |
        datb[1:-1, :-2, 1:-1] |
        datb[1:-1, 1:-1,  2:] |
        datb[1:-1, 1:-1, :-2])

# get the number of neighbours from the LUT
usamples = no_ones[ored]

性能影响相当显着。第一个版本需要2.57秒,第二个版本需要283毫秒,我的机器上有一个384 x 384 x 100表(不包括创建随机数据)。这分别转换为19 ns和174 ns / cell。

然而,该解决方案仅限于存在合理数量的不同(和已知)值的情况。如果不同可能值的数量增长到64以上,那么魔法就失去了它的魅力。 (此外,在大约20个不同的值处,查找部分必须分成多个操作来执行LUT的内存消耗.LUT应该适合CPU缓存,否则会变慢。)

另一方面,扩展解决方案以使用完整的26邻域非常简单且非常快。

答案 1 :(得分:1)

您可以尝试以下操作,但不一定是最佳的,如果您的数据太大会导致问题,但这里有

import numpy as np
from sklearn.feature_extraction.image import extract_patches

a = np.array([[1,1,1,7,4],
              [1,1,1,3,2],
              [1,1,1,2,2],
              [1,3,1,4,2],
              [1,1,1,4,2]])

patches = extract_patches(a, patch_shape=(3, 3), extraction_step=(1, 1))

neighbor_template = np.array([[0, 1, 0],
                              [1, 0, 1],
                              [0, 1, 0]]).astype(np.bool)
centers = patches[:, :, 1, 1]
neighbors = patches[:, :, neighbor_template]

possible_values = np.arange(1, 8)
counts = (neighbors[..., np.newaxis] ==
          possible_values[np.newaxis, np.newaxis, np.newaxis]).sum(2)

nonzero_counts = counts > 0
unique_counter = nonzero_counts.sum(-1)

print unique_counter

产量

[[1 2 3]
 [2 2 4]
 [1 3 3]]

结果是你期望的数组的中间位置。为了获得带边框的完整数组,边框需要单独处理。使用numpy 1.8,您可以使用np.pad模式 median reflect来填充一个像素。这也可以正确完成边框。

现在让我们转向3D并确保我们不会使用太多内存。

# first we generate a neighbors template
from scipy.ndimage import generate_binary_structure

neighbors = generate_binary_structure(3, 1)
neighbors[1, 1, 1] = False
neighbor_coords = np.array(np.where(neighbors)).T

data = np.random.randint(1, 8, (384, 384, 100))
data_neighbors = np.zeros((neighbors.sum(),) + tuple(np.array(data.shape) - 2), dtype=np.uint8)

# extract_patches only generates a strided view
data_view = extract_patches(data, patch_shape=(3, 3, 3), extraction_step=(1, 1, 1))

for neigh_coord, data_neigh in zip(neighbor_coords, data_neighbors):
    sl = [slice(None)] * 3 + list(neigh_coord)
    data_neigh[:] = data_view[sl]

indicator = (data_neigh[np.newaxis] == possible_values[:, np.newaxis, np.newaxis, np.newaxis]).sum(1) > 0

uniques = indicator.sum(0)

和以前一样,您可以在uniques中找到唯一条目的数量。使用scipy中的generate_binary_structureextract_patches的滑动窗口等方法会使这种方法变得通用:如果你想要一个26邻域而不是6邻域,那么你只需要改变{{1}到generate_binary_structure(3, 1)。如果生成的数据量适合您机器的内存,它还可以直接推广到额外的尺寸。

相关问题