在训练神经网络时保存检查点的位置

时间:2017-04-26 09:52:15

标签: python tensorflow neural-network

在下面的示例中,我根据iterations的值训练数据x次数。起初我每次迭代都保存了一个检查点,但是当我将它放在迭代的下方时,训练显然变得更快。我不确定它是否有所作为。在会话中有一个保存足够吗?并且迭代会使用上一次迭代中设置的值(权重)吗?

def train():
"""Trains the neural network  
:returns: 0 for now. Why? 42
"""

unique_labels, labels = get_labels(training_data_file, savelabels=True) # get the labels from the training data
randomize_output(unique_labels) # create the correct amount of output neurons

prediction = neural_network_model(p_inputdata)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=p_known_labels)) # calculates the difference between the known labels and the results
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables()) # initialize all variables **in** the session
    with open(training_data_file, "r") as file:
        inputdata = [re.split(csvseperatorRE, line.rstrip("\n"))[:-1] for line in file]
    if path.isfile(temp_dir + "\\" + label + ".ckpt.index"):
        saver.restore(sess, temp_dir + "\\" + label + ".ckpt")
    for i in range(iterations):
        _, cost_of_iteration = sess.run([optimizer,cost], feed_dict={
            p_inputdata: np.array(inputdata),
            p_known_labels: np.array(labels)
        })
    saver.save(sess, temp_dir + "\\" + label + ".ckpt")
    correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(tf.argmax(p_known_labels, 1)))

    accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

    with open(verify_data_file, "r") as file:
        verifydata = [re.split(csvseperatorRE, line.rstrip("\n"))[:-1] for line in file]

    npverifydata = np.array(verifydata )
    nplabels = np.array(get_labels(verify_data_file)[1])
    print("Accuracy: ", accuracy.eval({p_inputdata: npverifydata , p_known_labels: nplabels}))
return 0

0 个答案:

没有答案
相关问题