将Keras model.fit的`steps_per_epoch`与TensorFlow的数据集API的`batch()`相结合

时间:2019-02-07 14:35:42

标签: tensorflow keras tensorflow-datasets

我正在研究使用Keras + TensorFlow训练CNN模型期间的性能和GPU使用情况。与this question类似,我很难理解Keras model.fit的{​​{1}}和TensorFlow的Dataset API的steps_per_epoch的组合使用:我设置了一定的批处理大小在输入管道.batch()上,后来我使用

dataset = dataset.batch(batch_size)

但是我看到每个纪元实际上可以设置任何数量的步,甚至超过fit = model.fit(dataset, epochs=num_epochs, steps_per_epoch=training_set_size//batch_size) 。从文档中我了解到,在Keras上,一个纪元不一定像通常那样遍及整个培训集,但是无论如何我还是有些困惑,现在我不确定是否正确使用它。

training_set_size//batch_sizedataset.batch(batch_size)是否定义了一个微型批次SGD,它通过steps_per_epoch=training_set_size//batch_size个样本的微型批次在整个训练集中运行?如果batch_size设置为大于steps_per_epoch,则历时大于训练集的一遍吗?

1 个答案:

答案 0 :(得分:3)

steps_per_epoch是您在一个时期内通过网络运行所设置的批次大小的批次数。

您有充分理由将steps_per_epoch设置为training_set_size//batch_size。这样可以确保所有数据都在一个纪元上得到训练,只要数字精确地相除即可(如果不是,则由//运算符四舍五入)。

也就是说,如果批次大小为10,训练集大小为30,则steps_per_epoch = 3确保使用了所有数据。

并引用您的问题:

  

“如果将steps_per_epoch设置为大于training_set_size // batch_size,则历时是否超过训练集的一倍?”

是的。某些数据将在同一时期再次通过。