tensorflow session.run()的参数-您是否传递操作?

时间:2018-07-18 12:23:00

标签: python tensorflow

我正在关注此tutorial for tensorflow

我正在尝试了解tf.session.run()的参数。我了解您必须在会话中的图形中运行操作。

在此特定示例中,train_step是否因为封装了网络的所有操作而被传递?我试图理解为什么我不需要将任何其他变量传递给会话,例如cross_entropy

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

这是完整的代码:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

for _ in range(10):
    batch_xs, batch_ys = mnist.train.next_batch(100)

    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

2 个答案:

答案 0 :(得分:4)

在TensorFlow会话tf.Session中,您要运行(或执行)优化程序操作(在本例中为train_step)。优化程序会最小化您的损失函数(在这种情况下为cross_entropy),该损失函数是使用模型假设y进行评估或计算的。

在级联方法中,cross_entropy损失函数使计算y时产生的错误最小化,因此它找到权重W与{{1}组合时的最佳值}精确地逼近x

因此,将TensorFlow会话对象y用作tf.Session时,我们将运行优化器sess,该优化器随后将评估整个计算图。

train_step

由于级联方法最终会调用sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) ,而该cross_entropy使用占位符xy,因此必须使用feed_dict将数据传递给那些占位符。 / p>

答案 1 :(得分:1)

正如您提到的,Tensorflow用于构建操作图。您的train_step操作(即“通过梯度下降最小化”)是否已连接/取决于cross_entropy的结果。 cross_entropy本身依赖于y(softmax操作)和y_(数据分配)的结果;等

调用sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})时,基本上是在问Tensorflow“ 运行所有导致train_step的操作,并返回其结果(使用x = batch_xs和{{1} }作为输入)”。所以是的,Tensorflow本身将向后浏览您的图,以找出y = batch_ys的操作/输入依赖关系,然后向前执行所有这些操作,以返回您的要求。