如何检查numpy数组的所有元素是否在另一个numpy数组中

时间:2018-07-09 18:34:03

标签: python arrays numpy vectorization

我有两个2D numpy数组,例如:

A = numpy.array([[1, 2, 4, 8], [16, 32, 32, 8], [64, 32, 16, 8]])

B = numpy.array([[1, 2], [32, 32]])

我想拥有A的所有行,在这里我可以找到B的任何行的所有元素。 B行中有2个相同元素的地方,A中的行也必须至少包含2个。在我的示例中,我想实现以下目标:

A_filtered = [[1, 2, 4, 8], [16, 32, 32, 8]]

我可以控制值的表示形式,因此我选择了数字表示形式,其中二进制表示形式仅以1占据一个位置(例如:0b000000010b00000010,等等。)我可以使用np.logical_or.reduce()函数轻松检查所有类型的值是否都在行中,但是我无法检查A行中相同元素的数量是否大于或等于。我真的希望我可以避免简单的for循环和数组的深拷贝,因为性能对我来说是非常重要的一个方面。

如何以有效的方式在numpy中做到这一点?


更新

here中的解决方案可能有效,但我认为性能对我来说是一个很大的问题,A可能非常大(> 300000行),而B可能中等(> 30):

[set(row).issuperset(hand) for row in A.tolist() for hand in B.tolist()]

更新2:

set()解决方案不起作用,因为set()删除了所有重复的值。

2 个答案:

答案 0 :(得分:1)

我认为这应该可行:

首先,按以下方式对数据进行编码(假设您的二进制方案也暗示了“令牌”的数量有限):

制作一个形状[n_rows,n_tokens],dtype int8,其中每个元素都计算令牌的数量。以相同的方式编码B,形状为[n_hands,n_tokens]

这允许您输出的单个向量化表达式; matchs =(A [None,:,:]> = B [:, None,:])。all(axis = -1)。 (确切地说,如何将这个匹配数组映射到所需的输出格式作为练习的内容留给了读者,因为问题使它在多个匹配中都未定义)。

但是我们在这里谈论的是每个令牌10Mbyte的内存。即使有了32个令牌,这也不应该是不可想象的。但是在这种情况下,最好不要对n_tokens或n_hands或两者上的循环进行矢量化处理; for循环对于小n来说很好,或者如果主体中有足够的工作要做,则循环开销微不足道。

只要n_tokens和n_hands保持适度,我认为这将是最快的解决方案,如果不使用纯python和numpy的话。

答案 1 :(得分:1)

希望您的问题正确。至少它可以解决您在问题中描述的问题。如果输出的顺序应与输入的顺序相同,请更改就地排序。

该代码看起来很丑陋,但是应该表现良好,并且不难理解。

代码

import time
import numba as nb
import numpy as np

@nb.njit(fastmath=True,parallel=True)
def filter(A,B):
  iFilter=np.zeros(A.shape[0],dtype=nb.bool_)

  for i in nb.prange(A.shape[0]):
    break_loop=False

    for j in range(B.shape[0]):
      ind_to_B=0
      for k in range(A.shape[1]):
        if A[i,k]==B[j,ind_to_B]:
          ind_to_B+=1

        if ind_to_B==B.shape[1]:
          iFilter[i]=True
          break_loop=True
          break

      if break_loop==True:
        break

  return A[iFilter,:]

衡量效果

####First call has some compilation overhead####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

####Let's measure the second call too####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

结果

46ms after the first run on a dual-core Notebook (sorting included)
32ms (sorting excluded)