使用生成器导致内存错误不足

时间:2017-03-30 19:21:32

标签: python generator keras

我在python中有以下生成器。它接收的样本是一个由x行组成的数组,其中每行看起来像这样

['some\\path\\center.jpg', 'some\\path\\left.jpg', 'some\\path\\right.jpg', 'someNumber']

def generator(samples, batch_size):
num_samples = len(samples)
while 1: # Loop forever so the generator never terminates
    shuffle(samples)
    for offset in range(0, num_samples, batch_size):
        batch_samples = samples[offset:offset+batch_size]

        for batch_sample in batch_samples:

            source_path = np.random.choice([batch_sample[0], batch_sample[1], batch_sample[2]])
            filename = source_path.split(os.sep)[-1]
            current_path = 'data20/IMG/' + filename
            current_image = cv2.imread(current_path)
            current_angle = float(batch_sample[3])

            if source_path == line[1]:
                current_angle += 0.2

            elif source_path == line[2]:
                current_angle -= 0.2

            images.append(current_image)
            angles.append(current_angle)

        # trim image to only see section with road
        X_train = np.array(images)
        y_train = np.array(angles)
        yield (X_train, y_train)

然后我有一些CNN,我使用具有给定功能的发生器训练网络

model.fit_generator(train_generator, samples_per_epoch=len(train_samples), validation_data=validation_generator, nb_val_samples=len(validation_samples), nb_epoch=5)

我的问题:为什么培训期间的时间会增加,因此需要更长的时间...导致批量大小的内存不足让我们说32?

1 个答案:

答案 0 :(得分:0)

如Dref360所述,我不清楚发生器内的图像和角度。主要问题是我把

images = []
angles = []

在我的发电机之外,这导致每个周期都在增长,它们应该在我的发电机内部。以下是工作代码:

['some\\path\\center.jpg', 'some\\path\\left.jpg', 'some\\path\\right.jpg', 'someNumber']

def generator(samples, batch_size):
num_samples = len(samples)
while 1: # Loop forever so the generator never terminates
    shuffle(samples)
    for offset in range(0, num_samples, batch_size):
        batch_samples = samples[offset:offset+batch_size]

        images = []
        angles = []
        for batch_sample in batch_samples:

            source_path = np.random.choice([batch_sample[0], batch_sample[1], batch_sample[2]])
            filename = source_path.split(os.sep)[-1]
            current_path = 'data20/IMG/' + filename
            current_image = cv2.imread(current_path)
            current_angle = float(batch_sample[3])

            if source_path == line[1]:
                current_angle += 0.2

            elif source_path == line[2]:
                current_angle -= 0.2

            images.append(current_image)
            angles.append(current_angle)

        # trim image to only see section with road
        X_train = np.array(images)
        y_train = np.array(angles)
        yield (X_train, y_train)

非常感谢 Dref360 :D

相关问题