火炬损失信息

时间:2018-06-25 23:37:32

标签: regression pytorch

我正在尝试使用1个特征进行简单的线性回归。这是一个简单的“根据多年经验预测薪水”问题。 NN会训练年经验(X)和薪水(Y)。 由于某种原因,损失激增,最终返回infnan

这是我的代码:

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

dataset = pd.read_csv('./salaries.csv')

x_temp = dataset.iloc[:, :-1].values
y_temp = dataset.iloc[:, 1:].values

X_train = torch.FloatTensor(x_temp)
Y_train = torch.FloatTensor(y_temp)

class Model(torch.nn.Module): 
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1,1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = Model()

loss_func = torch.nn.MSELoss(size_average=False)
optim = torch.optim.SGD(model.parameters(), lr=0.01)

#training 
for epoch in range(200):
    #calculate y_pred
    y_pred = model(X_train)

    #calculate loss
    loss = loss_func(y_pred, Y_train)
    print(epoch, "{:.2f}".format(loss.data))

    #backward pass + update weights
    optim.zero_grad()
    loss.backward()
    optim.step()


test_exp = torch.FloatTensor([[8.0]])
print("8 years experience --> ", model(test_exp).data[0][0].item())

正如我提到的,一旦开始训练,损失就会变得非常大,最终在第10个时期之后显示inf

我怀疑这可能与我如何加载数据有关?这就是salaries.csv文件中的内容:

Years Salary
1.1 39343
1.3 46205
1.5 37731
2   43525
2.2 39891
2.9 56642
3   60150
3.2 54445
3.2 64445
3.7 57189
3.9 63218
4   55794
4   56957
4.1 57081
4.5 61111
4.9 67938
5.1 66029
5.3 83088

谢谢您的帮助

4 个答案:

答案 0 :(得分:5)

一旦经过一定的损失后,损失变为inf,则模型在反向传播后会损坏。这可能是因为“薪水”列中的值太大。尝试使工资标准化。

或者,您可以尝试手动初始化参数(而不是让它随机初始化),让偏差项为薪水的平均值,而直线的斜率为0(例如)。这样,初始模型将足够接近最佳解,因此损失不会增加。

答案 1 :(得分:0)

以下是所有情况的示例。您可以尝试运行该程序,该程序基本上表示r-深层网络。

import torch
import math
import matplotlib.pyplot as plt
def stat(t, p=True):
    m = t.mean()
    s = t.std()
    if p==True:
        print(f"MEAN: {m}, STD: {s}")
    return(m,s)

_m = []
_s = []

c = 100
r = 50# repeat steps
x = torch.randn(c)
m = torch.randn(c,c)#/math.sqrt(n)
stat(x)

for _ in range (0,r):
    x = m@x    
    _1, _2 = stat(x, False)
    _m.append(_1)
    _s.append(_2)


stat(x)

plt.plot(_m)
plt.plot(_s)
plt.legend(["mean","std"])
plt.show()

enter image description here

答案 2 :(得分:0)

请将学习率“lr”降低到 0.001 或 0.0001。较大的 lr 值会使梯度爆炸并导致 inf。我已经尝试过 lr=0.001 和 lr=0.0001 它对我来说很好用。请尝试一次并告诉我。

答案 3 :(得分:0)

获得 import pyspark.sql.functions as F df1 = df.withColumn( 'CONV_ID', F.split('URL', '(?<=conversations/)')[1] # just using 'conversations/' should also be enough ) df1.show() +---+--------------------+----------+ | ID| URL| CONV_ID| +---+--------------------+----------+ | 1|https://app.xyz.c...|2686735685| | 2|https://app.xyz.c...|2938415796| | 3|https://app.drift...|2938419189| +---+--------------------+----------+ 损失的另一种可能性是包含 nan 值的模型的输入张量。尝试从模型输入中过滤 nan 值。

相关问题