我有一个包含11个样本的数据集。当我选择BATCH_SIZE
为2时,以下代码会出错:
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
if shuffle:
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)
问题在于dataset = dataset.batch(batch_size)
,当Dataset
循环到最后一批时,剩余的样本数只有1,所以有没有办法从之前访问过的样本中随机选择一个生成最后一批?
答案 0 :(得分:7)
@mining通过填充文件名来提出解决方案。
另一种解决方案是使用tf.contrib.data.batch_and_drop_remainder
。这将使用固定的批次大小批量处理数据,并删除最后一个较小的批次。
在您的示例中,有11个输入且批量大小为2,这将产生5批2个元素。
以下是文档中的示例:
dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
答案 1 :(得分:2)
您只需在对drop_remainder=True
的通话中设置batch
。
dataset = dataset.batch(batch_size, drop_remainder=True)
drop_remainder :(可选。)tf.bool标量tf.Tensor,表示 如果最后一批较少,是否应丢弃最后一批 比batch_size元素;默认行为是不删除 较小的批次。