如何读取tensorflow检查点文件中保存的权重?

时间:2016-10-18 21:05:47

标签: tensorflow

我想阅读权重并将其视为图像。但我没有看到任何关于模型格式的文档以及如何阅读训练过的权重。

3 个答案:

答案 0 :(得分:8)

这个实用程序有print_tensors_in_checkpoint_file方法http://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py

或者,您可以使用Saver恢复模型并在变量张量上使用session.run来获取值为numpy数组

答案 1 :(得分:3)

我用Python编写了代码段

def extracting(meta_dir):
    num_tensor = 0
    var_name = ['2-convolutional/kernel']
    model_name = meta_dir
    configfiles = [os.path.join(dirpath, f)  # List of META files
    for dirpath, dirnames, files in os.walk(model_name)
    for f in fnmatch.filter(files, '*.meta')]

    with tf.Session() as sess:
        try:
            # A MetaGraph contains both a TensorFlow GraphDef
            # as well as associated metadata necessary
            # for running computation in a graph when crossing a process boundary.
            saver = tf.train.import_meta_graph(configfiles[0])
       except:
           print("Unexpected error:", sys.exc_info()[0])
       else:
           # It will get the latest check point in the directory
           saver.restore(sess, configfiles[-1].split('.')[0])  # Specific spot

           # Now, let's access and create placeholders variables and
           # create feed-dict to feed new data
           graph = tf.get_default_graph()
           inside_list = [n.name for n in graph.as_graph_def().node]

           print('Step: ', configfiles[-1])

           print('Tensor:', var_name[0] + ':0')
           w2 = graph.get_tensor_by_name(var_name[0] + ':0')
           print('Tensor shape: ', w2.get_shape())
           print('Tensor value: ', sess.run(w2))
           w2_saved = sess.run(w2)  # print out tensor

您可以通过将meta_dir作为预先训练的模型目录来运行它。

答案 2 :(得分:0)

为了扩展 Yaroslav 的答案,print_tensors_in_checkpoint_file 是一个围绕 py_checkpoint_reader 的薄包装,它可以让您简洁地访问变量并以 numpy 格式检索张量。例如,您在名为 tf_weights 的文件夹中有以下文件:

checkpoint  model.ckpt.data-00000-of-00001  model.ckpt.index  model.ckpt.meta

然后您可以使用 py_checkpoint_reader 与权重交互,而不必加载整个模型。要做到这一点:

from tensorflow.python.training import py_checkpoint_reader

# Need to say "model.ckpt" instead of "model.ckpt.index" for tf v2
file_name = "./tf_weights/model.ckpt"
reader = py_checkpoint_reader.NewCheckpointReader(file_name)

# Load dictionaries var -> shape and var -> dtype
var_to_shape_map = reader.get_variable_to_shape_map()
var_to_dtype_map = reader.get_variable_to_dtype_map()

现在,var_to_shape_map 字典的键与存储在检查点中的变量相匹配。这意味着您可以使用 reader.get_tensor 检索它们,例如:

ckpt_vars = list(var_to_shape_map.keys())
reader.get_tensor(ckpt_vars[1])

总结以上所有内容,您可以使用以下代码来获取 numpy 数组的字典:

from tensorflow.python.training import py_checkpoint_reader

file_name = "./tf_weights/model.ckpt"
reader = py_checkpoint_reader.NewCheckpointReader(file_name)

state_dict = {
    v: reader.get_tensor(v) for v in reader.get_variable_to_shape_map()
}