使用predict_generator和VGG16的内存错误

时间:2018-07-23 13:48:29

标签: python memory machine-learning keras computer-vision

我正在尝试在我自己的数据集上应用转移学习,该数据集存在于33.000个训练图像中(共1,4GB)。在Keras(2.2.0)中使用predict_generator进行预测时,遇到了内存错误。当查看我的任务管理器时,我可以看到内存正在缓慢工作,直到我的Tesla K80(1GPU)的5GB最大VRAM。我正在使用以下代码:

#Train
print('train dataset:')
datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
train_generator = datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode=None,
    shuffle=False)

num_classes = len(train_generator.class_indices)
nb_train_samples = len(train_generator.filenames)
predict_size_train = int(math.ceil(nb_train_samples / batch_size))
VGG16_bottleneck_features_train = model.predict_generator(train_generator, predict_size_train, verbose=1)
np.save('XVGG16_bottleneck_features_train.npy', VGG16_bottleneck_features_train)

我尝试了很多事情,但似乎无法使其成功。我已经阅读了许多建议使用批处理的解决方案,但我认为我的预报生成器已经在接收批处理形式的数据了吗?这里是否有人可以验证这对于我的系统不起作用,还是有其他解决方案?

1 个答案:

答案 0 :(得分:0)

推理所需的内存与您尝试同时通过网络的图像数量(即批大小)成比例。您可以尝试较小的批量,直到运行为止。批次大小越小,生成器将需要产生更多的批次来传递整个数据集(代码中的predict_size_train)。

相关问题