在PyTorch中实现“无限循环”数据集和DataLoader

时间:2019-01-25 05:08:07

标签: pytorch

我想实现无限循环数据集和数据加载器。这是我尝试过的:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

如您所见,这里的主要挑战是__len()__方法。如果我在此处放置足够大的数字,例如1 << 30,则症状是在火车循环的第一次迭代中,内存使用量将跃升至10 + GB。一段时间后,大概是因为OOM,工人被杀了。

如果我在其中放一个小数字,例如1或BATCH_SIZE,则火车循环中的采样“数据”将定期复制。这不是我想要的,因为我希望每次迭代都可以生成和训练新数据。

我猜想过多内存使用的罪魁祸首是堆栈中的某处,一堆东西被缓存了。随意看一下Python的一面,我无法确定具体位置。

有人可以建议实现我想要的最佳方法是什么? (使用DataLoader的并行加载,同时确保加载的每个批次都是全新的。)

3 个答案:

答案 0 :(得分:1)

DataLoader对您的数据集进行采样,而无需替换。为此,它会生成一个索引random permutation,其索引介于0到len(dataset)之间。我猜想,这种排列方式会消耗掉您的大部分内存。我认为PyTorch API不支持无限集合,但是您可以尝试分叉DataLoader中的代码并自己完成。 您可以使用batch_sampler参数,并传入基于RandomSampler实现的自定义变体。这将允许您保留DataLoader的并行加载部分。

话虽如此,基于__len____getitem__的迭代协议并不适合无限集合。重新实现Dataset.__len__仅返回1,使Dataset.__getitem__始终返回一个新样本(无论索引如何),然后对n进行“替换” 。从技术上讲,它将要求n次以第0个样本为例,但是由于您覆盖__getitem__以返回不同的样本,因此这将有效地满足您的需求。

答案 1 :(得分:1)

这似乎在不定期复制数据的情况下起作用:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 2

class Infinite(Dataset):

    def __len__(self):
        return BATCH_SIZE

    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))


data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)

batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')

    data = next(iter(data_loader))
    print(data)
    # forward + backward on "data"  

    if batch_count == 5:
        break

结果:

Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])

所以我认为问题出在您的函数sample_func_to_be_parallelized()中。


编辑:如果我在torch.randint(0, 10, (3,))中使用np.random.randint(10, size=3)(例如__getitem__)而不是sample_func_to_be_parallelized(),则数据为确实在每批重复。参见此issue

因此,如果您在sample_func_to_be_parallelized()中某处使用numpy的RGN,则解决方法是使用

worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id) 

,并在每次调用np.random.seed()之前通过data = next(iter(data_loader))重置种子。

答案 2 :(得分:0)

尝试使用cycle中的itertools。这是简单数据集的示例:

代码:

from itertools import cycle

import torch
from torch.utils.data import Dataset, DataLoader


# Create some dummy data.
data = torch.tensor([[0, 0],
                     [1, 1],
                     [2, 2],
                     [3, 3]])


class DataSet(Dataset):
    """Our dataset. Iterates over tensor data"""

    def __init__(self, data):
        self.data = data
        self.n = self.data.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.data[idx]


bs = 1  # batch size
workers = 1  # number of workers

dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)

# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
    print(i, data)

输出:

batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
...

batch size: 2 | number of workers: 2
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
        [3, 3]])
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
...
相关问题