Tensorflow Dataset迭代器消耗大量内存

时间:2019-02-24 12:47:30

标签: python tensorflow machine-learning

我最近正在学习机器学习,并尝试使用Tensorflow实现一个简单的神经网络。

我使用MNIST作为数据集,我想使用Tensorflow的Dataset API加载和批处理我的数据。

这是我的代码:


    train_data = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_data = train_data.shuffle(500)
    train_data = train_data.batch(50)
    train_data = train_data.repeat()
    td_iter = train_data.make_one_shot_iterator()
    features, labels = td_iter.get_next()

    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(n_epochs):
            for iteration in range(n_batches):
                X_batch, y_batch = sess.run([features, labels])
                sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
            acc_train = accuracy.eval(feed_dict={X:X_batch, y:y_batch})
            acc_test = accuracy.eval(feed_dict={X:X_test, y:y_test})
            print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)

我能够高精度地训练模型,但是训练时它会消耗我的所有内存(8GB)。

更具体地说,它在完成第一个时期之前会消耗大量内存(并且打印第一条输出行需要花费相当长的时间),但是如果开始打印某些内容,则内存消耗会减少。

我尝试简化代码以找出问题所在:


    with tf.Session() as sess:
        sess.run(init)
        sess.run([features, labels])

上面的代码仍然会耗尽我的全部内存。

mem

我认为我的代码一定有误,您能帮我吗?

谢谢!

0 个答案:

没有答案
相关问题