ImageDataGenerator的替代方案,用于自定义数据集

时间:2018-08-16 23:27:53

标签: python csv keras annotations computer-vision

以下是我的csv文件

file,pt1,pt2,pt3,,pt4,pt5,pt6
object/obj0.png,66.0335639098,39.0022736842,30.2270075188,36.4216781955,59.582075188,39.6474225564
object/obj0.png,66.0335639098,39.0022736842,30.2270075188,36.4216781955,59.582075188,39.6474225564
object/obj0.png,66.0335639098,39.0022736842,30.2270075188,36.4216781955,59.582075188,39.6474225564

我如何加载这些图像和注释来训练我的简单cnn?

我尝试如下使用'ImagedataGenerator',但这没有帮助...还有其他选择吗?

train_datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

1 个答案:

答案 0 :(得分:2)

ImageDataGenerator对象允许从numpy arrays或直接从目录产生数据。在后一种情况下,标签是从数据的文件夹结构自动推断出来的:每类图像应位于单独的文件夹中。只要标签结构更复杂(如您的情况),您就可以选择编写自己的自定义生成器。如果这样做,请使用Keras' Sequence object,这样可以进行安全的多重处理。 Keras网站包含一个样板示例。就您而言,您的代码应如下所示:

from keras.utils import Sequence
from keras.preprocessing.image import load_img
import pandas as pd
import random 

class DataSequence(Sequence):

    def __init__(self, csv_path, batch_size, mode='train'):
        self.df = pd.read_csv(csv_path) # read your csv file with pandas
        self.bsz = batch_size # batch size
        self.mode = mode # shuffle when in train mode

        # Take labels and a list of image locations in memory
        self.labels = self.df[['pt1', 'pt2', 'pt3', 'pt4', 'pt5', 'pt6']].values
        self.im_list = self.df['file'].tolist()

    def __len__(self):
        # compute number of batches to yield
        return int(math.ceil(len(self.df) / float(self.bsz)))

    def on_epoch_end(self):
        # Shuffles indexes after each epoch if in training mode
        self.indexes = range(len(self.im_list))
        if self.mode == 'train':
            self.indexes = random.sample(self.indexes, k=len(self.indexes))

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return self.labels[idx * self.bsz: (idx + 1) * self.bsz,:]

    def get_batch_features(self, idx):
        # Fetch a batch of inputs
        return np.array([load_img(im) for im in self.im_list[idx * self.bsz: (1 + idx) * self.bsz]])

    def __getitem__(self, idx):
        batch_x = self.get_batch_features(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_x, batch_y

您可以使用此Sequence对象通过model.fit_generator()来训练模型:

sequence = DataSequence('./path_to/csv_file.csv', batch_size)
model.fit_generator(sequence, epochs=1, use_multiprocessing=True)

另请参阅this related question

相关问题