TensorFlow control_dependencies在while_loop

时间:2016-12-10 09:23:35

标签: python tensorflow

此问题最初发布在TensorFlow issues。现在转到StackOverflow,看看是否有办法解决这个问题:)

问题描述

我需要将InceptionV3放在while循环中以节省GPU和CPU内存使用量,因为我正在处理视频,每个都包含数百个图像。问题是,如果control_dependencies函数在control_dependencies中,则InceptionV3对BatchNorm使用while_loop,而TensorFlow则抛出帧错误。如果删除control_dependencies,它可以正常运行。

以下是重现错误的最小代码段:

sess = tf.Session()

with tf.variable_scope('state'):
    x = tf.get_variable('x', shape=(), 
                             initializer=tf.constant_initializer(1), 
                             dtype=tf.float32)
    update_x = tf.assign(x, x+1)

def iter_fun(i, y):
    # comment the line below, the program will run without any error
    # but I need control_dependencies, or at least some way to replace it...
    with tf.control_dependencies([update_x]): 
        y = y + x
    return (i+1, y)

with tf.variable_scope('iteration'):
    num_iterations = 5   
    initial_i = tf.constant(0, dtype=tf.int32)
    initial_y = tf.constant(0, dtype=tf.float32)
    _, result = tf.while_loop(
        cond=lambda i, *_: i < num_iterations,
        body=iter_fun,
        loop_vars=(initial_i, initial_y))

init_op = tf.global_variables_initializer()
sess.run(init_op)
sess.run(result)  

错误的堆栈跟踪:

Traceback (most recent call last):
  File "demo.py", line 28, in <module>
    sess.run(result)
  File "/workspace/bily/anoaconda2/envs/tensorflow0.12/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 766, in run
    run_metadata_ptr)
  File "/workspace/bily/anoaconda2/envs/tensorflow0.12/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 964, in _run
    feed_dict_string, options, run_metadata)
  File "/workspace/bily/anoaconda2/envs/tensorflow0.12/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1014, in _do_run
    target_list, options, run_metadata)
  File "/workspace/bily/anoaconda2/envs/tensorflow0.12/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1034, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'iteration/while/add' has inputs from different frames. The input 'iteration/while/add/Enter' is in frame 'iteration/while/iteration/while/'. The input 'state/Assign' is in frame ''.

环境

  • CentOS Linux发行版7.2.1511
  • 从源
  • 构建的TensorFlow 0.12
  • Python 2.7.12
  • CUDA 7.5和CUDNN v5.1

相关问题

  1. This issue in tflearn似乎与我的问题有关,但删除control_denpendencies并不是我的解决方案。
  2. #4478#3114是关于帧错误的问题,但这些错误是由变量而不是control_dependencies引起的。
  3. 任何帮助将不胜感激:)

1 个答案:

答案 0 :(得分:0)

当我回复github问题时,修复程序应在几天内以开源方式显示。

相关问题