在Numpy / PyTorch

时间:2018-04-26 14:10:40

标签: python numpy pytorch

任务

给定numpy或pytorch矩阵,找到值大于给定阈值的单元格的索引。

我的实施

#abs_cosine is the matrix
#sim_vec is the wanted

sim_vec = []
for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
        # exclude diagonal cells
        if m != n and abs_cosine[m][n] >= threshold:
            sim_vec.append((m, n))

关注

速度即可。所有其他计算都是基于Pytorch构建的,使用numpy已经是一种妥协,因为它已经将计算从GPU转移到了CPU。纯python for循环将使整个过程更加糟糕(对于已经慢5倍的小数据集)。 我想知道我们是否可以在不调用任何for循环的情况下将整个计算移动到Numpy(或pytorch)?

我能想到的改进(但卡住了......)

  

bool_cosine = abs_cosine>阈值

返回TrueFalse的布尔矩阵。但我找不到快速检索True单元格索引的方法。

2 个答案:

答案 0 :(得分:3)

以下是PyTorch(完全在GPU上)

mask = 1 - np.diag(np.ones(abs_cosine.shape[0]))
sim_vec = np.nonzero((abs_cosine >= 0.2)*mask)
# sim_vec is a 2-array tuple where the first array is the row index and the second array is column index

以下适用于numpy

{{1}}

答案 1 :(得分:0)

这大约是np.where

的两倍
url = "https://www.googleapis.com/oauth2/v1/tokeninfo?fields=email&access_token
    =" + token_from_the_request_header;

第一次调用需要大约0.2秒(编译开销)。如果阵列在GPU上,则可能还有一种方法可以在GPU上进行整个计算。

尽管如此,我对性能并不满意,因为简单的布尔操作比上面显示的解决方案快5倍,比np.where快10倍。如果索引的顺序无关紧要,这个问题也可以并行化。