我正在训练我的模型,但是发生了一些问题。 数据加载器应该更新批次,但是我得到了:
有人知道怎么了吗?我的代码如下
class MSourceDataSet(Dataset):
def __init__(self, clean_dir, mix_dir, clean_label_dir, mix_label_dir):
with open(clean_dir + 'clean1.json') as f:
clean0 = torch.Tensor(json.load(f))
with open(mix_dir + 'mix1.json') as f:
mix0 = torch.Tensor(json.load(f))
with open(clean_label_dir + 'clean_label1.json') as f:
clean_label0 = torch.Tensor(json.load(f))
with open(mix_label_dir + 'mix_label1.json') as f:
mix_label0 = torch.Tensor(json.load(f))
self.spec = torch.cat([clean0, mix0], 0)
self.label = torch.cat([clean_label0, mix_label0], 0)
def __len__(self):
return self.spec.shape[0]
def __getitem__(self, index):
spec = self.spec[index]
label = self.label[index]
return spec, label
trainset = MSourceDataSet(clean_dir, mix_dir, clean_label_dir, mix_label_dir)
trainloader = torch.utils.data.DataLoader(dataset = trainset,
batch_size = 4,
shuffle = True)
# testloader = torch.utils.data.DataLoader(dataset = testset,
# batch_size = 4,
# shuffle = True)
model.train()
for epoch in range(10):
running_loss = 0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 1000 == 0:
print ('[%d, %5d] loss: %.3f' % (epoch, i, running_loss/ 1000))
running_loss = 0
torch.save(model, 'FeatureNet.pkl')