有没有更好的方法来重写这个NumPy Snippet

时间:2014-01-17 11:16:08

标签: python numpy refactoring

我有以下Python(NumPy),我想重构它更干净(可能更快):

temp = max(value for (x, y), value in np.ndenumerate(cm) if x * y < 100 and (x, y) != (0, 0) and not np.isnan(value))

我认为很清楚我想做什么。总而言之,我尝试根据它的值和索引的某些条件来过滤2D数组的一些元素。

感谢任何帮助。

1 个答案:

答案 0 :(得分:5)

import numpy as np
from numpy.random import rand, randint 

cm = rand(50, 100)
cm[randint(0, 50, 4000), randint(0, 100, 4000)] = np.nan

temp1 = max(value for (x, y), value in np.ndenumerate(cm) if x * y < 100 and (x, y) != (0, 0) and not np.isnan(value))

x, y = np.indices(cm.shape)
mask = (x * y < 100) & (x + y != 0) & (~np.isnan(cm))
temp2 = np.max(cm[mask])

assert temp1 == temp2

修改

代表max(x+y * value)

np.max((x + y * cm)[mask])

np.max(x[mask] + y[mask] * cm[mask])