pytorch没有给出预期的输出

时间:2018-04-08 21:40:24

标签: pytorch

首先,CNN模型对一堆数据进行分类。然后,我试图从第一步对正确分类的数据进行预测,预计准确度为100%。但是,我发现结果不稳定,有时为99 +%,但不是100%。有人知道我的代码有什么问题吗?非常感谢你提前几天,它困扰了我几天~~

火炬。版本

'0.3.1.post2'

import numpy as np
import torch 
import torch.nn as nn
from torch.autograd import Variable

n = 2000
data = np.random.randn(n, 1, 10, 10)
label = np.random.randint(2, size=(n, ))

def test_pred(model, data_test, label_test):

    data_batch = data_test
    labels_batch = label_test

    images = torch.autograd.Variable(torch.FloatTensor(data_batch))
    labels = torch.autograd.Variable(torch.FloatTensor(labels_batch))

    outputs = model(images)

    _, predicted = torch.max(outputs.data, 1)

    correct = (np.array(predicted) == labels_batch).sum()

    label_pred = np.array(predicted)

    acc = correct/len(label_test)
    print(" acc:", acc)

    return acc, label_pred

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(128, 2)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

cnn = CNN()

[_, label_pred] = test_pred(cnn, data, label)

print("Acc:", np.mean(label_pred==label))
# Given the correctly classified data in previous step, expect to get 100% accuracy
# Why it sometimes doesn't give a 100% accuracy ?
print("Using selected data size {}:".format(data[label_pred==label].shape))
_, _ = test_pred(cnn, data[label_pred==label], label[label_pred==label])

输出:

acc:0.482

Acc:0.482

使用所选数据大小(964,1,10,10):

acc:0.9979253112033195

1 个答案:

答案 0 :(得分:1)

好像你没有将网络设置为评估模式,这可能会导致一些问题,特别是BatchNorm图层。做

cnn = CNN()
cnn.eval()

它应该有用。

相关问题