如何将.ckpt文件转换为.pb

时间:2017-08-17 03:57:53

标签: tensorflow ssd

我在对象检测API中使用ssd_mobilenets来训练我自己的模型,并获取.ckpt文件。它在我的电脑上运行良好,但现在我想在手机上使用该型号。所以,我需要将其转换为.pb文件。我不知道怎么做,有人可以帮忙吗?顺便说一下,ssd_mobilenets的图表很复杂,我找不到哪个是模型的输出。有没有人知道输出的名称?

2 个答案:

答案 0 :(得分:2)

使用export_inference_graph.py将模型检查点文件转换为.pb文件。

python tensorflow_models/object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path architecture_used_while_training.config \
--trained path_to_saved_ckpt/model.ckpt-NUMBER \
--output_directory model/

答案 1 :(得分:0)

  • 这是此链接中object_detection_tutorial.ipynb中的第4个代码单元格 - https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb

    # What model to download.
    MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
    MODEL_FILE = MODEL_NAME + '.tar.gz'
    DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
    
    # Path to frozen detection graph. This is the actual model that is used for the object detection.
    PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
    
    # List of the strings that is used to add correct label for each box.
    PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
    
    NUM_CLASSES = 90
    
  • 现在,该单元格清楚地显示.pb <{1}}

  • /frozen_inference_graph.pb文件名
  • 所以你已经拥有.pb文件为什么要转换?
  • 无论如何,你可以参考这个链接冻结图表:https://github.com/jayshah19949596/Tensorboard-Visualization-Freezing-Graph
  • 您需要使用tensorflow.python.tools.freeze_graph()函数将.ckpt文件转换为.pb文件
  • 以下代码行显示了您的操作方式

    freeze_graph.freeze_graph(input_graph_path,
                              input_saver_def_path,
                              input_binary,
                              input_checkpoint_path,
                              output_node_names,
                              restore_op_name,
                              filename_tensor_name,
                              output_graph_path,
                              clear_devices,
                              initializer_nodes)
    
    • input_graph_path:是.pb文件的路径,您将在其中编写图表,并且此.pb文件未被冻结。您将使用tf.train.write_graph()来编写图表
    • input_saver_def_path:您可以将其保留为空字符串
    • input_binary:它是一个布尔值,保持为假,因此创建的文件不是二进制和人类可读的
    • input_checkpoint_path:.ckpt file
    • 的路径
    • output_graph_path:您要写入pb文件
    • 的路径
    • clear_devices:布尔值...保持错误
    • output_node_names:要保存的显式张量节点名称
    • restore_op_name:字符串值应为&#34; save / restore_all&#34;
    • filename_tensor_name =&#34; save / Const:0&#34;
    • initializer_nodes =&#34;&#34;