Tensorflow中的AssertionError while_loop

时间:2016-05-07 00:46:44

标签: python while-loop tensorflow

这可能非常简单,但我无法找到答案。我试图在tf.while_loop的'body'中使用张量。为了简单起见,我只是将(3,4)形状的张量“x”传递给它,暂时在“体”功能中无所作为。但似乎这一论点的传递正在引发一些问题。堆栈跟踪只是告诉'AssertionError:'。请帮忙。 代码:

import tensorflow as tf
import numpy as np

def cond(sequence_len, step, x):
    return tf.less(step,sequence_len)

def body(sequence_len, step, x):
    return (sequence_len, step+1)

step = tf.constant(0)
sequence_len  = tf.constant(10)
x = tf.zeros([3, 4], tf.int32)
res,step = tf.while_loop(cond,body,[sequence_len, step, x])

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    step_eval = step.eval(session=sess)

print(step_eval)

下面还粘贴了完整的堆栈跟踪。 The image of the stack trace

1 个答案:

答案 0 :(得分:0)

tf.while_loop()你需要确保body()是一个可调用的,可以获取张量列表并返回相同长度相同类型的张量列表作为输入。这就是While_loop的工作原理。每个返回都作为输入参数发回。也就是说,先前的返回是下一次迭代的输入参数。