通过Keras使用大于2 Gb的数据集

时间:2018-12-26 22:44:17

标签: python tensorflow keras

TensorFlow在单个张量上长期存在2 Gb的限制。这意味着您不能一次跳过超过2 Gb的数据来训练模型。参见Initializing tensorflow Variable with an array larger than 2GBUse large dataset in Tensorflow

这些帖子中引用的标准解决方案是使用占位符,并将其通过feed_dict传递给“会话”:

my_graph = tf.Graph()
sess = tf.Session(graph=my_graph)   
X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
X = tf.Variable(X_init)
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})

但是,这仅在我使用“旧” API(tf.Session()等)时才有效。当今推荐的方法是使用Keras(tensorflow.org上的所有教程都使用它)。而且,对于Keras,没有tf.Graph(),tf.Session()和run()(至少没有一个用户容易看到的)。

我如何使以上代码与Keras配合使用?

2 个答案:

答案 0 :(得分:4)

在Keras中,您不会在张量中加载整个数据集。您将其加载到numpy数组中。

如果整个数据可以在单个numpy数组中:

感谢@sebrockm的评论。

Keras最简单的用法就是将数据集加载到 numpy数组(不是tf张量)中,然后调用model.fit(arrayWithInputs, arrayWithoutputs, ...)

如果整个数据不适合numpy数组:

您将创建一个generatorkeras.utils.Sequence来一次加载一批,然后使用model.fit_generator(generatorOrSequence, ...)训练模型

限制成为批处理大小,但是您几乎没有在单个批处理中达到2GB。 因此,继续努力:

答案 1 :(得分:3)

Keras的数据集没有2GB的限制,我已经使用Keras训练了更大的数据集,没有问题。

限制可能来自TensorFlow常数,该常数确实有2GB的限制,但是在任何情况下,您都将数据集存储为常数,因为这些数据集已保存为图形的一部分,也就是说不是存储模型的想法。

Keras具有model.fit_generator函数,您可以使用该函数传递生成器函数,该函数可以动态加载数据并进行批处理。这使您可以动态加载大型数据集,并且通常可以调整批处理大小,以便在可接受的RAM使用情况下最大化性能。 TensorFlow没有类似的API,您必须按照feed_dict的说明手动实现它。