Keras:处理线程和大型数据集

时间:2017-10-24 14:53:53

标签: python multithreading dataset deep-learning keras

我正在尝试处理Keras的大型训练数据集。

我将model.fit_generator与自定义生成器一起使用,该生成器从SQL文件中读取数据。

我收到一条错误消息,告诉我我不能在两个不同的线程中使用SQLite对象:

ProgrammingError: SQLite objects created in a thread can only be used in that 
same thread.The object was created in thread id 140736714019776 and this is 
thread id 123145449209856

我尝试对HDF5文件执行相同的操作,并遇到了一个分段错误,我现在认为它也与fit_generator的多线程字符有关(请参阅错误报告here)。

使用这些生成器的正确方法是什么,因为我认为必须从文件中批量读取不适合内存的数据集的数据。

以下是生成器的代码:

class DataGenerator:
    def __init__(self, inputfile, batch_size, **kwargs):
        self.inputfile = inputfile
        self.batch_size = batch_size

    def generate(self, labels, idlist):
        while 1:
            for batch in self._read_data_from_hdf(idlist):
                batch = pandas.merge(batch, labels, how='left', on=['id'])
                Y = batch['label']
                X = batch.drop(['id', 'label'], axis=1)
                yield (X, Y)    

    def _read_data_from_hdf(self, idlist):
        chunklist = [idlist[i:i + self.batch_size] for i in range(0, len(idlist), self.batch_size)]
        for chunk in chunklist:
            yield pandas.read_hdf(self.inputfile, key='data', where='id in {}'.format(chunk))

# [...]

model.fit_generator(generator=training_generator,
                    steps_per_epoch=len(partitions['train']) // 
                    config['batch_size'],
                    validation_data=validation_generator,
                    validation_steps=len(partitions['validation']) // 
                    config['batch_size'],
                    epochs=config['epochs'])

请参阅full example repository here

感谢您的支持。

干杯,

1 个答案:

答案 0 :(得分:1)

面对同样的问题,我通过将线程安全装饰器与sqlalchemy引擎相结合来找出解决方案,该引擎可以管理对数据库的并发访问:

import pandas
from sqlalchemy import create_engine

class threadsafe_iter:
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)


def threadsafe_generator(f):
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))
    return g


class DataGenerator:
    def __init__(self, inputfile, batch_size, **kwargs):
        self.inputfile = inputfile
        self.batch_size = batch_size
        self.sqlengine = create_engine('sqlite:///' + self.inputfile)

    def __del__(self):
        self.sqlengine.dispose()

    @threadsafe_generator
    def generate(self, labels, idlist):
        while 1:
            for batch in self._read_data_from_sql(idlist):
                Y = batch['label']
                X = batch.drop(['id', 'label'], axis=1)
                yield (X, Y)

    def _read_data_from_sql(self, idlist):
        chunklist = [idlist[i:i + self.batch_size]
                     for i in range(0, len(idlist), self.batch_size)]
        for chunk in chunklist:
            query = 'select * from data where id in {}'.format(tuple(chunk))
            df = pandas.read_sql(query, self.sqlengine)
            yield df

# Build keras model and instantiate generators

model.fit_generator(generator=training_generator,
                    steps_per_epoch=train_steps,
                    validation_data=validation_generator,
                    validation_steps=valid_steps,
                    epochs=10,
                    workers=4)

我希望有所帮助!