找到一种更好的计算矩阵的方法

时间:2014-01-15 21:52:56

标签: python performance algorithm math numpy

我想计算只有1和0条目的2d数组的数量,这些条目具有相等的一对不相交的行,它们具有相等的向量和。对于4乘4矩阵,下面的代码通过迭代所有这些并依次测试每个代码来实现这一目的。

import numpy as np
from itertools import combinations
n = 4
nxn = np.arange(n*n).reshape(n, -1)
count = 0
for i in xrange(2**(n*n)):
   A = (i >> nxn) %2
   p = 1
   for firstpair in combinations(range(n), 2):
       for secondpair in combinations(range(n), 2):
           if firstpair < secondpair and not set(firstpair) & set(secondpair):
              if (np.array_equal(A[firstpair[0]] + A[firstpair[1]], A[secondpair[0]] + A[secondpair[1]] )):
                  if (p):
                      count +=1
                      p = 0
print count

输出为3136。

这个问题是它使用了2 ^(4 ^ 2)次迭代,我想将它运行为n到8次。是否有更聪明的方法来计算这些而不迭代所有矩阵?例如,一遍又一遍地创建相同矩阵的排列似乎毫无意义。

3 个答案:

答案 0 :(得分:8)

使用CPython 3.3在我的机器上大约一分钟计算:

4 3136
5 3053312
6 7247819776
7 53875134036992
8 1372451668676509696

代码,基于记忆包含 - 排除:

#!/usr/bin/env python3
import collections
import itertools

def pairs_of_pairs(n):
    for (i, j, k, m) in itertools.combinations(range(n), 4):
        (yield ((i, j), (k, m)))
        (yield ((i, k), (j, m)))
        (yield ((i, m), (j, k)))

def columns(n):
    return itertools.product(range(2), repeat=n)

def satisfied(pair_of_pairs, column):
    ((i, j), (k, m)) = pair_of_pairs
    return ((column[i] + column[j]) == (column[k] + column[m]))

def pop_count(valid_columns):
    return bin(valid_columns).count('1')

def main(n):
    pairs_of_pairs_n = list(pairs_of_pairs(n))
    columns_n = list(columns(n))
    universe = ((1 << len(columns_n)) - 1)
    counter = collections.defaultdict(int)
    counter[universe] = (- 1)
    for pair_of_pairs in pairs_of_pairs_n:
        mask = 0
        for (i, column) in enumerate(columns_n):
            mask |= (int(satisfied(pair_of_pairs, column)) << i)
        for (valid_columns, count) in list(counter.items()):
            counter[(valid_columns & mask)] -= count
    counter[universe] += 1
    return sum(((count * (pop_count(valid_columns) ** n)) for (valid_columns, count) in counter.items()))
if (__name__ == '__main__'):
    for n in range(4, 9):
        print(n, main(n))

答案 1 :(得分:3)

