根据pytorch数据集中的文件名拆分数据集

时间:2018-09-24 06:07:46

标签: python-3.x dataset pytorch

是否有一种基于文件名将数据集分为训练和测试的方法。我有一个包含两个文件夹的文件夹:输入和输出。输入文件夹包含图像,输出是该图像的标签。输入文件夹中的文件名类似于input01_train.pnginput01_test.png,如下所示。

                          Dataset
                          /     \
                     Input       Output
                      |             |
           input01_train.png   output01_train.png
                    .                 .
                    .                 .
           input01_test.png    output01_test.png

我仅有的代码仅将数据集分为输入和标签,而不是测试和训练。

class CancerDataset(Dataset):
  def __init__(self, dataset_folder):#,label_folder):
    self.dataset_folder = torchvision.datasets.ImageFolder(dataset_folder ,transform = transforms.Compose([transforms.Resize(512),transforms.ToTensor()]))
    self.label_folder = torchvision.datasets.ImageFolder(dataset_folder ,transform = transforms.Compose([transforms.Resize(512),transforms.ToTensor()]))

  def __getitem__(self,index):
    img = self.dataset_folder[index]
    label = self.label_folder[index]
    return img,label

  def __len__(self):
    return len(self.dataset_folder)

trainset = CancerDataset(dataset_folder = '/content/drive/My Drive/cancer_data/')
trainsetloader = DataLoader(trainset,batch_size = 1, shuffle = True,num_workers = 0,pin_memory = True)

如果可以的话,我希望能够按名称对火车和测试仪进行划分。

1 个答案:

答案 0 :(得分:1)

您可以自己在__getitem__中加载图像,只选择包含'_train.png'或'_test.png'的图像。

class CancerDataset(Dataset):
    def __init__(self, datafolder, datatype='train', transform = transforms.Compose([transforms.Resize(512),transforms.ToTensor()]):
        self.datafolder = datafolder
        self.image_files_list = [s for s in os.listdir(datafolder) if
                                 '_%s.png' % datatype in s]
        # Same for the labels files
        self.label_files_list = ...
        self.transform = transform

    def __len__(self):
        return len(self.image_files_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.datafolder,
                                self.image_files_list[idx])
        image = Image.open(img_name)
        image = self.transform(image)
        # Same for the labels files
        label = .... # Load in etc
        label = self.transform(label)
        return image, label

现在您可以制作两个数据集(trainsettestset)。

trainset = CancerDataset(dataset_folder = '/content/drive/My Drive/cancer_data/', datatype='train')
testset = CancerDataset(dataset_folder = '/content/drive/My Drive/cancer_data/', datatype='test')
相关问题