对PyTorch使用大数据集的最有效方法?

时间:2018-12-01 23:56:02

标签: python memory pytorch hdf5 data-processing

也许这个问题曾经被问过,但是我很难找到适合自己情况的相关信息。

我正在使用PyTorch创建用于图像数据回归的CNN。我没有正式的学术程序设计背景,所以我的许多方法都是临时的,效率很低。可能有时我可以回顾我的代码并稍后进行清理,因为效率低下并没有严重影响性能。但是,在这种情况下,我使用图像数据的方法会花费很长时间,会占用大量内存,并且每次我要测试模型更改时都会执行此操作。

我所做的基本上是将图像数据加载到numpy数组中,然后将这些数组保存在.npy文件中,然后当我要将所述数据用于模型时,我将导入该文件中的所有数据。我不认为数据集真的那么大,因为它包含5000个3x大小为64x64的彩色通道图像。但是,在加载时,我的内存使用率高达70%-80%(在16gb中),并且每次加载都需要20-30秒。

我的猜测是我对加载方式很傻,但是坦率地说我不确定标准是什么。我应该以某种方式在需要之前将图像数据放在某处,还是应该直接从图像文件加载数据?在这两种情况下,独立于文件结构的最佳,最有效的方法是什么?

在此方面,我将不胜感激。

4 个答案:

答案 0 :(得分:1)

这里是一个具体的例子来说明我的意思。假设您已经使用TestCase将图像转储到hdf5文件(train_images.hdf5)中。

h5py

简单来说,import h5py hf = h5py.File('train_images.hdf5', 'r') group_key = list(hf.keys())[0] ds = hf[group_key] # load only one example x = ds[0] # load a subset, slice (n examples) arr = ds[:n] # should load the whole dataset into memory. # this should be avoided arr = ds[:] 现在可以用作迭代器,它可以即时提供图像(即,它不会在内存中加载任何内容)。这应该使整个运行时间快速增长。

答案 1 :(得分:0)

为了提高速度,我建议使用 HDF5 LMDB

  

使用LMDB的原因:

     

LMDB使用内存映射文件,从而提供了更好的I / O性能。   非常适合大型数据集。始终读取HDF5文件   完全存储在内存中,因此您的HDF5文件不能超过   内存容量。您可以轻松地将数据拆分为多个HDF5   但是文件(只需在文本文件中放置几个​​指向h5文件的路径)。   再说一次,与LMDB的页面缓存相比,I / O性能不会   几乎一样好。   [http://deepdish.io/2015/04/28/creating-lmdb-in-python/]

如果您决定使用 LMDB

ml-pyxis是用于使用LMDB创建和读取深度学习数据集的工具。

它允许创建二进制Blob(LMDB),并且可以非常快速地读取它们。上面的链接附带了一些有关如何创建和读取数据的简单示例。包括python generators / iteratos。

notebook上有一个示例,说明如何使用pytorch创建数据集并在同位读取它。

如果您决定使用 HDF5

PyTables是一个用于管理分层数据集的程序包,旨在高效,轻松地处理大量数据。

https://www.pytables.org/

答案 2 :(得分:0)

除了上述答案外,由于Pytorch世界的最新进展(2020年),以下内容可能会有用。

您的问题:我应该以某种方式在需要之前将图像数据放在某个地方,还是应该直接从图像文件加载数据?在这两种情况下,独立于文件结构的最佳,最有效的方法是什么?

您可以将原始格式的图像文件(.jpg,.png等)保留在本地磁盘或云存储上,但又增加了一个步骤-将目录压缩为tar文件。请阅读以下详细信息:

Pytorch博客(2020年8月):高效的PyTorch I / O库,适用于大型数据集,多个文件,多个GPU(https://pytorch.org/blog/efficient-pytorch-io-library-for-large-datasets-many-files-many-gpus/

此软件包用于数据文件太大而无法容纳在内存中进行训练的情况。因此,您提供数据集位置的URL(本地,云,..),它将成批和并行引入数据。

(当前)唯一的要求是数据集必须为tar文件格式。

tar文件可以位于本地磁盘或云上。这样,您不必每次都将整个数据集加载到内存中。您可以使用torch.utils.data.DataLoader批量加载以进行随机梯度下降。

答案 3 :(得分:0)

无需将图像保存到 npy 并将所有内容加载到内存中。只需加载一批图像路径并转换为张量即可。

下面的代码定义了MassiveDataset,并传入DataLoader,一切顺利。

from torch.utils.data.dataset import Dataset
from typing import Optional, Callable
import os
import multiprocessing

def apply_transform(transform: Callable, data):
    try:
        if isinstance(data, (list, tuple)):
            return [transform(item) for item in data]

        return transform(data)
    except Exception as e:
        raise RuntimeError(f'applying transform {transform}: {e}')


class MassiveDataset(Dataset):
    def __init__(self, filename, transform: Optional[Callable] = None):
        self.offset = []
        self.n_data = 0

        if not os.path.exists(filename):
            raise ValueError(f'filename does not exist: {filename}')

        with open(filename, 'rb') as fp:
            self.offset = [0]
            while fp.readline():
                self.offset.append(fp.tell())
            self.offset = self.offset[:-1]

        self.n_data = len(self.offset)

        self.filename = filename
        self.fd = open(filename, 'rb', buffering=0)
        self.lock = multiprocessing.Lock()

        self.transform = transform

    def __len__(self):
        return self.n_data

    def __getitem__(self, index: int):
        if index < 0:
            index = self.n_data + index
        
        with self.lock:
            self.fd.seek(self.offset[index])
            line = self.fd.readline()

        data = line.decode('utf-8').strip('\n')

        return apply_transform(self.transform, data) if self.transform is not None else data

注意open file with buffering=0multiprocessing.Lock() 用于避免加载错误数据(通常来自文件的一部分和文件另一部分的一点)。

另外,如果在 DataLoader 中使用 multiprocessing,可能会得到这样的异常 TypeError: cannot serialize '_io.BufferedReader' object。这是由 multiprocessing 中使用的 pickle 模块引起的,它无法序列化 _io.BufferedReader,但 dill 可以。用 multiprocess 替换 multiprocessing,一切顺利(与 multiprocessing 相比,主要变化,使用 dill 完成了增强的序列化)

同样的事情在this issue

中讨论过
相关问题