内存泄漏tf.data + Keras

时间:2018-09-25 08:33:49

标签: tensorflow keras

我的训练流水线出现内存泄漏,不知道如何解决。

我在Python 3.5.2中使用Tensorflow版本:1.9.0和Keras(tf)版本:2.1.6-tf

这是我的训练管道的样子:

for i in range(num_epochs):

    training_data = training_set.make_one_shot_iterator().get_next()
    hist = model.fit(training_data[0],[training_data[1],training_data[2],training_data[3]],
                    steps_per_epoch=steps_per_epoch_train,epochs=1, verbose=1, callbacks=[history, MemoryCallback()])


    # custom validation

在迭代器用尽之后,似乎没有释放迭代器的内存。我已经在del traininig_data之后尝试过model.fit。没用

有人可以给些提示吗?

编辑: 这就是我创建数据集的方式。

dataset = tf.data.TFRecordDataset(tfrecords_filename)
dataset = dataset.map(map_func=preprocess_fn, num_parallel_calls=8)
dataset = dataset.shuffle(100)
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)

1 个答案:

答案 0 :(得分:0)

包括repeat()方法以重新初始化迭代器可能会解决您的问题。您可以查看Input Pipeline Performance Guide来找出根据您的要求对方法进行优化的最佳顺序。

dataset = dataset.shuffle(100)
dataset = dataset.repeat() # Can specify num_epochs as input if needed
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)

如果您有能力作为fit方法的一部分进行验证,则可以使用下面的代码,而完全失去循环,从而使您的生活更轻松。

training_data = training_set.make_one_shot_iterator().get_next()
# val_data refers to your validation data and steps_per_epochs_val refers to no of your validation batches
hist = model.fit(training_data[0],training_data[1],training_data[2],training_data[3]], validation_data=val_data.make_one_shot_iterator(), validation_steps=steps_per_epochs_val, 
       steps_per_epoch=steps_per_epoch_train, epochs=num_epochs, verbose=1, callbacks=[history, MemoryCallback()])

参考:https://github.com/keras-team/keras/blob/master/examples/mnist_dataset_api.py