Pytorch自定义损失函数(降噪损失函数)

时间:2019-10-23 09:13:42

标签: python pytorch loss-function noise-reduction

我正在尝试在PyTorch中编写自定义损失函数(降噪损失)。它与交叉熵损失非常相似,不同之处在于,它假定预测的答案中的某些标签不正确,从而使它对所预测的答案具有一定的置信度(预测矩阵中的最高概率)。这里pred表示预测的[m * L]矩阵,其中m是示例数,L是标签数,y_true是实际标签的[m * 1]矩阵,“ ro”是决定每个标签的影响的超参数所使用的两个标准中的一个。

def lossNR(pred, y_true, ro):
    outputs = torch.log(pred)   # compute the log of softmax values
    out1 = outputs.gather(1, y_true.view([-1,1])) # pick the values corresponding to the labels
    l1 = -((ro)* torch.mean(out1))
    l2 = -(1-ro) * torch.mean((torch.max(outputs,1)[0]))
    print("l1=", l1)
    print("l2 = ", l2)
    return (l1+l2)

我在各种数据集上尝试了损失函数,但效果不好。请提供建议。

0 个答案:

没有答案