将冻结的模型(.pb)转换为保存的模型

时间:2020-01-09 04:16:59

标签: tensorflow

最近,我尝试将模型(tf1.x)转换为save_model,并遵循官方的migrate document。但是,在我的用例中,我的手或张量流模型动物园中的大多数模型通常是pb文件,并且根据official document

没有直接的方法可以将原始Graph.pb文件升级到TensorFlow 2.0,但是如果您有“冻结图”(将变量转换为常量的tf.Graph),则可以进行转换使用v1.wrap_function将其转换为concrete_function:

但是我仍然不明白如何转换为saved_model format

2 个答案:

答案 0 :(得分:2)

在TF1模式下:

import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

def convert_pb_to_server_model(pb_model_path, export_dir, input_name='input:0', output_name='output:0'):
    graph_def = read_pb_model(pb_model_path)
    convert_pb_saved_model(graph_def, export_dir, input_name, output_name)


def read_pb_model(pb_model_path):
    with tf.gfile.GFile(pb_model_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        return graph_def


def convert_pb_saved_model(graph_def, export_dir, input_name='input:0', output_name='output:0'):
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

    sigs = {}
    with tf.Session(graph=tf.Graph()) as sess:
        tf.import_graph_def(graph_def, name="")
        g = tf.get_default_graph()
        inp = g.get_tensor_by_name(input_name)
        out = g.get_tensor_by_name(output_name)

        sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
            tf.saved_model.signature_def_utils.predict_signature_def(
                {"input": inp}, {"output": out})

        builder.add_meta_graph_and_variables(sess,
                                             [tag_constants.SERVING],
                                             signature_def_map=sigs)
        builder.save()

在TF2模式下:

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config
import numpy as np
def frozen_keras_graph(func_model):
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(func_model)

    input_tensors = [
        tensor for tensor in frozen_func.inputs
        if tensor.dtype != tf.resource
    ]
    output_tensors = frozen_func.outputs
    graph_def = run_graph_optimizations(
        graph_def,
        input_tensors,
        output_tensors,
        config=get_grappler_config(["constfold", "function"]),
        graph=frozen_func.graph)

    return graph_def


def convert_keras_model_to_pb():

    keras_model = train_model()
    func_model = tf.function(keras_model).get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
    graph_def = frozen_keras_graph(func_model)
    tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')

def convert_saved_model_to_pb():
    model_dir = '/tmp/saved_model'
    model = tf.saved_model.load(model_dir)
    func_model = model.signatures["serving_default"]
    graph_def = frozen_keras_graph(func_model)
    tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')

或者:

def convert_saved_model_to_pb(output_node_names, input_saved_model_dir, output_graph_dir):
    from tensorflow.python.tools import freeze_graph

    output_node_names = ','.join(output_node_names)

    freeze_graph.freeze_graph(input_graph=None, input_saver=None,
                              input_binary=None,
                              input_checkpoint=None,
                              output_node_names=output_node_names,
                              restore_op_name=None,
                              filename_tensor_name=None,
                              output_graph=output_graph_dir,
                              clear_devices=None,
                              initializer_nodes=None,
                              input_saved_model_dir=input_saved_model_dir)


def save_output_tensor_to_pb():
    output_names = ['StatefulPartitionedCall']
    save_pb_model_path = '/tmp/pb_model/freeze_graph.pb'
    model_dir = '/tmp/saved_model'
    convert_saved_model_to_pb(output_names, model_dir, save_pb_model_path)

答案 1 :(得分:0)

为了确保我的理解是正确的,所以我还发布了我学到的东西:

如果有人想将tf1.x迁移到tf2.x,请先遵循official post

在tensorflow 2.0中,tf.train.Saver和freeze_graph已被saved_model取代。

如果有人想将tb1.x的pb模型转换为save_model,则可以遵循@Boluoyu的回答。但是,如果您的运行时环境高于tf2.0,则可以使用以下代码:

import tensorflow.compat.v1 as tf 
tf.disable_v2_behavior()
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

def covert_pb_to_server_model(pb_model_path, export_dir, input_name='input', output_name='output'):
    graph_def = read_pb_model(pb_model_path)
    covert_pb_saved_model(graph_def, export_dir, input_name, output_name)


def read_pb_model(pb_model_path):
    with tf.gfile.GFile(pb_model_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        return graph_def


def covert_pb_saved_model(graph_def, export_dir, input_name='input', output_name='output'):
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
    tf.import_graph_def(graph_def, name="")
    g = tf.get_default_graph()
    inp = g.get_tensor_by_name(input_name)
    out = g.get_tensor_by_name(output_name)

    sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
        tf.saved_model.signature_def_utils.predict_signature_def(
            {"input": inp}, {"output": out})

    builder.add_meta_graph_and_variables(sess,
                                         [tag_constants.SERVING],
                                         signature_def_map=sigs)
    builder.save()