Keras Forecast_generator损坏的图像

时间:2018-10-11 14:57:45

标签: python keras

我正在尝试使用经过训练的模型,使用python 3中的带有keras和tensorflow作为后端的predict_generator来预测数百万个图像。生成器和模型预测可以正常工作,但是目录中的某些图像已损坏或损坏,并使predict_generator停止并引发错误。删除图像后,它将再次起作用,直到下一个损坏的/损坏的图像通过该功能提供为止。

由于图像太多,因此运行脚本来打开每个图像并删除引发错误的图像是不可行的。有没有一种方法可以将“如果损坏则跳过图像”参数合并到生成器或目录函数的流中?
任何帮助将不胜感激!

2 个答案:

答案 0 :(得分:0)

ImageDataGenerator中没有这样的参数,flow_from_directory方法中也没有这样的参数,因为您可以同时看到rangehere的Keras文档。一种解决方法是扩展ImageDataGenerator类并重载flow_from_directory方法,以检查在生成器中生成图像之前是否损坏了图像。 here,您可以找到它的源代码。

答案 1 :(得分:0)

由于它是在预测期间发生的,因此,如果跳过任何图像或批次,则需要跟踪跳过的图像,以便将预测分数正确映射到图像文件名。

基于这个想法,我的DataGenerator是通过有效的图像索引跟踪器实现的。尤其要注意变量valid_index,该变量跟踪有效图像的索引。

class DataGenerator(keras.utils.Sequence):
    def __init__(self, df, batch_size, verbose=False, **kwargs):
        self.verbose = verbose
        self.df = df
        self.batch_size = batch_size
        self.valid_index = kwargs['valid_index']
        self.success_count = self.total_count = 0

    def __len__(self):
        return int(np.ceil(self.df.shape[0] / float(self.batch_size)))

    def __getitem__(self, idx):
        print('generator is loading batch ',idx)
        batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
        self.total_count += batch_df.shape[0]

        # return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
        x = load_batch_image_to_arrays(batch_df['image_file_names'])

        # filter out corrupted images
        tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
               u is not None]

        # boundary case. # all image failed, return another random batch
        if len(tmp) == 0:
            print('[ERROR] All images loading failed')
            # based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
            # Keras will automatically find the next batch if it returns None
            return None

        print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
        self.success_count += len(tmp)

        x, batch_index = zip(*tmp) 
        x = np.stack(x)  # list to np.array
        self.valid_index[idx] = batch_index

        # follow preprocess input function provided by keras
        x = resnet50_preprocess(np.array(x, dtype=np.float))
        return x

    def on_epoch_end(self):
        print('total image count', self.total_count)
        print('successful images count', self.success_count)
        self.success_count = self.total_count = 0 # reset count after one epoch ends.

在预测期间。

predictions = model.predict_generator(
            generator=data_gen,
            workers=10,
            use_multiprocessing=False,
            max_queue_size=20,
            verbose=1
        ).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
    indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions