Tensorflow恢复模型

时间:2017-04-14 22:05:32

标签: tensorflow

我试图恢复我的Tensorflow模型 - 它是一个线性回归网络。我确定我做错了,因为我的预测并不好。当我训练时,我有一套测试仪。我的测试集预测看起来很棒,但是当我尝试恢复相同的模型时,预测看起来很差。

以下是我保存模型的方法:

with tf.Session() as sess:
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    sess.run(init)
    training_data, ground_truth = d.get_training_data()
    testing_data, testing_ground_truth = d.get_testing_data()

    for iteration in range(config["training_iterations"]):
        start_pos = np.random.randint(len(training_data) - config["batch_size"])
        batch_x = training_data[start_pos:start_pos+config["batch_size"],:,:]
        batch_y = ground_truth[start_pos:start_pos+config["batch_size"]]
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        train_acc, train_loss = sess.run([accuracy, cost], feed_dict={x: batch_x, y: batch_y})

        sess.run(optimizer, feed_dict={x: testing_data, y: testing_ground_truth})
        test_acc, test_loss = sess.run([accuracy, cost], feed_dict={x: testing_data, y: testing_ground_truth})
        samples = sess.run(pred, feed_dict={x: testing_data})
        # print samples
        data.compute_acc(samples, testing_ground_truth)

        print("Training\tAcc: {}\tLoss: {}".format(train_acc, train_loss))
        print("Testing\t\tAcc: {}\tLoss: {}".format(test_acc, test_loss))
        print("Iteration: {}".format(iteration))

        if iteration % config["save_step"] == 0:
            saver.save(sess, config["save_model_path"]+str(iteration)+".ckpt")

以下是我的测试集中的一些示例。您会注意到My prediction相对接近Actual

My prediction: -12.705  Actual : -10.0
My prediction: 0.000    Actual : 8.0
My prediction: -14.313  Actual : -23.0
My prediction: 17.879   Actual : 13.0
My prediction: 17.452   Actual : 24.0
My prediction: 22.886   Actual : 29.0
Custom accuracy: 5.0159861487
Training    Acc: 5.63836860657  Loss: 25.6545143127
Testing     Acc: 4.238052845    Loss: 22.2736053467
Iteration: 6297

然后我将如何恢复模型:

with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, config["retore_model_path"]+"3000.ckpt")

    init = tf.global_variables_initializer()
    sess.run(init)

    pred = sess.run(pred, feed_dict={x: predict_data})[0]
    print("Prediction: {:.3f}\tGround truth: {:.3f}".format(pred, ground_truth))

但这是预测的样子。您会注意到Prediction总是在0左右:

Prediction: 0.355       Ground truth: -22.000
Prediction: -0.035      Ground truth: 3.000
Prediction: -1.005      Ground truth: -3.000
Prediction: -0.184      Ground truth: 1.000
Prediction: 1.300       Ground truth: 5.000
Prediction: 0.133       Ground truth: -5.000

这是我的张量流版本(是的,我需要更新):

Python 2.7.6 (default, Oct 26 2016, 20:30:19) 
[GCC 4.8.4] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> print(tf.__version__)
0.12.0-rc1

不确定这是否有帮助,但我尝试在saver.restore()之后放置sess.run(init)来电并获得完全相同的预测。我认为这是因为sess.run(init)初始化变量。

更改顺序如下:

sess.run(init)
saver.restore(sess, config["retore_model_path"]+"6000.ckpt")

但是预测看起来像这样:

Prediction: -15.840     Ground truth: 2.000
Prediction: -15.840     Ground truth: -7.000
Prediction: -0.000      Ground truth: 12.000
Prediction: -15.840     Ground truth: -9.000
Prediction: -15.175     Ground truth: -27.000

1 个答案:

答案 0 :(得分:2)

从检查点恢复时,不会初始化变量。正如你在问题的最后提到的那样。

init = tf.global_variables_initializer()
sess.run(init)

覆盖刚恢复的变量。哎呀! :)

评论这两行,我怀疑你会好起来。