Python平衡列表/ numpy数组中的项目

时间:2019-08-13 09:27:18

标签: python list numpy

我有一个令牌数组,每个令牌对应于从1n的不同类。我需要 balance tokens数组/列表,以便每个类有相等数量的标记。我想通过删除tokens的元素来做到这一点。

在下面的示例中,令牌数量最少的类是class 2,它只有2个令牌。因此,我想从其他类中删除元素,直到它们的数量也为2

例如

tokens  = array(['a','b','c','d','e','f','g','h','l'])

classes = array([ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3])

在此示例中,类以升序排列(为了清楚起见),但实际上,这些类没有特定的顺序。

例如

sol = array(['c','d','e','f','g','h'])

sol = array(['a','b','e','f','g','h'])

很明显,因为您可以选择要删除的多余元素,所以可以有不同的解决方案(如上)。我需要一个可以使用tokensclasses并输出sol的函数。

4 个答案:

答案 0 :(得分:2)

使用Counter的解决方案:

tokens = ['a','b','c','d','e','f','g','h','l']
lst    = [ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3]

from collections import Counter

c = Counter(lst)
min_cnt = min(c.values())
new_lst = list( zip(tokens, lst) )

while True:
    tmp = []
    should_break = True
    for t, i in new_lst:
        if c[i] > min_cnt:
            c[i] -= 1
            should_break = False
        else:
            tmp.append( (t, i) )

    new_lst = tmp

    if should_break:
        break

print([t for t, _ in new_lst])

打印:

['c', 'd', 'e', 'f', 'h', 'l']

使用groupby的其他可能解决方案:

tokens = ['a','b','c','d','e','f','g','h','l']
lst    = [ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3]

from collections import Counter
from itertools import groupby, islice

c = Counter(lst)
min_cnt = min(c.values())

out = []
for v, g in groupby(sorted(enumerate(zip(tokens, lst)), key=lambda k: k[1][1]), lambda k: k[1][1]):
    out.extend(islice(g, 0, min_cnt))

print( [val for _, (val, _) in sorted(out, key=lambda k: k[0])] )

打印:

['a', 'b', 'e', 'f', 'g', 'h']

答案 1 :(得分:1)

这是使用NumPy做到这一点的一种方法。这将始终选择每个类的第一个外观。

import numpy as np

def balance(tokens, classes):
    # Count appearances of each class
    c = np.bincount(classes - 1)
    n = c.min()
    # Accumulated counts for each class shifted one position
    cs = np.roll(np.cumsum(c), 1)
    cs[0] = 0
    # Compute appearance index for each class
    i = np.arange(len(classes)) - cs[classes - 1]
    # Mask excessive appearances
    m = i < n
    # Return corresponding tokens
    return tokens[m]

tokens  = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'l'])
classes = np.array([  1,   1,   1,   1,   2,   2,   3,   3,   3])
print(balance(tokens, classes))
# ['a' 'b' 'e' 'f' 'g' 'h']

就目前而言,当某些类完全丢失时(因为最小出现次数为零,因此解决方案中不会出现类),该函数将返回一个空数组,但是您可以根据需要进行调整。

>

答案 2 :(得分:1)

又一个简短的解决方案:

import random
from itertools import chain
from operator import itemgetter
import toolz

tokens  = ['a','b','c','d','e','f','g','h','l']
classes = [ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3]

groups = toolz.groupby(itemgetter(1), zip(tokens, classes))
max_size = len(min(groups.values(), key=len))
random_samples = chain.from_iterable(map(lambda x: random.sample(x, k=max_size), list(groups.values())))

chosen_tokens, corresponding_classes = list(zip(*random_samples))

或完全使用buildins个模块

import random
from itertools import chain, groupby, tee
from operator import itemgetter

tokens = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'l']
classes = [1, 1, 1, 1, 2, 2, 3, 3, 3]

groups_for_max_size, groups = tee(groupby(zip(tokens, classes), itemgetter(1)), 2)
max_size = len(min(groups_for_max_size, key = len))

random_samples = chain.from_iterable(map(lambda x: random.sample(list(x[1]), k = max_size), groups))
chosen_tokens, corresponding_classes = list(zip(*random_samples))

编辑:我认为还有一个更短的解决方案:

from itertools import chain, groupby
from operator import itemgetter

groups = (sorted(tokens, key=lambda x: random.random()) 
          for _, tokens in groupby(zip(tokens, classes), itemgetter(1)))
chosen_tokens, corresponding_classes = zip(*chain.from_iterable(zip(*groups)))

只有两个步骤:1.确保每个组的列表都是随机的(这在sorted(tokens, key=lambda x: random.random())中神奇地发生了,因为排序键始终是一个随机值)。 2.同样重要的是要知道zip对元素进行采样,直到用尽最短的生成器为止(这使该解决方案变得如此之短)。 zip(*groups)是一个迭代器,它在每次迭代中检索三元组(3个类)。由于我们事先对列表进行了混洗,因此对它们进行了随机采样。如果我们要再次分隔标记和类,则将三元组连接起来并再次解压缩。

答案 3 :(得分:1)

使用Counter的另一种解决方案:

import random
from collections import Counter

tokens  = np.array(['a','b','c','d','e','f','g','h','l'])
classes = np.array([ 1 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 3])

def sampling(tokens, classes):
    dc = {}
    sol = []
    for i in range(len(classes)):
        if classes[i] in dc:
            dc[classes[i]].append(tokens[i])
        else:
            dc[classes[i]] = [tokens[i]]
    sample_counts = Counter(classes)
    min_sample = min(sample_counts.values())
    for i in dc:
        sol += (random.sample(dc[i],min_sample))
    return sol

print(sampling(tokens, classes))

>>> ['d', 'a', 'f', 'e', 'g', 'h']
相关问题