Keras fit和fit_generator返回完全不同的结果

时间:2018-10-03 17:26:20

标签: debugging keras

Keras fit和fit_generator返回的结果完全不同,fit_generator的精度下降了近20%。我确实在数据生成器中使用了shuffle。我在下面附上了我的data_generator。谢谢!

def data_generator(input_x, input_y, batch_size = BATCH_SIZE):
    loopcount = len(input_x) // batch_size
    while True:
        i = random.randint(0, loopcount - 1)
        x_batch = input_x[i*batch_size:(i+1)*batch_size]
        y_batch = input_y[i*batch_size:(i+1)*batch_size]
        yield x_batch, y_batch

我的model.fit_generator显示如下:

    model.fit_generator(generator = data_generator(x_train, y_train, batch_size = BATCH_SIZE),steps_per_epoch = len(x_train) // BATCH_SIZE, epochs = 20, validation_data = data_generator(x_val, y_val, batch_size = BATCH_SIZE), validation_steps = len(x_val) // BATCH_SIZE)

2 个答案:

答案 0 :(得分:0)

注意:这不是解决方案,而是类似的问题。我认为这两个问题都是相关的,解决一个问题可以解决另一个问题。

fit_generator is not working properly

答案 1 :(得分:0)

您的生成器似乎随机选择批次,最终可能会在同一时期内重复或未使用批次。

为避免这种情况,您可以执行以下操作(未经测试):

    def data_generator(input_x, input_y, batch_size):
        loopcount = len(input_x) // batch_size
        batches = range(loopcount)
        random.shuffle(batches)
        i = 0
        while True:
            b = batches[i]
            x_batch = input_x[b*batch_size:(b+1)*batch_size]
            y_batch = input_y[b*batch_size:(b+1)*batch_size]
            i += 1
            yield x_batch, y_batch

但是,如果您的数据量不是batch_size的倍数,则会丢失一些补丁。不过,您可以对此进行编码,并返回小于batch_size的最终批次。

或者,您可以使用迭代器:

class MyIterator(Iterator):
    def __init__(self, x, y, batch_size, shuffle=True, seed=None):
        self.x = x
        self.y = y
        super(MyIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)

    def _get_batches_of_transformed_samples(self, index_array):
        return self.x[index_array], self.y[index_array]

然后开始这样的训练:

    train_iterator = MyIterator(x_train, y_train, batch_size)
    val_iterator = MyIterator(x_val, y_val, batch_size)
    model.fit_generator(generator=iterator, 
                        steps_per_epoch=len(train_iterator),  
                        validation_data=val_iterator, 
                        validation_steps=len(val_iterator),
                        epochs=20)

迭代器将为您处理其余的补丁程序,从而有效地创建最后一个小于batch_size的批次。

编辑:在https://github.com/keras-team/keras/issues/2389进行了讨论之后,编写自己的自定义生成器时对数据进行混洗非常重要。