如何切片 Kinetics400 训练数据集? (pytorch)

时间:2021-01-22 08:08:23

标签: pytorch

我正在尝试运行官方 script 进行视频分类。 我想调整一些功能并运行所有示例会花费我太多时间。 我想知道如何根据该脚本对训练动力学数据集进行切片。 这是我之前添加的代码 train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video) 在脚本中:(假设我只想运行 100 个示例。)

tr_split_len = 100
dataset = torch.utils.data.random_split(dataset, [tr_split_len, len(dataset)-tr_split_len])[0]

然后当点击 train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video) ,它弹出错误:

AttributeError: 'Subset' object has no attribute 'video_clips'

是的,所以 dataset 的类型从 torchvision.datasets.kinetics.Kinetics400 转换为 torch.utils.data.dataset.Subset。 我明白。那么我该怎么做呢? (希望不是在数据加载器循环中使用 break 的方式)。 谢谢。

1 个答案:

答案 0 :(得分:0)

似乎torchvision.datasets.kinetics.Kinetics400在内部使用了一个VideoClips类的对象来存储有关剪辑的信息。它存储在成员变量 Kinetics4000().video_clips 中。

VideoClips 类有一个名为 subset 的函数,它接受索引列表并返回一个新的 VideoClips 对象,其中仅包含具有指定索引的剪辑。然后,您可以用数据集中的新对象替换旧的 VideoClips 对象。