为什么我训练有素的神经网络会产生相同的输出

时间:2018-09-24 13:02:53

标签: java neural-network encog

我已使用带有MLP的encog 3.3,resultantProp(由于难以设置BackProp的学习率和动量进行尝试),10个输入(包括理想值),1个具有7个神经元的隐藏层,1个输出神经元,乙状结肠激活,训练集大约80k行,测试集大约96行,错误率分别为0.01、0.007(我创建了2个模型,但是错误率只有2个,并且上述所有其他设置都相同)。我还对数据的最小-最大进行了标准化。也许我的评估代码是错误的?还是代码的某些部分?任何帮助将不胜感激。

完整代码:

public class ANN
{   
//training
//public final static String SQL = "SELECT load_input, day_of_week, weekend_day, type_of_day, week_num, time, day_date, month, year, ideal_value FROM sample WHERE (year,month,day_date,time) between (2012,4,1,1) and (2014,9,29, 96) ORDER BY ID";
//testing
public final static String SQL = "SELECT load_input, day_of_week, weekend_day, type_of_day, week_num, time, day_date, month, year, ideal_value FROM sample WHERE (year,month,day_date,time) between (2014,9,30,1) and (2014,9,30, 92) ORDER BY ID";
//validation
//public final static String SQL = "SELECT load_input, day_of_week, weekend_day, type_of_day, week_num, time, day_date, month, year, ideal_value FROM sample WHERE (year,month,day_date,time) between (2014,9,30,93) and (2014,9,30, 96) ORDER BY ID";
public final static int INPUT_SIZE = 9;
public final static int IDEAL_SIZE = 1;
public final static String SQL_DRIVER = "org.postgresql.Driver";
public final static String SQL_URL = "jdbc:postgresql://localhost/ANN";
public final static String SQL_UID = "postgres";
public final static String SQL_PWD = "";

public static void main(String args[])
{   
    Mynetwork();
    //train network. will add customizable params later.
    //train(trainingData());
    //evaluate network
    evaluate(trainingData());
    Encog.getInstance().shutdown();
}
public static void evaluate(MLDataSet testSet)
{
    BasicNetwork network = (BasicNetwork)EncogDirectoryPersistence.loadObject(new File("directory"));

    // test the neural network
    System.out.println("Neural Network Results:");
    for(MLDataPair pair: testSet ) {
        final MLData output = network.compute(pair.getInput());
        System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1) + "," + pair.getInput().getData(2) + "," + pair.getInput().getData(3) + "," + pair.getInput().getData(4) + "," + pair.getInput().getData(5) + "," + pair.getInput().getData(6) + "," + pair.getInput().getData(7) + "," + pair.getInput().getData(8) + "," + "Predicted=" + output.getData(0) + ", Actual=" + pair.getIdeal().getData(0));
    }
}
public static BasicNetwork Mynetwork()
{
    //basic neural network template. Inputs should'nt have activation functions
    //because it affects data coming from the previous layer and there is no previous layer before the input.
    BasicNetwork network = new BasicNetwork();
    //input layer with 2 neurons.
    //The 'true' parameter means that it should have a bias neuron. Bias neuron affects the next layer.
    network.addLayer(new BasicLayer(null , true, 9));
    //hidden layer with 3 neurons
    network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 5));
    //output layer with 1 neuron
    network.addLayer(new BasicLayer(new ActivationSigmoid(), false, 1));
    network.getStructure().finalizeStructure() ;
    network.reset();

    return network;
}
public static void train(MLDataSet trainingSet)
{
    //Backpropagation(network, dataset, learning rate, momentum)
    //final Backpropagation train = new Backpropagation(Mynetwork(), trainingSet, 0.1, 0.9);
    final ResilientPropagation train = new ResilientPropagation(Mynetwork(), trainingSet);
    //final QuickPropagation train = new QuickPropagation(Mynetwork(), trainingSet, 0.9);

    int epoch = 1;

    do {
        train.iteration();
        System.out.println("Epoch #" + epoch + " Error:" + train.getError());
        epoch++;
    } while((train.getError() > 0.01)); 
    System.out.println("Saving network");
    System.out.println("Saving Done");
    EncogDirectoryPersistence.saveObject(new File("directory"), Mynetwork());
}
public static MLDataSet trainingData()
{
    MLDataSet trainingSet = new SQLNeuralDataSet(
            ANN.SQL,
            ANN.INPUT_SIZE,
            ANN.IDEAL_SIZE,
            ANN.SQL_DRIVER,
            ANN.SQL_URL,
            ANN.SQL_UID,
            ANN.SQL_PWD);

    return trainingSet;
}

}

这是我的结果:

Predicted=0.4451817588640455, Actual=0.5260616667545941
Predicted=0.4451817588640455, Actual=0.5196499668339777
Predicted=0.4451817588640455, Actual=0.5083828048375548
Predicted=0.4451817588640455, Actual=0.49985462144799725
Predicted=0.4451817588640455, Actual=0.49085956670499675
Predicted=0.4451817588640455, Actual=0.485008112408512
Predicted=0.4451817588640455, Actual=0.47800504210686795
Predicted=0.4451817588640455, Actual=0.4693212349328293
(...and so on with the same "predicted")

期望的结果(出于演示目的,我将“预测的”更改为随机值,表明网络实际上在预测):

Predicted=0.4451817588640455, Actual=0.5260616667545941
Predicted=0.5123312331212122, Actual=0.5196499668339777
Predicted=0.435234234234254365, Actual=0.5083828048375548
Predicted=0.673424556563455, Actual=0.49985462144799725
Predicted=0.2344673345345544235, Actual=0.49085956670499675
Predicted=0.123346457544324, Actual=0.485008112408512
Predicted=0.5673452342342342, Actual=0.47800504210686795
Predicted=0.678435234423423423, Actual=0.4693212349328293

更新:

打印输入+预测+实际(理想)。它们之间用逗号分隔

0.5386671932975533,1100000.0,0.0,1.0,40.0,1.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.5260616667545941
0.5260616667545941,1100000.0,0.0,1.0,40.0,2.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.5196499668339777
0.5196499668339777,1100000.0,0.0,1.0,40.0,3.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.5083828048375548
0.5083828048375548,1100000.0,0.0,1.0,40.0,4.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.49985462144799725
0.49985462144799725,1100000.0,0.0,1.0,40.0,5.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.49085956670499675
0.49085956670499675,1100000.0,0.0,1.0,40.0,6.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.485008112408512
0.485008112408512,1100000.0,0.0,1.0,40.0,7.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.47800504210686795

1 个答案:

答案 0 :(得分:0)

通过完全标准化所有输入功能来修复此问题。我想这足以将您要预测的主要输入标准化,并保留影响它的因素。

相关问题