如何将从tf.get_collection返回的列表写入文件并读取它

时间:2017-10-04 15:20:00

标签: python-2.7 list tensorflow pickle tensorflow-gpu

我尝试转换为字符串和存储,但无法将其转换回原始类型

我也尝试了pickle.dump,但它给出了以下错误

raise TypeError, "can't pickle %s objects" % base.__name__
TypeError: can't pickle module objects

我的代码:

with tf.Session() as sess:
    restorer = tf.train.import_meta_graph('abcd.ckpt.meta')
    restorer.restore(sess,'abcd.ckpt')
    vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    with open("pickle_target.txt", "wb") as fp:   
        pickle.dump(vars, fp)

我需要将tf.get_collection存储到文件中,编辑它并再次将其读取到列表中。

1 个答案:

答案 0 :(得分:0)

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)为您提供变量列表,而不是存储在这些变量中的值。要获取变量的当前值,您必须在会话中运行变量列表:

vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
with tf.Session() as sess:
    restorer = tf.train.import_meta_graph('abcd.ckpt.meta')
    restorer.restore(sess,'abcd.ckpt')
    vars = sess.run(vars_list)

现在vars是一个普通的python列表,其中包含当前变量值。