我在哪里可以看到pytorch的MSELoss的源代码?

时间:2018-01-12 03:49:37

标签: python deep-learning pytorch loss-function

我使用U-NET网络来训练我的数据。 但我需要修改其损失函数以减少低于1的像素损失,以减少负面情况对网络权重的影响。但是我在pycharm MSELOSS中打开了源代码,请看:

SELECT 3 > 0 AS C1 FROM T WHERE T.C2 = 'text'

我无法获得任何有用的东西。

1 个答案:

答案 0 :(得分:1)

你去了:https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L1423然而,它调用了C api

def mse_loss(input, target, size_average=True, reduce=True):
    """
    mse_loss(input, target, size_average=True, reduce=True) -> Variable
    Measures the element-wise mean squared error.
    See :class:`~torch.nn.MSELoss` for details.
    """
    return _pointwise_loss(lambda a, b: (a - b) ** 2, torch._C._nn.mse_loss,
input, target, size_average, reduce)

def own_mse_loss(input, target, size_average=True):
    L = (input - target) ** 2
    return torch.mean(L) if size_average else torch.sum(L)
相关问题