TensorFlow图中的条件评估

时间:2017-07-05 22:02:35

标签: python tensorflow

这可以通过tf.cond完成,但是它会更新图表的两个分支,来自manual

  

请注意,条件执行仅适用于操作   在true_fn和false_fn中定义。考虑以下简单   程序:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
  

如果x < y,将执行tf.add操作并执行tf.square   操作不会被执行。因为至少需要一个z   cond的分支,tf.multiply操作始终执行,   无条件

我如何实现这一点,以便tf.multiply有条件地执行(即仅在x > Y时执行?)

更具体地说,我试图做的事情:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
update_var1 = tf.assign(var1,var1 +1)
training = tf.placeholder(tf.bool)

def f1():
  with tf.control_dependencies([update_var1]):
    return var1*1.1

def f2():
  return var1 * 1.1

final = tf.cond(training, f1, f2)
sess.run(final, feed_dict={training:False})

每次评估final时,这将使var1增加1,无论training的值如何,问题都在tf.cond,因为手动操作:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
update_var1 = tf.assign(var1,var1 +1)
training = tf.placeholder(tf.bool)

with tf.control_dependencies([update_var1]):
  f1 = var1 * 1.1

f2 = var1 * 1.1

sess.run(f1)
>> array([1.1,1.1,1.1,1.1])
sess.run(f1)
>> array([2.2,2.2,2.2,2.2])
# var1 gets updated every call
sess.run(f2)
>> array([2.2,2.2,2.2,2.2])
sess.run(f2)
>> array([2.2,2.2,2.2,2.2])
# var1 does not get updated

1 个答案:

答案 0 :(得分:5)

一般解决方案如下:将有条件地执行的代码移动到[2,4] (或者通常是可调用对象)的主体中,以获取相应的分支maxWidth的。{例如,要确保lambda仅在tf.cond()时执行,请将其移至tf.multiply(a, b) lambda:

x < y

同样的原则可以应用于变量更新操作,例如true_fn。重要的细节是您必须在其中一个分支的函数体内创建result = tf.cond(x < y, lambda: tf.add(x, tf.multiply(a, b)), lambda: tf.square(y)) op 。以下是您修改第二个示例的方法:

tf.assign()

分配的控制依赖关系有点繁琐,所以您也可以将tf.assign()写为:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
training = tf.placeholder(tf.bool)

def f1():
  with tf.control_dependencies([tf.assign(var1, var1 + 1)]):
    return var1 * 1.1

def f2():
  return var1 * 1.1

final = tf.cond(training, f1, f2)
sess.run(final, feed_dict={training: False})

...或者将整个事情放在一行:

f1()