矩阵计算的有效方法

时间:2019-12-25 13:02:22

标签: python pytorch

现在我有一个张量farthest_idxs的大小(批处理,特征)=(24,32)。我还有一个大小(点,批处理,特征)的张量Nearest_idxs =(1024、24、1023)。对于一个点p和一个样本s(即,大小为1x1023的near_idxs [p,s ,:]),我想在此向量中找到在farthest_idxs [s,:](大小为1x32)中的第一个元素,并返回一个记录结果的矩阵(大小为24x1024)。有什么有效的方法可以实现吗?

这是我的代码,这是一种无效的实现方式。

def nearest_indices(self, relation, farthest_idxs):
    '''Generate the nearest indices
        return:
            [B, N] matrix
    '''
    device = relation.device
    nearest_value, nearest_idxs = torch.topk(relation, k=1023, dim=2, largest=False, sorted=True)
    print('nearest idxs', nearest_idxs)
    nearest_idxs = nearest_idxs.transpose(0,1) # 1024x24x1023
    print('transposed nearest_idxs', nearest_idxs.shape)
    N, B, P = nearest_idxs.shape
    upsample_idxs = torch.zeros((B, N), dtype=torch.long).to(device)
    for n in range(N):
        for b in range(B):
            for p in range(P):
                if nearest_idxs[n, b, p] in farthest_idxs[b,:]:
                    upsample_idxs[b, n] = nearest_idxs[n, b, p]
                    break
    print(upsample_idxs.shape)

0 个答案:

没有答案