带有PyTorch的多标签,多类别图像分类器(ConvNet)

时间:2018-06-22 06:12:31

标签: python classification pytorch convolutional-neural-network multilabel-classification

我正在尝试使用PyTorch实现图像分类器(CNN / ConvNet),在这里我想从csv文件中读取标签。我有4个不同的类别,一张图片可能属于多个类别。

我已经阅读了PyTorch Tutorialthis Stanford tutorialthis one,但都没有涵盖我的具体情况。我设法建立了print Hotel.getSortedHotels(sortKey=lambda x:x.hotel_name) 类的自定义函数,该函数仅对于从用于二进制分类器的csv文件读取标签有效。

这是我到目前为止torch.utils.data.Dataset类的代码(与上面链接的第三个教程稍有修改):

torch.utils.data.Dataset

具体来说,我正在尝试从具有以下结构的文件中读取标签:

CSV Data

我的具体问题是,我不知道如何在我的import torch import torchvision.transforms as transforms import torch.utils.data as data from PIL import Image import numpy as np import pandas as pd class MyCustomDataset(data.Dataset): # __init__ function is where the initial logic happens like reading a csv, # assigning transforms etc. def __init__(self, csv_path): # Transforms self.random_crop = transforms.RandomCrop(800) self.to_tensor = transforms.ToTensor() # Read the csv file self.data_info = pd.read_csv(csv_path, header=None) # First column contains the image paths self.image_arr = np.asarray(self.data_info.iloc[:, 0]) # Second column is the labels self.label_arr = np.asarray(self.data_info.iloc[:, 1]) # Calculate len self.data_len = len(self.data_info.index) # __getitem__ function returns the data and labels. This function is # called from dataloader like this def __getitem__(self, index): # Get image name from the pandas df single_image_name = self.image_arr[index] # Open image img_as_img = Image.open(single_image_name) img_cropped = self.random_crop(img_as_img) img_as_tensor = self.to_tensor(img_cropped) # Get label(class) of the image based on the cropped pandas column single_image_label = self.label_arr[index] return (img_as_tensor, single_image_label) def __len__(self): return self.data_len 类中实现它。我想我在csv中的标签的(手动)分配与PyTorch如何读取它们之间缺少联系,因为我对框架不是很熟悉。
我非常感谢您提供有关如何使其正常工作的帮助,或者,如果确实有涉及此方面的示例,那么也将非常感谢您的链接!

1 个答案:

答案 0 :(得分:0)

也许我遗漏了一些东西,但是如果您想将1..N(此处为N = 4的列转换为标签矢量或形状为(N,)的标签(例如,给定示例数据,{{ 1}},label(img1) = [0, 0, 0, 1],...),为什么不这样做:

  1. 将所有标签列读入label(img3) = [1, 0, 1, 0]

    self.label_arr
  2. 相应地返回self.label_arr = np.asarray(self.data_info.iloc[:, 1:]) # columns 1 to N 中的标签(此处不变):

    __getitem__()

要训练您的分类器,您可以计算例如single_image_label = self.label_arr[index] 预测与目标标签之间的交叉熵。