Numpy:二维数组的元素绝对最大值的有符号值

时间:2017-06-06 05:15:47

标签: python performance python-3.x numpy max

让我们假设我有一个名为arr形状(4,3)的2D数组,如下所示:

>>> arr
array([[ nan,   1., -18.],
   [ -1.,  -1.,  -1.],
   [  1.,   1.,   5.],
   [  1.,  -1.,   0.]])

这就是说,我想将元素明确的绝对最大值(1.0, 1.0, -15.0)和行arr[[0, 2], :]的签名值分配回arr。这意味着,我正在寻找输出:

>>> arr
array([[ 1.,   1.,  -18.],
   [ -1.,  -1.,  -1.],
   [  1.,   1., -15.],
   [  1.,  -1.,   0.]])

我在API参考中找到的最接近的是numpy.fmax,但它没有做绝对值。如果我用过:

arr[index_list, :] = np.fmax(arr[index_list, :], new_tuple)

我的数组终于看起来像了:

>>> arr
array([[ 1.,   1., -15.],
   [ -1.,  -1.,  -1.],
   [  1.,   1.,   5.],
   [  1.,  -1.,   0.]])

现在,API说这个函数是

  当x1和x2都不是NaN时,

相当于np.where(x1 >= x2, x1, x2),但它更快并且正确播放

我尝试使用以下内容:

arr[index_list, :] = np.where(np.absolute(arr[index_list, :]) >= np.absolute(new_tuple), 
                              arr[index_list, :], new_tuple)

虽然这产生了所需的输出,但我得到了警告:

  

/ Applications / PyCharm CE.app/Contents/helpers/pydev/pydevconsole.py:1:RuntimeWarning:在greater_equal中遇到无效值

我相信这个警告是因为NaN在这里没有得到优雅处理,与np.fmax函数不同。此外,API文档提到np.fmax更快并且正确广播(不确定np.where版本中缺少广播的哪个部分)

总之,我要找的东西类似于:

arr[index_list, :] = np.fmax(arr[index_list, :], new_tuple, key=abs)

遗憾的是,此功能没有可用的key属性。

仅仅是为了上下文,我对尽可能快的解决方案感兴趣,因为arr数组的实际形状是(100000,50)的平均值,我循环了近1000 new_tuple元组(当然,每个元组的形状与arr中的列数相等)。每个index_list的{​​{1}}更改。

编辑1:

一种可能的解决方案是,首先用new_tuple替换arr中的所有NaN。即0。在此之后,我可以使用原始文本中提到的arr[np.isnan(arr)] = 0np.where技巧。但是,根据API的建议,这可能比np.absolute慢很多。

编辑2:

np.fmax可能在后续循环中重复索引。每个index_list都附带相应的规则,并根据该规则选择new_tuple。没有什么能阻止不同规则与它们匹配的重叠索引。 @Divakar对于index_list没有重复的情况有一个很好的答案。然而,欢迎其他解决方案涵盖两种情况。

1 个答案:

答案 0 :(得分:1)

假设所有index_list的列表没有重复的索引:

方法#1

一旦我们将所有index_listsnew_tuples存储在一个地方,我会提出更多的矢量化解决方案,最好是作为列表。因此,如果我们处理许多这样的元组和列表,这可能是首选的。

所以,让我们说它们存储如下:

new_tuples = [(1.0, 1.0, -15.0), (6.0, 3.0, -4.0)] # list of all new_tuple
index_lists =[[0,2],[4,1,6]]  # list of all index_list

此后的解决方案是手动重复,替换广播,然后使用np.where,如问题中稍后所示。如果np.where有非NaN值,我们可以忽略对所述警告的关注new_tuples。因此,解决方案是 -

idx = np.concatenate(index_lists)
lens = list(map(len,index_lists))

a = arr[idx]
b = np.repeat(new_tuples,lens,axis=0)
arr[idx] = np.where(np.abs(a) > np.abs(b), a, b)

方法#2

另一种方法是在arr之前存储abs_arr = np.abs(arr)的绝对值:np.where并使用arr[index_list, :] = np.where(abs_arr[index_list, :] > np.abs(b), a, new_tuple) 内的绝对值。这应该在循环中节省很多时间。因此,相关计算将减少到:

replace into tracks
  (rowid, s3_url, track_id, cluster_id, rank, group_id, artist_name, track_name, set_name, file_size)
select 
  t.rowid, t.s3_url, t.track_id, t.cluster_id, t.rank, t.group_id, n.artist_name, n.track_name, t.set_name, t.file_size
from names n
inner join tracks t on n.track_id = t.track_id
;
相关问题