从列表的元素中查找互斥集的python组合

时间:2012-11-26 20:23:46

标签: python algorithm

在我目前正在开展的一个项目中,我已经实现了我希望程序执行的大约80%,我对结果非常满意。

在剩下的20%中,我遇到了一个问题,让我有点困惑如何解决。 这是:

我想出了一个包含多个数字(任意长度)的列表 例如:

listElement[0] = [1, 2, 3]
listElement[1] = [3, 6, 8]
listElement[2] = [4, 9]
listElement[4] = [6, 11]
listElement[n] = [x, y, z...]

其中n最多可达到40,000左右。

假设每个列表元素是一组数字(在数学意义上),我想要做的是导出互斥集的所有组合;也就是说,就像上面列表元素的powerset一样,但是排除了所有非相交集元素。

所以,为了继续n = 4的例子,我想提出一个包含以下组合的列表:

newlistElement[0] = [1, 2, 3]
newlistElement[1] = [3, 6, 8]
newlistElement[2] = [4, 9]
newlistElement[4] = [6, 11] 
newlistElement[5] = [[1, 2, 3], [4, 9]]
newlistElement[6] = [[1, 2, 3], [6, 11]]
newlistElement[7] = [[1, 2, 3], [4, 9], [6, 11]]
newlistElement[8] = [[3, 6, 8], [4, 9]]
newlistElement[9] = [[4, 9], [6, 11]

无效的情况,例如组合[[1,2,3],[3,6,8]],因为3在两个元素中是常见的。 有没有优雅的方法来做到这一点?我非常感谢任何反馈。

我还必须指定我不想做powerset函数,因为初始列表可能有相当多的元素(正如我所说n可以达到40000),并且使用这么多元素的powerset永远不会完成。

6 个答案:

答案 0 :(得分:4)

我会使用发电机:

import itertools

def comb(seq):
   for n in range(1, len(seq)):
      for c in itertools.combinations(seq, n): # all combinations of length n
         if len(set.union(*map(set, c))) == sum(len(s) for s in c): # pairwise disjoint?
            yield list(c)

for c in comb([[1, 2, 3], [3, 6, 8], [4, 9], [6, 11]]):
   print c

这会产生:

[[1, 2, 3]]
[[3, 6, 8]]
[[4, 9]]
[[6, 11]]
[[1, 2, 3], [4, 9]]
[[1, 2, 3], [6, 11]]
[[3, 6, 8], [4, 9]]
[[4, 9], [6, 11]]
[[1, 2, 3], [4, 9], [6, 11]]

如果您需要将结果存储在一个列表中:

print list(comb([[1, 2, 3], [3, 6, 8], [4, 9], [6, 11]]))

答案 1 :(得分:4)

以下是递归生成器:

def comb(input, lst = [], lset = set()):
   if lst:
      yield lst
   for i, el in enumerate(input):
      if lset.isdisjoint(el):
         for out in comb(input[i+1:], lst + [el], lset | set(el)):
            yield out

for c in comb([[1, 2, 3], [3, 6, 8], [4, 9], [6, 11]]):
   print c

在许多集合具有共同元素的情况下,这可能比其他解决方案更有效(当然,在最坏的情况下,它仍然必须遍历powerset的2**n元素)

答案 2 :(得分:3)

以下程序中使用的方法类似于以前的几个答案,排除了不相交的集合,因此通常不会测试所有组合。它与以前的答案不同,贪婪地排除它可以尽可能早的所有集合。这使它的运行速度比NPE的解决方案快几倍。下面是两种方法的时间比较,使用输入数据200,400,... 1000大小 - 6组,其中元素的范围为0到20:

Set size =   6,  Number max =  20   NPE method
  0.042s  Sizes: [200, 1534, 67]
  0.281s  Sizes: [400, 6257, 618]
  0.890s  Sizes: [600, 13908, 2043]
  2.097s  Sizes: [800, 24589, 4620]
  4.387s  Sizes: [1000, 39035, 9689]

Set size =   6,  Number max =  20   jwpat7 method
  0.041s  Sizes: [200, 1534, 67]
  0.077s  Sizes: [400, 6257, 618]
  0.167s  Sizes: [600, 13908, 2043]
  0.330s  Sizes: [800, 24589, 4620]
  0.590s  Sizes: [1000, 39035, 9689]

在上面的数据中,左栏显示了以秒为单位的执行时间。数字列表显示发生了多少单,双或三联合。程序中的常量指定数据集大小和特征。

#!/usr/bin/python
from random import sample, seed
import time
nsets,   ndelta,  ncount, setsize  = 200, 200, 5, 6
topnum, ranSeed, shoSets, shoUnion = 20, 1234, 0, 0
seed(ranSeed)
print 'Set size = {:3d},  Number max = {:3d}'.format(setsize, topnum)

for casenumber in range(ncount):
    t0 = time.time()
    sets, sizes, ssum = [], [0]*nsets, [0]*(nsets+1);
    for i in range(nsets):
        sets.append(set(sample(xrange(topnum), setsize)))

    if shoSets:
        print 'sets = {},  setSize = {},  top# = {},  seed = {}'.format(
            nsets, setsize, topnum, ranSeed)
        print 'Sets:'
        for s in sets: print s

    # Method by jwpat7
    def accrue(u, bset, csets):
        for i, c in enumerate(csets):
            y = u + [c]
            yield y
            boc = bset|c
            ts = [s for s in csets[i+1:] if boc.isdisjoint(s)]
            for v in accrue (y, boc, ts):
                yield v

    # Method by NPE
    def comb(input, lst = [], lset = set()):
        if lst:
            yield lst
        for i, el in enumerate(input):
            if lset.isdisjoint(el):
                for out in comb(input[i+1:], lst + [el], lset | set(el)):
                    yield out

    # Uncomment one of the following 2 lines to select method
    #for u in comb (sets):
    for u in accrue ([], set(), sets):
        sizes[len(u)-1] += 1
        if shoUnion: print u
    t1 = time.time()
    for t in range(nsets-1, -1, -1):
        ssum[t] = sizes[t] + ssum[t+1]
    print '{:7.3f}s  Sizes:'.format(t1-t0), [s for (s,t) in zip(sizes, ssum) if t>0]
    nsets += ndelta

编辑:在函数accrue中,参数(u, bset, csets)的用法如下:
•u =当前联合组中的集合列表
•bset =“big set”= u的平坦值=已使用的元素
•csets =候选集=有资格列入的集合列表
请注意,如果accrue的第一行替换为
def accrue(csets, u=[], bset=set()):
和第七行由 for v in accrue (ts, y, boc):
(即,如果重新排序参数并为u和bset指定默认值),则可以通过accrue调用[accrue(listofsets)]以生成其兼容联合列表。

关于使用Python 2.6时评论中提到的ValueError: zero length field name in format错误,请尝试以下操作。

# change:
    print "Set size = {:3d}, Number max = {:3d}".format(setsize, topnum)
# to:
    print "Set size = {0:3d}, Number max = {1:3d}".format(setsize, topnum)

程序中的其他格式可能需要类似的更改(添加适当的字段编号)。注意,what's new in 2.6页面说“支持str.format()方法已被反向移植到Python 2.6”。虽然它没有说明是否需要字段名称或数字,但它没有显示没有它们的示例。相比之下,无论哪种方式都适用于2.7.3。

答案 3 :(得分:1)

使用itertools.combinationsset.intersectionfor-else循环:

from itertools import *
lis=[[1, 2, 3], [3, 6, 8], [4, 9], [6, 11]]
def func(lis):
    for i in range(1,len(lis)+1):
       for x in combinations(lis,i):
          s=set(x[0])
          for y in x[1:]:
              if len(s & set(y)) != 0:
                  break
              else:
                  s.update(y)    
          else:
              yield x


for item in func(lis):
    print item

<强>输出:

([1, 2, 3],)
([3, 6, 8],)
([4, 9],)
([6, 11],)
([1, 2, 3], [4, 9])
([1, 2, 3], [6, 11])
([3, 6, 8], [4, 9])
([4, 9], [6, 11])
([1, 2, 3], [4, 9], [6, 11])

答案 4 :(得分:1)

NPE's solution类似,但它没有递归,它返回一个列表:

def disjoint_combinations(seqs):
    disjoint = []
    for seq in seqs:
        disjoint.extend([(each + [seq], items.union(seq))
                            for each, items in disjoint
                                if items.isdisjoint(seq)])
        disjoint.append(([seq], set(seq)))
    return [each for each, _ in disjoint]

for each in disjoint_combinations([[1, 2, 3], [3, 6, 8], [4, 9], [6, 11]]):
    print each

结果:

[[1, 2, 3]]
[[3, 6, 8]]
[[1, 2, 3], [4, 9]]
[[3, 6, 8], [4, 9]]
[[4, 9]]
[[1, 2, 3], [6, 11]]
[[1, 2, 3], [4, 9], [6, 11]]
[[4, 9], [6, 11]]
[[6, 11]]

答案 5 :(得分:0)

不使用itertools包的单行程序。 这是您的数据:

lE={}
lE[0]=[1, 2, 3]
lE[1] = [3, 6, 8]
lE[2] = [4, 9]
lE[4] = [6, 11]

这是单线:

results=[(lE[v1],lE[v2]) for v1 in lE for v2  in lE if (set(lE[v1]).isdisjoint(set(lE[v2])) and v1>v2)]
相关问题