神经网络的准确性正在下降

时间:2017-03-09 08:45:35

标签: machine-learning neural-network computer-vision artificial-intelligence conv-neural-network

我在Neuralnetworksanddeeplearning.com的帮助下在python中制作了神经网络程序。其中我随机初始化了hiddenLayer权重(784,100)和outputLayer权重(100,10)。算法正在研究基于minibatch的理论和正则化过度拟合mnist.pkl.gz数据集。我正在使用大小为10的小批量,学习率(η)= 3,正则化参数= 2.5。运行程序后,它的精度会增加然后减少......所以请帮助我如何让它更好地获得更高的准确性。以下是算法的结果。在此先感谢..

>>> stochastic(training_data,10,20,hiddenW,outW,hiddenB,outB,3,test_data,2.5)
    Epoch 0 correct data: 9100.0/10000
    Total cost of test data [ 307.75991542]
    Epoch 1 correct data: 9136.0/10000
    Total cost of test data [ 260.61199829]
    Epoch 2 correct data: 9233.0/10000
    Total cost of test data [ 244.9429907]
    Epoch 3 correct data: 9149.0/10000
    Total cost of test data [ 237.08391208]
    Epoch 4 correct data: 9012.0/10000
    Total cost of test data [ 227.14709858]
    Epoch 5 correct data: 8714.0/10000
    Total cost of test data [ 215.23668711]
    Epoch 6 correct data: 8694.0/10000
    Total cost of test data [ 201.79958056]
    Epoch 7 correct data: 8224.0/10000
    Total cost of test data [ 193.37639124]
    Epoch 8 correct data: 7915.0/10000
    Total cost of test data [ 183.83249811]
    Epoch 9 correct data: 7615.0/10000
    Total cost of test data [ 166.59631548]
    # forward proppagation with with bais 3 para
def forward(weight,inp,b):
    val=np.dot(weight.T,inp)+b
    return val

# sigmoid function 
def sigmoid(x):
    val=1.0/(1.0+np.exp(-x))
    return val

# Backpropagation for gradient check
def backpropagation(x,weight1,weight2,bais1,bais2,yTarget):
    hh=forward(weight1,x,bais1)
    hhout=sigmoid(hh)
    oo=forward(weight2,hhout,bais2)
    oout=sigmoid(oo)
    ooe=-(yTarget-oout)*(oout*(1-oout))
    hhe=np.dot(weight2,ooe)*(hhout*(1-hhout))
    a2=np.dot(hhout,ooe.T)
    a1=np.dot(x,hhe.T)
    b1=hhe
    b2=ooe
    return a1,a2,b1,b2
def totalCost(data,weight1,weight2,bais1,bais2,lmbda):
    m=len(data)
    cost=0.0
    for x,y in data:
        hh=forward(weight1,x,bais1)
        hhout=sigmoid(hh)
        oo=forward(weight2,hhout,bais2)
        oout=sigmoid(oo)
        c=sum(-y*np.log(oout)-(1-y)*np.log(1-oout))
        cost=cost+c/m
    cost=cost+0.5*(lmbda/m)*(sum(map(sum,(weight1**2)))+sum(map(sum,(weight2**2))))
    return cost

def stochastic(tdata,batch_size,epoch,w1,w2,b1,b2,eta,testdata,lmbda):
    n=len(tdata)
    for j in xrange(epoch):
        random.shuffle(tdata)
        mini_batches = [tdata[k:k+batch_size]for k in xrange(0, n, batch_size)]
        for minibatch in mini_batches:
            w1,w2,b1,b2=updateminibatch(minibatch,w1,w2,b1,b2,eta,lmbda)
        print 'Epoch {0} correct data: {1}/{2}'.format(j,evaluate(testdata,w1,w2,b1,b2),len(testdata))
        print 'Total cost of test data {0}'.format(totalCost(testdata,w1,w2,b1,b2,lmbda))
    return w1,w2,b1,b2


def updateminibatch(data,w1,w2,b1,b2,eta,lmbda):
    n=len(training_data)
    q1=np.zeros(w1.shape)
    q2=np.zeros(w2.shape)
    q3=np.zeros(b1.shape)
    q4=np.zeros(b2.shape)
    for xin,yout in data:
        delW1,delW2,delB1,delB2=backpropagation(xin,w1,w2,b1,b2,yout)
        q1=q1+delW1
        q2=q2+delW2
        q3=q3+delB1
        q4=q4+delB2
    w1=(1-eta*(lmbda/n))*w1-(eta/len(data))*q1
    w2=(1-eta*(lmbda/n))*w2-(eta/len(data))*q2
    b1=b1-(eta/len(data))*q3
    b2=b2-(eta/len(data))*q4
    return w1,w2,b1,b2

def evaluate(testdata,w1,w2,b1,b2):
    i=0
    z=np.zeros(len(testdata))
    for x,y in testdata:
        h=forward(w1,x,b1)
        hout=sigmoid(h)
        o=forward(w2,hout,b2)
        out=sigmoid(o)
        p=np.argmax(out)
        if (p==y):
            a=int(p==y)
            z[i]=a
        i=i+1
    return sum(z)

1 个答案:

答案 0 :(得分:2)

培训机器学习模型时,必须注意不要过度训练训练数据。

要了解您是否过度拟合数据,在培训期间使用3组不同的数据非常有用:

  • 训练集,您应该用它训练模型
  • 验证集,您可以在培训期间使用它来检查您是否准确拟合数据(显然您不必使用此集来训练模型,而且还需要在培训期间进行测试)。
  • 和测试集作为模型的最终测试。

特别是验证集非常有用。实际上,如果您过度拟合数据,则可能在训练集上具有非常好的性能,但在此集合上的准确度较低。 ( - >在这种情况下,您的模型对训练数据过于专业化,但预测新数据的准确性可能较低。) 因此,当验证集的准确性开始下降时,是停止训练的时刻,因为您已达到最佳准确度。

如果您想提高模型的准确性,可以使用更多数据进行培训,或者,如果您还没有,或者准确度不会提高,您应该更改模型,例如添加更多图层神经网络。

相关问题