pytorch-批次不更新

时间:2018-12-04 04:42:15

标签: python-3.x pytorch

我正在训练我的模型,但是发生了一些问题。 数据加载器应该更新批次,但是我得到了:

输出

  • [0,0]损失:0.009
  • [1,0]亏损:0.009
  • [2,0]损失:0.009
  • [3,0]损失:0.009
  • [4,0]亏损:0.009
  • [5,0]损失:0.009
  • [6,0]损失:0.009
  • [7,0]亏损:0.009
  • [8,0]损失:0.009
  • [9,0]损失:0.009

有人知道怎么了吗?我的代码如下

DataLoader

class MSourceDataSet(Dataset):

    def __init__(self, clean_dir, mix_dir, clean_label_dir, mix_label_dir):

        with open(clean_dir + 'clean1.json') as f:
            clean0 = torch.Tensor(json.load(f))

        with open(mix_dir + 'mix1.json') as f:
            mix0 = torch.Tensor(json.load(f))

        with open(clean_label_dir + 'clean_label1.json') as f:
            clean_label0 = torch.Tensor(json.load(f))


        with open(mix_label_dir + 'mix_label1.json') as f:
            mix_label0 = torch.Tensor(json.load(f))


        self.spec = torch.cat([clean0, mix0], 0)
        self.label = torch.cat([clean_label0, mix_label0], 0)

    def __len__(self):
        return self.spec.shape[0]


    def __getitem__(self, index): 

        spec = self.spec[index]
        label = self.label[index]
        return spec, label

trainset = MSourceDataSet(clean_dir, mix_dir, clean_label_dir, mix_label_dir)

trainloader = torch.utils.data.DataLoader(dataset = trainset,
                                                batch_size = 4,
                                                shuffle = True)

# testloader = torch.utils.data.DataLoader(dataset = testset,
#                                                batch_size = 4,
#                                                shuffle = True)

培训

model.train()

for epoch in range(10):
    running_loss = 0

    for i, data in enumerate(trainloader, 0):

        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 1000 == 0:
            print ('[%d, %5d] loss: %.3f' % (epoch, i, running_loss/ 1000))
            running_loss = 0

torch.save(model, 'FeatureNet.pkl')

0 个答案:

没有答案
相关问题