如何在tf.data.Dataset中填充固定的BATCH_SIZE?

时间:2018-01-18 16:11:35

标签: tensorflow tensorflow-datasets

我有一个包含11个样本的数据集。当我选择BATCH_SIZE为2时,以下代码会出错:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

问题在于dataset = dataset.batch(batch_size),当Dataset循环到最后一批时,剩余的样本数只有1,所以有没有办法从之前访问过的样本中随机选择一个生成最后一批?

2 个答案:

答案 0 :(得分:7)

@mining通过填充文件名来提出解决方案。

另一种解决方案是使用tf.contrib.data.batch_and_drop_remainder。这将使用固定的批次大小批量处理数据,并删除最后一个较小的批次。

在您的示例中,有11个输入且批量大小为2,这将产生5批2个元素。

以下是文档中的示例:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))

答案 1 :(得分:2)

您只需在对drop_remainder=True的通话中设置batch

dataset = dataset.batch(batch_size, drop_remainder=True)

来自documentation

  

drop_remainder :(可选。)tf.bool标量tf.Tensor,表示   如果最后一批较少,是否应丢弃最后一批   比batch_size元素;默认行为是不删除   较小的批次。