如何使用TF2将加载的h5模型正确保存到pb

时间:2019-06-18 10:21:02

标签: tensorflow keras protocol-buffers tensorflow2.0 keras-2

我加载了一个已保存的h5模型,并希望将该模型另存为pb。 该模型是在训练期间使用tf.keras.callbacks.ModelCheckpoint回调函数保存的。

TF版本:2.0.0a
 编辑:2.0.0-beta1

也存在相同的问题

我保存铅的步骤:

  1. 我首先设置了K.set_learning_phase(0)
  2. 然后我用tf.keras.models.load_model加载模型
  3. 然后,我定义freeze_session()函数。
  4. (可选,我可以编译模型)
  5. 然后将freeze_session()函数与tf.keras.backend.get_session一起使用

错误,无论是否编译,我都会得到该错误:

  

AttributeError:模块'tensorflow.python.keras.api._v2.keras.backend'   没有属性“ get_session”

我的问题:

  1. TF2不再有get_session吗? (我知道tf.contrib.saved_model.save_keras_model不再存在,并且我也尝试了tf.saved_model.save确实没有用)

  2. 或者get_session仅在我实际训练模型并且仅加载h5不起作用时才起作用 编辑:另外,通过重新培训的课程,没有get_session可用。

    • 如果是这样,我将如何在未经培训的情况下将h5转换为pb?有很好的教程吗?

谢谢您的帮助

3 个答案:

答案 0 :(得分:1)

我想知道同一件事,因为我试图使用get_session()和set_session()释放GPU内存。这些功能似乎丢失了,aren't in the TF2.0 Keras documentation。我想这与Tensorflow切换到急切执行有关,因为不再需要直接会话访问。

答案 1 :(得分:1)

我确实将模型从pb模型保存到h5

import logging
import tensorflow as tf
from tensorflow.compat.v1 import graph_util
from tensorflow.python.keras import backend as K
from tensorflow import keras

# necessary !!!
tf.compat.v1.disable_eager_execution()

h5_path = '/path/to/model.h5'
model = keras.models.load_model(h5_path)
model.summary()
# save pb
with K.get_session() as sess:
    output_names = [out.op.name for out in model.outputs]
    input_graph_def = sess.graph.as_graph_def()
    for node in input_graph_def.node:
        node.device = ""
    graph = graph_util.remove_training_nodes(input_graph_def)
    graph_frozen = graph_util.convert_variables_to_constants(sess, graph, output_names)
    tf.io.write_graph(graph_frozen, '/path/to/pb/model.pb', as_text=False)
logging.info("save pb successfully!")

我使用TF2转换模型,例如:

  1. 在训练期间将keras.callbacks.ModelCheckpoint(save_weights_only=True)传递到model.fit并保存checkpoint
  2. 训练后,self.model.load_weights(self.checkpoint_path)加载checkpoint;
  3. self.model.save(h5_path, overwrite=True, include_optimizer=False)另存为h5
  4. 像上面一样将h5转换为pb

答案 2 :(得分:1)

使用

from tensorflow.compat.v1.keras.backend import get_session

在keras 2和tensorflow 2.2中

然后打电话

import logging
import tensorflow as tf
from tensorflow.compat.v1 import graph_util
from tensorflow.python.keras import backend as K
from tensorflow import keras
from tensorflow.compat.v1.keras.backend import get_session

# necessary !!!
tf.compat.v1.disable_eager_execution()

h5_path = '/path/to/model.h5'
model = keras.models.load_model(h5_path)
model.summary()
# save pb
with get_session() as sess:
    output_names = [out.op.name for out in model.outputs]
    input_graph_def = sess.graph.as_graph_def()
    for node in input_graph_def.node:
        node.device = ""
    graph = graph_util.remove_training_nodes(input_graph_def)
    graph_frozen = graph_util.convert_variables_to_constants(sess, graph, output_names)
    tf.io.write_graph(graph_frozen, '/path/to/pb/model.pb', as_text=False)
logging.info("save pb successfully!")
相关问题