使用Keras的fit_generator

时间:2019-04-11 11:17:47

标签: python keras

问题摘要:

我正在尝试针对二进制分类任务的基本ANN模型。我可以说大数据总共2 GB 150个csv文件组成。数据由6个功能和1个目标组成。

重要说明::这是一个二进制分类任务,每个文件仅包含一个标签。例如。 file_1仅包含标签0,file_2仅包含标签1。

问题1:我使用Keras的fit_generator方法逐文件读取数据文件,并逐批读取数据。我开始训练模型,但是在每次训练结束时模型给出不同的结果。此外,有时准确性会随着时间而下降。我认为这是因为eacy文件仅包含一个标签。

问题2:我不确定我是否正确编写了data_generator方法。我需要从不同的CSV文件中获取数据。任何建议都将不胜感激。

某些代码

简单的ANN模型:

def create_model():
    model = Sequential()

    model.add(Dense(32, kernel_initializer='normal',
                    activation='relu', input_dim=(6)))
    model.add(Dropout(0.5))
    model.add(Dense(16, kernel_initializer='normal', activation='relu'))
    model.add(Dense(8, kernel_initializer='normal', activation='relu'))
    model.add(Dense(16, kernel_initializer='normal', activation='relu'))
    model.add(Dense(32, kernel_initializer='normal', activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(1, kernel_initializer='normal', activation='sigmoid'))

    model.compile(optimizer='adam', loss="binary_crossentropy",
                  metrics=['accuracy'])
    return model

数据生成器: 我正在尝试从其他CSV文件生成数据

def data_generotto(path: str, batchsize: int):
    while True:
        for csv_file in os.listdir(path):
            chunks = pd.read_csv(os.path.join(
                path, csv_file), sep=';', chunksize=batchsize)

            for i, chunk in enumerate(chunks):
                X, y = preprocess.preprocess(chunk)

                yield (X, y)

用于获取数据总大小的代码:

def get_total_size(path: str):
    for csv_file in os.listdir(path):
        global SIZE
        with open(os.path.join(path, csv_file)) as f:
            for line in f:
                SIZE += 1

            SIZE -= 1 # minus header line

主程序流程:

np.random.seed(7)

SIZE = 0
BS = 1000
EPOCHS = 5

if __name__ == "__main__":
    model = cnn.create_model()

    get_total_size("./complete_csv")
    print("size calculated")

    H = model.fit_generator(data_generotto(
        "./complete_csv", BS), steps_per_epoch=SIZE // BS, epochs=EPOCHS, workers=-1)

    save_model(model, "./MODEL.h5")

1 个答案:

答案 0 :(得分:0)

对不起,我对您的问题有误解。现在,我对您的任务有了一些想法:

  1. 神经网络很好地解决了高维问题,但是您的数据只有6个特征,对于神经网络来说太短了。
  2. 也许您可以尝试一些机器学习方法,例如决策树,SVM和一些提升方法。我认为这些方法将适合您的任务。