Pytorch: 使用torch.utils.data封装数据
数据集集成框架
数据的流向:先将数据交给自定义的数据集类(其任务为对源数据预处理,包装成数据加载器能够访问的数据集),通常转化为张量 通过数据加载类进行采样并送给模型计算并优化(其任务为根据默认或自定义的采样器
Sampler进行真正的采样)map-style和iterable-style:索引式数据集和迭代式数据集- 索引式数据集的做法是将整个数据集读取到内存,采样的时候通过
[]访问即可 适用于批量读取、小数据集,由于数据均在内存中,往往更快 - 迭代式数据集的做法是事先定义迭代的行为,采样的时候从硬盘中取出数据并作为迭代器返回 适用于各种数据无法事先全部读入内存的情况(例如训练过程中会产生新样本、或内存过小),当然是比较慢的
- 索引式数据集的做法是将整个数据集读取到内存,采样的时候通过
torch.utils.data.Dataset(简称Dataset)是所有数据集类型的基类,用户如果需要构建自己的索引式数据集,必须直接继承该类并重载以下方法(其实就是封装成一个索引式数据集):__len__():应返回数据集的长度,以便可以用len(obj)访问数据集长度__getitem__(idx):应返回该索引指向的数据,以便可以用[]访问数据
torch.utils.data.IterableDataset(简称IterableDataset)是所有迭代式数据集的基类,继承该类后必须重载:__iter__():yield一个迭代器,一般是在方法内临时读取文件并返回一条数据
torch.utils.data.DataLoader(简称DataLoader)是数据集对象的迭代器,相比诸多语言的内置迭代器封装了很多功能,不需要手动实现乱序、分批读取等,加载数据集十分方便,其构造方法__init__()参数用法为:dataset:Dataset(或子类)对象,表示要遍历的数据集batch_size:批数据的大小,即通过该迭代器遍历时返回的样本个数;将其设置为None以关闭自动批处理sampler:可以指定采样器进行更多样的采样行为batch_sampler:类似sampler,但要求用户提供的采样器必须返回一批张量的索引而不是一个张量的索引,这一批张量均用用户提供的Sampler对象来采样,它和sampler、batch_size是冲突的(因为认为用户已经在sampler里定义了批量大小)shuffle:布尔值,决定是否乱序采样,和sampler、batch_sampler冲突(因为要使用默认的RandomSampler),没有更高需求的采样行为而懒得自定义采样器时,使shuffle=True就足够了
只提供
dataset,其它参数保持默认下,迭代器会顺序地进行批量为一的自动批处理采样
其它不常用的参数:num_workers:将数据分给多个子进程加载,默认只用主进程加载collate_fn:是一个可调用对象,表示对数据集采样后,返回给用户前进行的行为 默认的collate_fn常在一个开启了自动批处理的索引式数据集中用到,会将采样的数据合并为一批 因为迭代式数据集的批量采样等行为通常是由用户在__iter__()中实现的,所以对迭代式数据集不会提供默认的collate_fn用户可以自定义collate_fnpin_memory:由于GPU中的内存全是锁页内存,但CPU中一般只是不锁页内存(即会被交换到虚存中),设置pin_memory=True,将使读入的样本变成锁页内存,通过消耗CPU内存的方式加快其载入GPU的过程drop_last:布尔值,表示是否丢弃最后一个不完整的批次timeout:表示数据加载的时限
torch.utils.data.Sampler(简称Sampler)是采样器类,表示采样的策略,具体为提供一个索引序列 内置的Sampler有:SequentialSampler:顺序采样,每次返回一个索引,是未提供sampler且shuffle=False时的默认采样器RandomSampler:全集的随机采样,是未提供sampler且shuffle=True时的默认采样器WeightedRandomSampler:根据权重的随机采样SubsetRandomSampler:对数据集子集的随机采样BatchSampler:对其它Sampler产生的索引序列分批,顺序地提供的各个批次的索引序列(一次返回整个批次的所有索引) 开启自动批处理时,最终会经过这个采样器进行分批
用户可以继承
Sampler或以上这些特化的采样器类,来自定义一个Sampler不过大多数情况下,上述采样器已经足够了