你可以在“比什么都没有”的情况下提交这个;-)这是一个简单的Python3代码,可以重新思考一下这个问题。也许笨拙的技巧可以大大提高它,但很难看出如何。

  1. 这里的“一行”是range(2**n)中的整数。所以数组只是一个整数元组。
  2. 因此,通过combinations_with_replacement()生成在行排列下唯一的所有数组都很容易。这会将外部循环上的行程计数从2**(n**2)减少到(2**n+n-1)-choose-n)。一个巨大的减少,但仍然......
  3. 预先计算的dict将行对(这里表示整数对!)映射到它们的向量和作为元组。因此,测试时不需要数组操作,除了测试元组是否相等。有了更多的技巧,元组可以编码为(比如)base-3整数,减少内循环测试以比较从一对dict查找中检索到的两个整数。
  4. 该预先计算的字典所需的时间和空间相对微不足道,因此没有尝试加速该部分。
  5. 内部循环一次选取行索引4,而不是每次选择两个索引的一对循环。在一口气中做所有4个更快,很大程度上是因为没有必要用重复的索引来清除对。
  6. 以下是代码:

    def calc_row_pairs(n):
        fmt = "0%db" % n
        rowpair2sum = dict()
        for i in range(2**n):
            row1 = list(map(int, format(i, fmt)))
            for j in range(2**n):
                row2 = map(int, format(j, fmt))
                total = tuple(a+b for a, b in zip(row1, row2))
                rowpair2sum[i, j] = total
        return rowpair2sum
    
    def multinomial(n, ks):
        from math import factorial as f
        assert n == sum(ks)
        result = f(n)
        for k in ks:
            result //= f(k)
        return result
    
    def count(n):
        from itertools import combinations_with_replacement as cwr
        from itertools import combinations
        from collections import Counter
        rowpair2sum = calc_row_pairs(n)
        total = 0
        class NextPlease(Exception):
            pass
        for a in cwr(range(2**n), n):
            try:
                for ix in combinations(range(n), 4):
                    for ix1, ix2, ix3, ix4 in (
                           ix,
                           (ix[0], ix[2], ix[1], ix[3]),
                           (ix[0], ix[3], ix[1], ix[2])):
                        if rowpair2sum[a[ix1], a[ix2]] == \
                           rowpair2sum[a[ix3], a[ix4]]:
                            total += multinomial(n, Counter(a).values())
                            raise NextPlease
            except NextPlease:
                pass
        return total
    

    这足以通过n = 6找到结果,虽然花了很长时间才能完成最后一次(多长时间?不知道 - 没有时间 - 大约一个小时,但是 - 很长时间“是相对的;-)):

    >>> count(4)
    3136
    >>> count(5)
    3053312
    >>> count(6)
    7247819776
    

    编辑 - 删除一些不必要的索引

    通过将主要功能更改为此来获得一个很好的加速:

    def count(n):
        from itertools import combinations_with_replacement as cwr
        from itertools import combinations
        from collections import Counter
        rowpair2sum = calc_row_pairs(n)
        total = 0
        for a in cwr(range(2**n), n):
            for r0, r1, r2, r3 in combinations(a, 4):
                if rowpair2sum[r0, r1] == rowpair2sum[r2, r3] or \
                   rowpair2sum[r0, r2] == rowpair2sum[r1, r3] or \
                   rowpair2sum[r0, r3] == rowpair2sum[r1, r2]:
                    total += multinomial(n, Counter(a).values())
                    break
        return total
    

    编辑 - 加快总和测试

    这是次要的,但由于这似乎是迄今为止表中最精确的方法,不妨从中榨取更多。如前所述,由于每个和都在range(3)中,因此每个元组可以用整数替换(将元组视为给出基数为3的整数的数字)。像这样替换calc_row_pairs()

    def calc_row_pairs(n):
        fmt = "0%db" % n
        rowpair2sum = dict()
        for i in range(2**n):
            row1 = list(map(int, format(i, fmt)))
            for j in range(2**n):
                row2 = map(int, format(j, fmt))
                total = 0
                for a, b in zip(row1, row2):
                    t = a+b
                    assert 0 <= t <= 2
                    total = total * 3 + t
                rowpair2sum[i, j] = total
        return rowpair2sum
    

    我确定numpy有更快的方法可以做到这一点,但是calc_row_pairs()花费的时间微不足道,为什么还要费心呢?顺便说一下,这样做的好处是内循环==测试从需要比较元组变为仅仅比较小整数。简单的Python从中受益,但我敢打赌,pypy可能会受益更多。

答案 2 :(得分:2)

不是你问题的直接答案,但正如我所指出的,我认为你可以安全地忘记对任何重要的n进行彻底测试所有矩阵。但问题很适合随机表征。有趣的是,在某些条件下,三次总和比双次总和更常见!获得命中的可能性似乎是n和m的相当简单(单调)函数,但没有惊喜。

double sums; n,m = 4..20

triple sums; n,m = 4..20

相关问题