Keras模型到Tensorflow可训练的PB文件和检查点

时间:2019-05-16 19:09:03

标签: python tensorflow keras

我想在Tensorflow中从Keras训练pb和检查点。

我已经成功地将Keras模型转换为Tensorflow pb和检查点。 而且我已经成功推断了。 但是问题是,我不知道该怎么做。 这种Keras模型似乎没有培训内容,或者我只是不知道在培训中应该输入什么信息。

此代码将Keras模型转换为Tensorflow pb和检查点。

from keras import backend as K
from keras.models import load_model
import tensorflow as tf

model = load_model('model/my_model.h5')

K.set_learning_phase(0) #0 : test, 1 : train

sess = K.get_session()

saver = tf.train.Saver()
saver.save(sess, 'keras/keras.ckpt')

sess.graph.as_default()
graph = sess.graph

with open('keras/keras.pb', 'wb') as f:
    f.write(graph.as_graph_def().SerializeToString())

这是读取pb和检查点的代码

def keras_model():
    sess = tf.Session()
    saver = tf.train.import_meta_graph('keras/keras.ckpt.meta')
    saver.restore(sess, "keras/keras.ckpt")

    sess.graph.as_default()
    graph = tf.get_default_graph()

    a = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
    #print(a)

    img = cv2.imread("data/wqds_backbead_0_3.png", cv2.IMREAD_COLOR)
    img = img[...,::-1] # bgr to rgb
    img = img.astype('float32')
    img = np.expand_dims(img, axis=0)

    INPUT1 = graph.get_tensor_by_name("input_1:0")
    OUTPUT1 = graph.get_tensor_by_name("softmax/Softmax:0")
    TARGET1 = graph.get_tensor_by_name("softmax_target:0")

    print(TARGET1)

    pred = sess.run(OUTPUT1, feed_dict={INPUT1: img})
    print(pred, pred.shape, pred.dtype)

0 个答案:

没有答案