我有一个numpy数组,想要"望远镜"基于顶行的值。一个例子是描述它的最佳方式
启动数组:
9 9 8 7 7 7 6
1 2 3 4 5 6 3
3 4 5 6 7 6 3
5 6 7 8 9 6 4
所需的输出数组:
9 8 7 6
3 3 15 3
7 5 19 3
11 7 23 4
这个想法是唯一的 - 在顶行中按值分组的后续行中的顶行和总和值。顶行将被排序,阵列将是大约2000个单元宽和200,000个单元长。顶行中可以有任意数量的连续相同数字。我当前的黑客就是这个(示例中的顶行标签略有不同,我打印到屏幕而不是创建最终数组以检查输出。计划是堆叠输出以生成输出数组)
import numpy as N
kk=N.array([[90,90,85,80,80,80,70],[1,2,3,4,5,6,3],[3,4,5,6,7,6,3],[5,6,7,8,9,6,4]])
ll=kk[:,0]
for i in range(1,len(kk[0])):
if kk[0][i]==kk[0][i-1]:
ll=ll+kk[:,i]
elif kk[0][i]!=kk[0][i-1]:
print "sum=", ll, i,kk[0][i],kk[0][i-1]
ll=kk[:,i]
有两个缺点。主要的一点是,它没有处理最后一栏,我不明白为什么。次要的是,它也是排在第一行的。很明显,为什么这个小问题正在发生。我怀疑我可以在那个方面找到方法,但是未能处理最后一栏一直让我感到沮丧,我真的很感激任何处理它的建议。
感谢您的帮助
答案 0 :(得分:4)
如果你有200,000
行,那么Python循环可能会非常慢。使用NumPy,您可以使用np.add.reduceat
向量化该操作,但首先需要创建一个数组,其中包含第一行中每组重复条目的第一项的索引:
mask = np.concatenate(([True], kk[0, 1:] != kk[0, :-1]))
indices, = np.nonzero(mask)
然后,您可以通过使用mask
布尔数组索引它来获取第一行:
>>> kk[0, mask]
array([90, 85, 80, 70])
以及使用reduceat
indices
的数组的其余部分:
>>> np.add.reduceat(kk[1:], indices, axis=1)
array([[ 3, 3, 15, 3],
[ 7, 5, 19, 3],
[11, 7, 23, 4]])
假设您的原始数组是默认的整数类型,您可以通过执行以下操作来组装数组:
out = np.empty((kk.shape[0], len(indices)), dtype=kk.dtype)
out[0] = kk[0, mask]
np.add.reduceat(kk[1:], indices, axis=1, out=out[1:])
>>> out
array([[90, 85, 80, 70],
[ 3, 3, 15, 3],
[ 7, 5, 19, 3],
[11, 7, 23, 4]])
答案 1 :(得分:2)
你应该使用numpy中的独特功能
import numpy as np
a = np.array([[90,90,85,80,80,80,70],[1,2,3,4,5,6,3],[3,4,5,6,7,6,3],[5,6,7,8,9,6,4]])
u, v = np.unique(a[0], return_inverse=True)
output = np.zeros((a.shape[0], u.shape[0]))
output[0] = u.copy()
for i in xrange(u.shape[0]):
pos = np.where(v==i)[0]
output[1:,i] = np.sum(a[1:,pos], axis=1)
您应该注意到u
将从最低到最高排序。如果你想要它从最高到最低,你必须做
output = output[:,::-1]
最后。
答案 2 :(得分:1)
您可以使用groupby
:
from itertools import groupby
import numpy as N
kk=N.array([[90,90,85,80,80,80,70],[1,2,3,4,5,6,3],[3,4,5,6,7,6,3],[5,6,7,8,9,6,4]])
keys = kk[0]
vals = kk[1:]
uniq = map(lambda x: x[0], groupby(keys))
new = [uniq]
for row in vals:
new.append([sum(map(lambda x: x[1], group)) for _, group in groupby(zip(keys, row), lambda x: x[0])])
print N.array(new)
提供输出:
[[90 85 80 70]
[ 3 3 15 3]
[ 7 5 19 3]
[11 7 23 4]]