Tensorflow MNIST示例:要从SavedModel预测的代码

时间:2018-04-13 18:32:53

标签: python tensorflow mnist

我正在使用该示例按照本文构建CNN:https://www.tensorflow.org/tutorials/layers

但是,我无法通过输入样本图像找到要预测的样本。任何帮助都将受到高度赞赏。

以下是我尝试过的,无法找到输出张量名称

img = <load from file>
sess = tf.Session()
saver = tf.train.import_meta_graph('/tmp/mnist_convnet_model/model.ckpt-2000.meta')
saver.restore(sess, tf.train.latest_checkpoint('/tmp/mnist_convnet_model/'))

input_place_holder = sess.graph.get_tensor_by_name("enqueue_input/Placeholder:0")
out_put = <not sure what the tensor output name in the graph>
current_input = img

result = sess.run(out_put, feed_dict={input_place_holder: current_input})
print(result)

1 个答案:

答案 0 :(得分:0)

您可以使用Tensorflow中的inspect_checkpoint工具查找检查点文件中的张量。

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name="tmp/mnist_convnet_model/model.ckpt-2000.meta", tensor_name='')

关于如何save and restore in tensorflows programming guide有很好的说明。这是一个受后者链接启发的小例子。 只需确保./tmp目录存在

import tensorflow as tf
# Create some variables.
variable = tf.get_variable("variable_1", shape=[3], initializer=tf.zeros_initializer)
inc_v1=variable.assign(variable + 1)

# Operation to initialize variables if we do not restore from checkpoint
init_op = tf.global_variables_initializer()

# Create the saver
saver = tf.train.Saver()
with tf.Session() as sess:
    # Setting to decide wether or not to restore
    DO_RESTORE=True
    # Where to save the data file
    save_path="./tmp/model.ckpt"
    if DO_RESTORE:
        # If we want to restore, load the variables from the saved file
        saver.restore(sess, save_path)
    else:
        # If we don't want to restore, then initialize variables
        # using their specified initializers.
        sess.run(init_op)

    # Print the initial values of variable
    initial_var_value=sess.run(variable)
    print("Initial:", initial_var_value)
    # Do some work with the model.
    incremented=sess.run(inc_v1)
    print("Incremented:", incremented)
    # Save the variables to disk.
    save_path = saver.save(sess, save_path)
    print("Model saved in path: %s" % save_path)
相关问题