在训练期间如何停止TensorFlow数据集迭代器?

时间:2018-12-20 08:58:34

标签: python tensorflow iterator tensorflow-datasets

我正在使用tf.data.Dataset作为我的净输入。我通常这样使用:

dataset = get_dataset_pipeline(mode='train')
iter = dataset.make_initializable_iterator()
element = iter.get_next()

在进行训练时,我设置了dataset.repeat(1),每次使用此dataset时,我将执行sess.run(iter.initializer)training的地雷网络,如下所示:

for epoch in range(4):
    sess.run(iter.initializer)
    try:
        while True:
            train_step += 1
            input = sess.run(elemeent)
            # some train step here           
    except tf.errors.OutOfRangeError:
        pass
    finally:
        pass

问题是我如何才能在tf.errors.OutOfRangeError出现之前提前停止训练?。我的意思是我不仅要跳出训练循环,而且还希望iter真正被禁用或从地雷系统内存中删除。谁能得到我的观点并为此提供一些建议?非常感谢~~~

0 个答案:

没有答案