如何将keras模型H5导入到PB(tensorflow)

时间:2019-08-19 11:44:27

标签: tensorflow keras

我试图将keras模型保存到pb文件中,所以我可以使用C ++ tensorflow运行该模型。

我正在尝试使用来自互联网的解决方案来冻结模型。我可以保存模型,但是当我在tensorflow中加载模型时,无法进行预测。

import os
from keras import backend as K
from keras.models import load_model
from keras.engine.saving import model_from_json, load_model
from keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt

def load_image(img_path):
    img = image.load_img(img_path, target_size=(128, 128))
    img_tensor = image.img_to_array(img)                    # (height, width, channels)
    img_tensor = np.expand_dims(img_tensor, axis=0)         # (1, height, width, channels), add a dimension because the model expects this shape: (batch_size, height, width, channels)
    img_tensor /= 255.                                      # imshow expects values in the range [0, 1]

    return img_tensor

# This line must be executed before loading Keras model.
K.set_learning_phase(0)

json_file = open('/home/antonio/keras_to_tensorflow/modelAlex/output_model2.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)

model.load_weights("/home/antonio/keras_to_tensorflow/modelAlex/output_model2.h5")

from keras import backend as K
import tensorflow as tf

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        # Graph -> GraphDef ProtoBuf
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph


frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])

tf.train.write_graph(frozen_graph, "model", "./model/tf_model.pb", as_text=False)

import tensorflow as tf
from tensorflow.python.platform import gfile

f = gfile.FastGFile("./model/tf_model.pb", 'rb')
graph_def = tf.GraphDef()
# Parses a serialized binary message into the current message.
graph_def.ParseFromString(f.read())
f.close()

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init_op)
    sess.graph.as_default()
    tf.import_graph_def(graph_def)

    files = ['image.png']

    for fileName in files:
        img_path = 'data/images/' + fileName
        x_test = load_image(img_path)

        softmax_tensor = sess.graph.get_tensor_by_name('dense_4/Softmax:0')
        predictions = sess.run(softmax_tensor, {'input_1:0': x_test})

        print(predictions)

模型会产生预测,但它是错误的或随机的。

0 个答案:

没有答案
相关问题