tensorflow中的while_loop返回类型错误

时间:2016-10-27 21:44:52

标签: tensorflow

我很困惑为什么以下代码会返回此错误消息:

Traceback (most recent call last):
  File "/Users/Desktop/TestPython/tftest.py", line 46, in <module>
    main(sys.argv[1:])
  File "/Users/Desktop/TestPython/tftest.py", line 35, in main
    result = tf.while_loop(Cond_f2, Body_f1, loop_vars=loopvars)
  File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2518, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2356, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2292, in _BuildLoop
    c = ops.convert_to_tensor(pred(*packed_vars))
  File "/Users/Desktop/TestPython/tftest.py", line 18, in Cond_f2
    boln = tf.less(tf.cast(tf.constant(ind), dtype=tf.int32), tf.cast(tf.constant(N), dtype=tf.int32))
  File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 163, in constant
    tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
  File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 353, in make_tensor_proto
    _AssertCompatible(values, dtype)
  File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 287, in _AssertCompatible
    raise TypeError("List of Tensors when single Tensor expected")
TypeError: List of Tensors when single Tensor expected

如果有人可以帮我修复此错误,我将不胜感激。谢谢!

from math import *
import numpy as np
import sys
import tensorflow as tf

def Body_f1(n, ind, N, T):
    # Compute trace
    a = tf.trace(tf.random_normal(0.0, 1.0, (n, n)))
    # Update trace
    a = tf.cast(a, dtype=T.dtype)
    T = tf.scatter_update(T, ind, a)
    # Update index
    ind = ind + 1

    return n, ind, N, T

def Cond_f2(n, ind, N, T):
    boln = tf.less(tf.cast(tf.constant(ind), dtype=tf.int32), tf.cast(tf.constant(N), dtype=tf.int32))
    return boln



def main(argv):
    # Open tensorflow session
    sess = tf.Session()

    # Parameters
    N = 10
    T = tf.zeros((N), dtype=tf.float64)
    n = 4
    ind = 0

    # While loop
    loopvars = [n, ind, N, T]
    result = tf.while_loop(Cond_f2, Body_f1, loop_vars=loopvars, shape_invariants=None, \
     parallel_iterations=1, back_prop=False, swap_memory=False, name=None)
    trace = result[3]
    trace = sess.run(trace)
    print trace
    print 'Done!'

    # Close tensorflow session
    if session==None:
        sess.close()

if __name__ == "__main__":
    main(sys.argv[1:])

更新:我添加了完整的错误消息。我不知道为什么我收到此错误消息。 loop_vars是否期望单个张量而不是张量列表?我希望不会。

1 个答案:

答案 0 :(得分:2)

tf.constant需要一个非Tensor值,如Python列表或numpy数组。你可以通过迭代tf.constant来获得相同的错误,就像在tf.constant(tf.constant(5。))中一样。删除这些调用可修复第一个错误。这是一个非常糟糕的错误消息,因此我鼓励您file a bug on Github

看起来random_normal的参数有点混乱;关键字参数有助于避免这样的问题:

tf.random_normal(mean=0.0, stddev=1.0, shape=(n, n))

最后,scatter_update需要一个变量。看起来TensorArray可能是您在这里寻找的内容(或隐式使用TensorArray的higher level looping constructs之一)。

相关问题