通过提供其值

时间:2017-03-03 10:06:25

标签: python tensorflow

我有一个TensorFlow模型,我希望能够选择检索和更新权重(在检查点机制之外)。由于我可能会多次这样做,所以每当我这样做时,我都不想在图表中添加节点,而是让我可以调用一些操作来完成它。我的想法是让tf.assign ops对更新的变量和值使用相同的变量,并在feed_dict中提供新的权重值;所以像这样:

weights_assigns = [tf.assign(v, v) for v in tf.trainable_variables()]
weights_update = tf.group(*weights_assigns)

# Now to update the weights
weights = [...]  # List of new weight values
feed_dict = {v: w for v, w in zip(tf.trainable_variables(), weights)}
tf.run(weights_update_op, feed_dict=feed_dict)

在我看来,这应该将feed_dict中传递的值作为变量的当前值,然后通过tf.assign操作存储它们。但是,这不起作用,并且给出了一些关于意外值类型的奇怪错误。

我目前的替代方案是改为使用一些辅助节点(变量或占位符),并在赋值操作中用作值:

weights_updates = [tf.placeholder(v.dtype, v.get_shape()) for v in tf.trainable_variables()]
weights_assigns = [tf.assign(v, u) for v, u in zip(tf.trainable_variables(), weights_updates)]
weights_update_op = tf.group(*weights_assigns)

# Now to update the weights
weights = [...]  # List of new weight values
feed_dict = {u: w for u, w in zip(weights_updates, weights)}
tf.run(weights_update_op, feed_dict=feed_dict)

这真的是唯一的方法吗?或者还有其他一些我没有看到的明显方式吗?

1 个答案:

答案 0 :(得分:1)

好的,发现有一个load()方法:

with tf.Graph().as_default():
    w = tf.get_variable('weights', shape=[3, 3], 
                        initializer=tf.random_uniform_initializer(dtype=tf.float32))
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        print('initial W:')
        print(sess.run(w))
        new_vals = np.reshape(np.arange(9, dtype=np.float32), (3,3))
        w.load(new_vals)
        print('updated W:')
        print(sess.run(w))

也许这有帮助?