如何从CNN张量流中的已保存模型(检查点,元文件)中预测图像?

时间:2019-04-19 09:29:10

标签: python tensorflow

我已经训练并保存了模型。我有检查点和元文件。我想还原模型并使用该模型预测图像。

我尝试使用sess.restore恢复模型,但是它具有一些权重。如何将这些权重用于实际预测?

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./tmp1/my_model.meta', clear_devices=True)

    graph = tf.get_default_graph()
    sess = tf.Session()
    saver.restore(sess, "./tmp1/my_model")
    input_graph_def = graph.as_graph_def()

    output_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session
            input_graph_def, # input_graph_def is useful for retrieving the nodes 
            output_node_names=['output']
    )


output_graph="./my_model.pb"
with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())


#use pb file

path="./my_model.pb"

def load_pb(path):
    with tf.gfile.GFile(path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph
graph=load_pb(path)
with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())

    input = graph.get_tensor_by_name('input:0')
    out = graph.get_tensor_by_name('output:0')
    sess.run(out, feed_dict={input: test_images})
    print(sess.run(out, feed_dict={input: test_images}))

这是我从打印语句中得到的。

[[558.4395 ]
 [498.31738]
 [528.15173]
 ...
 [724.5902 ]
 [508.516  ]
 [542.25244]]

我想要的是对我的test_images的预测

0 个答案:

没有答案