正确的方法来停止TensorFlow数据集`from_generator`?

时间:2018-05-10 14:56:17

标签: python tensorflow tensorflow-datasets

我想使用使用import tensorflow as tf def make_batch_generator_fn(batch_size=10, dset_size=100): feats, targs = range(dset_size), range(1, dset_size + 1) def batch_generator_fn(): start_idx, stop_idx = 0, batch_size while True: # if stop_idx > dset_size: --- stop action? yield feats[start_idx: stop_idx], targs[start_idx: stop_idx] start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size return batch_generator_fn def test(batch_size=10): dgen = make_batch_generator_fn(batch_size) features_shape, targets_shape = [None], [None] ds = tf.data.Dataset.from_generator( dgen, (tf.int32, tf.int32), (tf.TensorShape(features_shape), tf.TensorShape(targets_shape)) ) feats, targs = ds.make_one_shot_iterator().get_next() with tf.Session() as sess: counter = 0 try: while True: f, t = sess.run([feats, targs]) print(f, t) counter += 1 if counter > 15: break except tf.errors.OutOfRangeError: print('end of dataset at counter = {}'.format(counter)) if __name__ == '__main__': test() 构建的TensorFlow数据集来访问格式化文件。除了我不知道如何在生成器数据耗尽时停止数据集迭代器(当你超出范围时,生成器只会永远返回空列表),大多数都可以工作。

我的实际代码非常复杂,但我可以通过这个简短的程序来模拟这种情况:

stop action?

如果我事先知道记录的数量,我可以调整批次的数量,但我不会总是知道。我已经尝试在上面的代码段中添加一些代码,我在其中有IndexError这样的注释行。特别是,我尝试过提升catch,但TensorFlow并不喜欢这样,即使我在执行代码中明确tf.errors.OutOfRangeError。我也试过提出sort((x,y) => x.description < y.description ? -1 : 1) ,但我不确定如何实例化它:构造函数需要三个参数 - &#39; node_def&#39;,&#39; op&#39;和& #39;消息&#39;,我不太确定要使用什么节点&#39; node_def&#39;和&#39; op&#39;总的来说。

我对此问题的任何想法或意见表示感谢。谢谢!

1 个答案:

答案 0 :(得分:0)

它适用于以下几行:

dataset_size = your dataset size
batch_size = your batch size
dataset = your tf.data.Dataset
steps_per_epoch = dataset_size // batch_size

for data, _ in zip(dataset, range(steps_per_epoch)):
    # your train_step

迭代将结束。