在元组列表中分组和查找相同的元素?

时间:2017-02-16 19:26:26

标签: python list grouping

我有几个这样的元组列表(它们代表坐标):

a = [(100,100), (50,60)]
b = [(100,50), (50,60)]
c = [(100,100), (20,60)]
d = [(70,100), (50,10)]
e = [(100,80), (70,100)]

我想知道如何有效地管理它们以查找相同的值,然后将整个列表存储在单独的列表中。

由于它们是坐标,因此每个元组的X和Y在另一个元组中不能相同(即相同的X,但是不同的Y)。

对于上面的例子,我想最终得到这样的东西(作为列表,但如果可能的话,也会以更有效的方式):

new_list1 = [a, b, c]
new_list2 = [d, e]

如果没有在多个列表之间进行一对一解析,是否有更有效的方法来获得此结果?

1 个答案:

答案 0 :(得分:1)

好的,这是一个numpy tastic vectorised方法。似乎相当快。虽然没有彻底测试。它显式假设每个列表有两个元组,每个元组2坐标。

import time
import numpy as np

def find_chains_nmp(lists):
    lists = np.asanyarray(lists)
    lists.shape = -1,2
    dtype = np.rec.fromrecords(lists[:1, :]).dtype
    plists = lists.view(dtype)
    lists.shape = -1, 2, 2
    uniq, inv = np.unique(plists, return_inverse=True)
    uniqf = uniq.view(lists.dtype).reshape(-1, 2)
    inv.shape = -1, 2
    to_flip = inv[:, 0] > inv[:, 1]
    inv[to_flip, :] = inv[to_flip, ::-1].copy()
    sl = np.lexsort(inv.T[::-1])
    sr = np.lexsort(inv.T)
    lj = inv[sl, 0].searchsorted(np.arange(len(uniq)+1))
    rj = inv[sr, 1].searchsorted(np.arange(len(uniq)+1))
    mask = np.ones(uniq.shape, bool)
    mask[0] = False
    rooted = np.zeros(uniq.shape, int)
    l, r = 0, 1
    blocks = [0]
    rblocks = [0]
    reco = np.empty_like(lists)
    reci = 0
    while l < len(uniq):
        while l < r:
            ll = r
            for c in rooted[l:r]:
                if (rj[c]==rj[c+1]) and (lj[c]==lj[c+1]):
                    continue
                connected = np.r_[inv[sr[rj[c]:rj[c+1]], 0],
                                  inv[sl[lj[c]:lj[c+1]], 1]]
                reco[reci:reci+lj[c+1]-lj[c]] = uniqf[inv[sl[lj[c]:lj[c+1]], :]]
                reci += lj[c+1]-lj[c]
                connected = np.unique(connected[mask[connected]])
                mask[connected] = False
                rr = ll + len(connected)
                rooted[ll:rr] = connected
                ll = rr
            l, r = r, rr
        blocks.append(l)
        rblocks.append(reci)
        if l == len(uniq):
            break
        r = l + 1
        rooted[l] = np.where(mask)[0][0]
        mask[rooted[l]] = 0
    return blocks, rblocks, reco, uniqf[rooted]


# obsolete
def find_chains(lists):
    outlist = []
    outinds = []
    outset = set()
    for j, l in enumerate(lists):
        as_set = set(l)
        inds = []
        for k in outset.copy():
            if outlist[k] & as_set:
                outset.remove(k)
                as_set |= outlist[k]
                inds.extend(outinds[k])
        outset.add(j)
        outlist.append(as_set)
        outinds.append(inds + [j])
    outinds = [outinds[j] for j in outset]
    del outset, outlist
    result = [[lists[j] for j in k] for k in outinds]
    return result, outinds


if __name__ == '__main__':
    a = [(100,100), (50,60)]
    b = [(100,50), (50,60)]
    c = [(100,100), (20,60)]
    d = [(70,100), (50,10)]
    e = [(100,80), (70,100)]

    lists = [a, b, c, d, e]
    print(find_chains(lists))



    lists = np.array(lists)
    tblocks, lblocks, lreco, treco = find_chains_nmp(lists)


    coords = np.random.random((12_000, 2))
    pairs = np.random.randint(0, len(coords), (12_000, 2))
    pairs = np.delete(pairs, np.where(pairs[:, 0] == pairs[:, 1]), axis=0)
    pairs = coords[pairs, :]
    t0 = time.time()
    tblocks, lblocks, lreco, treco = find_chains_nmp(pairs)
    t0 = time.time() - t0
    print('\n\nproblem:')
    print('\n\ntuples {}, lists {}'.format(len(coords), len(pairs)))
    if len(pairs) < 40:
        for k, l in enumerate(pairs):
            print('[({:0.6f}, {:0.6f}), ({:0.6f}, {:0.6f})]    '
                  .format(*l.ravel()), end='' if k % 2 != 1 else '\n')
    print('\n\nsolution:')
    for j, (lists, tuples) in enumerate(zip(
            np.split(lreco, lblocks[1:-1], axis=0),
            np.split(treco, tblocks[1:-1], axis=0))):
        print('\n\ngroup #{}: {} tuples, {} list{}'.format(
            j + 1, len(tuples), len(lists),
            's' if len(lists) != 1 else ''))
        if len(pairs) < 40:
            print('\ntuples:')
            for k, t in enumerate(tuples):
                print('({:0.6f}, {:0.6f})    '.format(*t),
                      end='' if k % 4 != 3 else '\n')
            print('\nlists:')
            for k, l in enumerate(lists):
                print('[({:0.6f}, {:0.6f}), ({:0.6f}, {:0.6f})]    '
                      .format(*l.ravel()), end='' if k % 2 != 1 else '\n')
    print('\n\ncomputation time', t0)
相关问题