在tensorflow中无法识别函数中创建的占位符

时间:2017-05-03 16:00:22

标签: tensorflow

我通过点击和试用找到了解决问题的方法,但不明白为什么会有效。

我正在经历getting started of tensorflow。我运行代码直到下面的代码块结束,它工作正常。

W = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)
x = tf.placeholder(tf.float32)
linear_model = W * x + b

y = tf.placeholder(tf.float32)
squared_deltas = tf.square(linear_model - y)
loss = tf.reduce_sum(squared_deltas)
print(sess.run(loss, {x:[1,2,3,4], y:[0,-1,-2,-3]}))

现在,我认为loss看起来像应该在函数中的东西。我尝试了以下内容。

def get_loss(linear_model):
    y = tf.placeholder(tf.float32)
    squared_deltas = tf.square(linear_model - y)
    return tf.reduce_sum(squared_deltas)

loss = get_loss(linear_model)
print(sess.run(loss, {x:[1,2,3,4], y:[0,-1,-2,-3]}))

但它给了我一个错误。

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_26' with dtype float
     [[Node: Placeholder_26 = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

然后我尝试了以下内容。唯一的区别是我也返回y

def get_loss(linear_model):
    y = tf.placeholder(tf.float32)
    squared_deltas = tf.square(linear_model - y)
    return y, tf.reduce_sum(squared_deltas)

y, loss = get_loss(linear_model)
print(sess.run(loss, {x:[1,2,3,4], y:[0,-1,-2,-3]}))

有效。为什么呢?

1 个答案:

答案 0 :(得分:2)

它不起作用,因为yget_loss()范围之外不可见。如果您返回底部的y,则y现在可见,因此可以使用。

如果您不想返回y,则可以使用其名称来提供y:

def get_loss(linear_model):
    # give y a name 'y'
    y = tf.placeholder(tf.float32, name='y')
    squared_deltas = tf.square(linear_model - y)
    return tf.reduce_sum(squared_deltas)

loss = get_loss(linear_model)
print(sess.run(loss, {x:[1,2,3,4], 'y:0':[0,-1,-2,-3]}))

Here是关于"y:0"“:0”部分的讨论