您如何清理tf.data.Iterator?

时间:2018-11-10 12:20:05

标签: python tensorflow memory keras

我正在尝试使用来自tensorflow.data输入管线的日期来训练Keras模型。我打算不进行单次训练,而是打算在验证性能下降时尽早停止,并以更大的批次数量继续训练。我的代码如下所示:

batch_sizes = [256, 512, 1024]
for batch_size in batch_sizes:
    input_fn = input_fn_helper(batch_size, ...)
    training_set = input_fn().make_one_shot_iterator()

    input_fn_test = input_fn_test_helper(batch_size, ...)
    testing_set = input_fn_test().make_one_shot_iterator()

    model.fit(training_set,
              steps_per_epoch=(n_train / batch_size),
              epochs=max_epochs,
              validation_data=testing_set,
              validation_steps=(n_test / batch_size),
              callbacks=callbacks)

如您所见,我为批量大小的每次增加构造了一个新的输入管道(input_fn()返回了tf.data.Dataset)。我从中得到的行为是我所期望的,因此它可以执行预期的操作。我确实遇到的问题是,每次循环运行时,我的脚本内存占用量都会增加,即training_settesting_set的先前实例似乎没有在下一步被覆盖而释放。这有几个问题:

  1. 我在这里做错了什么吗?
  2. 是否有一种规范的方法来确保在不再需要tf.data.Iterator时将其妥善处置?

0 个答案:

没有答案