Tensorflow成本函数占位符错误

时间:2017-11-30 08:19:23

标签: python tensorflow placeholder

我有以下代码尝试优化具有两个输入和三个参数(m_1,m_2和b)的线性模型。最初,我遇到了以一种方式导入数据的问题,即feed_dict会接受它们,我通过把它放在一个numpy数组中来解决。

现在,优化器功能将顺利运行(输出看起来大致与优化参数一样),但是一旦我尝试使用最后一行返回成本:

cost_val = sess.run(cost)

它返回以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_2' with dtype float and shape [?,1]
     [[Node: Placeholder_2 = Placeholder[dtype=DT_FLOAT, shape=[?,1], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

如果我单独注释掉这一行,一切都会顺利进行。

我尝试将成本函数从我使用的更复杂的函数更改为更简单的函数,但错误仍然存​​在。我知道这可能与数据输入形状(?)有关,但无法确定数据如何适用于优化器而不是成本函数。

# reading in data
filename = tf.train.string_input_producer(["file.csv"])
reader = tf.TextLineReader(skip_header_lines=1)
key, value = reader.read(filename)
rec_def = [[1], [1], [1]]
input_1, input_2, col3 = tf.decode_csv(value, record_defaults=rec_def)

# parameters
learning_rate = 0.001
training_steps = 300

x = tf.placeholder(tf.float32, [None,1])
x2 = tf.placeholder(tf.float32, [None,1])

m = tf.Variable(tf.zeros([1,1]))
m2 = tf.Variable(tf.zeros([1,1]))
b = tf.Variable(tf.zeros([1]))

y_ = tf.placeholder(tf.float32, [None,1])

y = tf.matmul(x,m) + tf.matmul(x2,m2) + b

# cost function
# cost = tf.reduce_mean(tf.log(1+tf.exp(-y_*y)))
cost = tf.reduce_sum(tf.pow((y_-y),2))
# Gradient descent optimizer
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# initializing variables
init = tf.global_variables_initializer()


with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    sess.run(init)

    for i in range(training_steps):

        xs = np.array([[sess.run(input_1)]])
        ys = np.array([[sess.run(input_2)]])
        label = np.array([[sess.run(col3)]])

        feed = {x:xs, x2:ys, y_:label}
        sess.run(optimizer, feed_dict=feed)
        cost_val = sess.run(cost)

    coord.request_stop()
    coord.join(threads)

1 个答案:

答案 0 :(得分:1)

cost张量是占位符张量的函数,这要求它们具有值。由于拨打sess.run(cost)的电话不会为这些占位符提供信息,因此您会看到错误。 (换句话说 - 您想要计算xy_的价值是什么?)

所以你想改变这一行:

    cost_val = sess.run(cost)

为:

    cost_val = sess.run(cost, feed_dict=feed)

希望有所帮助。