我正在尝试为自定义数据集运行在Pytorch中实现的Siamese Network。 Github代码链接:https://github.com/viral-parmar/Voice_Dissimilarity 当我尝试运行代码并提供仅包含10个图像的文件夹的Training数据的路径时。代码的数据加载部分需要花费大量执行时间,并且永远不会成功执行。但是,当我提供到相同训练数据的路径时,现在拥有25个文件夹(每个文件夹具有10张图像),加载部分将成功执行。我不了解代码的问题。
我尝试在Google Co Lab和本地计算机上运行它,但是结果是相同的。
folder_dataset = dset.ImageFolder(root=Config.training_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
transform=transforms.Compose([transforms.Resize((100,100)),
transforms.ToTensor()
])
,should_invert=False)
#Visualising some of the data
vis_dataloader = DataLoader(siamese_dataset,
shuffle=True,
num_workers=8,
batch_size=8)
dataiter = iter(vis_dataloader)
example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())
这需要很长时间,并且永远不会完全执行。
答案 0 :(得分:0)
所以我认为num_workers=8
是导致问题的部分。基本上,分发的开销要花费更多时间。尝试将其更改为1,即num_workers=1
,然后再次运行。对于10张图片,您不需要8个核心:P
希望这会有所帮助!
答案 1 :(得分:0)
我现在很确定这些while循环的原因。基本上检查在这10张图片中您是否同时拥有这两个类别?如果所有10张图片都使用相同的类,则循环永远不会中断。
while True:
#keep looping till the same class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1]==img1_tuple[1]:
break
else:
while True:
#keep looping till a different class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1] !=img1_tuple[1]:
break