sess.run()导致训练缓慢

时间:2017-10-27 07:50:30

标签: python numpy machine-learning tensorflow

我正在训练CNN,我相信我对sess.run()的使用导致我的训练非常缓慢。

实质上,我使用的是mnist数据集......

from tensorflow.examples.tutorials.mnist import input_data
...
...
features = input_data.read_data_sets("/tmp/data/", one_hot=True)

问题是,CNN的第一层必须接受[batch_size, 28, 28, 1]形式的图像,这意味着我必须先将每张图像转换为CNN。

我用我的脚本执行以下操作......

x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, 10])  
...
...
with tf.Session() as sess:

    for epoch in range(25):

        total_batch = int(features.train.num_examples/500)

        avg_cost = 0

        for i in range(total_batch):

            batch_xs, batch_ys = features.train.next_batch(10)

            # Notice this line.
            _, c = sess.run([train_op, loss], feed_dict={x:sess.run(tf.reshape(batch_xs, [10, 28, 28, 1])), y:batch_ys})

            avg_cost += c / total_batch

        if (epoch + 1) % 1 == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))

注意注释行。我正在从训练集中获取第一批,我正在重塑为正确的格式[batch_size, 28, 28, 1]。我每次都要拨打sess.run(),我相信这是训练速度如此之慢的原因。

我该如何防止这种情况发生。我尝试使用numpy在另一个脚本中重新格式化数据,但它仍然给我带来了问题,因为我无法在不运行numpy的情况下提供sess.run()数组。有人可以告诉我如何在训练课程之外格式化数据吗?也许我可以在另一个脚本中格式化数据并将其加载到包含我的CNN的那个?

2 个答案:

答案 0 :(得分:2)

在每次迭代中你绝对不应该在新的操作上有内部sess.run()(虽然我不确定它真的减慢了多少)。你应该做以下其中一个:

  • 有一个与输入形状相同的占位符,例如[None, 28*28*1],后跟tf.reshape([None, 28, 28, 1]),位于您网络的开头(而不是tf.placeholder([None, 28, 28, 1])

OR

  • 保留神经网络,并使用numpy reshape而不是tensorflow重新格式化:_, c = sess.run([train_op, loss], feed_dict={x:batch_xs.reshape( [-1, 28, 28, 1]), y:batch_ys})

如果你只是写_, c = sess.run([train_op, loss], feed_dict={x:tf.reshape(batch_xs, [10, 28, 28, 1]), y:batch_ys}),它可能也有效,但是那样做,因为它会在每次迭代时在你的图形中创建一个新的op。

答案 1 :(得分:1)

您可以做的另一件事是重新设置开头本身的所有输入,然后将其提供给占位符。

import math
import numpy as np
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, 10])  
...
...
with tf.Session() as sess:
    X_train=mnist.train.images.reshape(-1,28,28,1)
    y_train=mnist.train.labels
    train_indicies = np.arange(X_train.shape[0])
    num_epochs = 25 // number of epochs
    batch_size = 50
    total_batch = int(math.ceil(X_train.shape[0]/batch_size))
    for epoch in range(25):
        for i in np.arange(total_batch):
        start_idx = (i*batch_size)%X_train.shape[0]
        idx = train_indicies[start_idx:start_idx+batch_size]
        _, c = sess.run([train_op, loss], feed_dict={x:X_train[idx,:], y:y_train[idx]})
        avg_cost += c / total_batch

    if (epoch + 1) % 1 == 0:
        print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))

因为我们无法使用mnist.train.next_batch,所以我们需要手动计算和增加索引。

希望这有效:)

相关问题