如何从pytorch数据加载器中获取批处理迭代的总数?

时间:2020-09-17 03:00:51

标签: for-loop pytorch dataloader

我有一个问题,如何从pytorch数据加载器中获取批处理迭代的总数?

以下是常见的培训代码

for i, batch in enumerate(dataloader):

然后,有什么方法可以获取“ for循环”的迭代总数吗?

在我的NLP问题中,迭代总数不同于int(n_train_samples / batch_size)...

例如,如果我仅截断训练数据10,000个样本并将批次大小设置为1024,那么在我的NLP问题中会发生363次迭代。

我想知道如何在“ for-loop”中获得总迭代次数。

谢谢。

2 个答案:

答案 0 :(得分:7)

len(dataloader)返回批次总数。它取决于数据集的__len__函数,因此请确保已正确设置。

答案 1 :(得分:0)

创建数据加载器时还有一个附加参数。它称为drop_last

如果drop_last=True,则长度为number_of_training_examples // batch_size。 如果drop_last=False可能是number_of_training_examples // batch_size +1

BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)

对于预定义的数据集,您可能会看到许多示例,例如:

# number of examples
len(dl_train.dataset) 

数据加载器中的正确批数始终为:

# number of batches
len(dl_train)