tf.gradients是否会通过tf.cond?

时间:2019-03-27 15:48:04

标签: tensorflow

我想创建一对递归神经网络,例如NN 1 和NN 2 ,其中NN 2 从NN 1 输出与上一时间步长不同的值时,不会在当前时间步长更新其权重。

为此,我计划将tf.cond()tf.stop_gradients()一起使用。但是,在我运行的所有玩具示例中,我无法让tf.gradients()穿过tf.cond()tf.gradients()仅返回[None]

这是一个简单的玩具示例:

import tensorflow as tf

x = tf.constant(5)
y = tf.constant(3)

mult = tf.multiply(x, y)
cond = tf.cond(pred = tf.constant(True),
               true_fn = lambda: mult,
               false_fn = lambda: mult)

grad = tf.gradients(cond, x) # Returns [None]

这是另一个简单的玩具示例,其中我在true_fn中定义了false_fntf.cond()(仍然没有骰子):

import tensorflow as tf

x = tf.constant(5)
y = tf.constant(3)
z = tf.constant(8)

cond = tf.cond(pred = x < y,
               true_fn = lambda: tf.add(x, z),
               false_fn = lambda: tf.square(y))

tf.gradients(cond, z) # Returns [None]

我本来以为梯度应该同时流过true_fnfalse_fn,但是显然没有梯度在流过。这是通过tf.cond()计算出来的渐变的预期行为吗?可能有办法解决这个问题吗?

1 个答案:

答案 0 :(得分:0)

是的,渐变将通过tf.cond()。您只需要使用浮点数而不是整数,并且(最好)使用变量而不是常量:


import tensorflow as tf

x = tf.Variable(5.0, dtype=tf.float32)
y = tf.Variable(6.0, dtype=tf.float32)
z = tf.Variable(8.0, dtype=tf.float32)

cond = tf.cond(pred = x < y,
               true_fn = lambda: tf.add(x, z),
               false_fn = lambda: tf.square(y))

op = tf.gradients(cond, z) 
# Returns [<tf.Tensor 'gradients_1/cond_1/Add/Switch_1_grad/cond_grad:0' shape=() dtype=float32>]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op)) # [1.0] 
相关问题