将Audioset ckpt转换为pb文件

时间:2020-02-03 01:53:16

标签: python tensorflow ckpt

我正在使用Audioset/VGGish,并尝试将它们提供的checkpoint file转换为.pb文件。问题在于为训练模型提供的仅有两项是ckpt文件(上面链接)和一个npz file

这是我第三次尝试这种方法,现在已经花了几个小时试图找到最好的工具来做到这一点。到目前为止,我已经找到了几种解决方案,但它们似乎都需要更多信息,而不仅仅是ckpt文件。请记住,通常ckpt文件和Audioset需要使用TensorFlow <2。


示例:

freeze_graph:无论我输入什么值,我总是以ValueError: You need to supply the name of a node to --output_node_names的错误结束。该示例使用softmax,但是问题是我似乎无法弄清楚如何从ckpt文件中提取节点名称,因此似乎无法在不知道节点名称的情况下添加有效值。

Logged GitHub Issues:按照OP的代码,我收到错误ValueError: No variables to save

Stack Overflow questions:这似乎是一个可靠的答案,但是GitHub存储库中没有提供.ckpt.meta文件。我认为在某些情况下通常将需要元信息吗?我抬头查看是否有任何方法可以从ckpt文件中提取元数据以创建一个元文件,然后运行该信息,因为看起来元文件是ckpt文件的结构或图形,没有任何值(来自此答案) :Tensorflow : What is the relationship between .ckpt file and .ckpt.meta and .ckpt.index , and .pb file),但我可能对此有误解。

我认为有一种方法可以提取图元文件,其原因之一是由于此问题,有人在MMdnn GitHub上登录了Convert Audioset VGG from tensorflow to pytorch。尽管没有转换为.pb,但它们的命令中有一个ckpt.meta文件。该文件未在其描述中链接,并且Google搜索“ vggish_model.ckpt.meta”仅显示GitHub问题。我已向OP传达了有关该问题的消息,以了解他们是否可以阐明该文件的来源。

Previous article (2018) with a conversion script:这是一篇比较老的文章。我可以运行脚本,但也可以得到错误ValueError: No variables to save


如果有人能指出我正确的方向,那就太好了;我已经开始用尽所有选择。似乎有一些我正在尝试的好的解决方案,但是为了使此方法成功转换,我可能只是缺少了一个或两个步骤(或一个或两个文件)。

感谢您的帮助!

1 个答案:

答案 0 :(得分:1)

我希望对于这个答复还为时不晚,但是我设法通过使用repository中提供的推断代码来生成.pb文件。

Obs:我是因为我的GPU而使用了tensorflow 1.4.1,所以这可能不适用于更新的版本,或者应该进行一些更改。

推理演示将图形和检查点数据加载到会话中。从那里我可以使用一个函数来保存会话和图形。这是我的代码示例:

import vggish_input
from tensorflow.python.tools import freeze_graph
def save(sess, directory, filename, saver):
    """
    This function saves a checkpoint, based on the current session
    """
    if not os.path.exists(directory):
        os.makedirs(directory)
    filepath = os.path.join(directory, filename)
    saver.save(sess, filepath)
    return filepath

def save_as_pb(sess, directory, filename, saver):
    """
    This function saves a checkpoint, then writes the graph in a pbtxt, and then              makes a frozen graph with the chekpoint and the pbtxt
    """

    # Save checkpoint to freeze graph later
    ckpt_filepath = save(sess, directory=directory, filename=filename, saver=saver)
    pbtxt_filename = filename + '.pbtxt'
    pbtxt_filepath = os.path.join(directory, pbtxt_filename)
    pb_filepath = os.path.join(directory, filename + '.pb')

    # This will only save the graph but the variables will not be saved.
    tf.train.write_graph(graph_or_graph_def=sess.graph_def, logdir=directory, name=pbtxt_filename, as_text=True)

    # Freeze graph, combining the checkpoint and 
    freeze_graph.freeze_graph(input_graph=pbtxt_filepath, input_saver='', input_binary=False, input_checkpoint=ckpt_filepath, output_node_names=vggish_params.OUTPUT_TENSOR_NAME.split(':')[0], restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph=pb_filepath, clear_devices=True, initializer_nodes='')

    return pb_filepath

然后,我在从vggish_inference_demo.py文件的检查点加载模型之后,立即插入了save_as_pb:

  config = tf.ConfigProto()
  config.gpu_options.allow_growth=True
  with tf.Graph().as_default(), tf.Session(config=config) as sess:
    # Define the model in inference mode, load the checkpoint, and
    # locate input and output tensors.
    vggish_slim.define_vggish_slim(training=False)
    vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint)
    features_tensor = sess.graph.get_tensor_by_name(
        vggish_params.INPUT_TENSOR_NAME)
    embedding_tensor = sess.graph.get_tensor_by_name(
        vggish_params.OUTPUT_TENSOR_NAME)
    saver = tf.train.Saver()
    save_as_pb(sess, './saved_vggish/', 'vggish', saver)
相关问题