使用fit_generator和TFrecords训练Keras模型

时间:2019-03-19 14:11:56

标签: python tensorflow keras conv-neural-network

我想和Keras一起训练我的ConvNet。在完成一些教程之后,我写了这样的内容。

我不知道它是否很好,特别是我对使用生成器训练模型的方法有一些疑问。

之前,我用numpy-array生成器喂训练过程,但是我读到可以使用tfrecords来提高性能。

我第一次在下面的create_dataset函数中(在“屈服”它们之前)将张量转换为numpy数组,但后来我读到了

  

确实有一种更有效的方式来使用数据集,而无需   将张量转换为numpy数组。

所以我试图用这种方式编辑我的代码 input_image=tf.keras.Input(tensor=x)model.compile(optimizer=optimizer, loss=compute_loss, target_tensors=[y])

在我既未使用target_tensors内的model.compile也不使用tensor=x内的tf.keras.Input之前(我只指定了输入形状)。

import tensorflow as tf
import keras
import compute_loss #my loss function

dataset_train_path="dataset_train.tfrecords"
dataset_val_path="dataset_val.tfrecords"

filepath_checkpoint="weights-best.hdf5"


Adam=tf.keras.optimizers.Adam
optimizer = Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)

BATCH_SIZE=32
TRAINING_SIZE=5717
VALIDATION_SIZE=5823
TRAINING_STEPS=TRAINING_SIZE//BATCH_SIZE
VALIDATION_STEPS=VALIDATION_SIZE//BATCH_SIZE

"""-----------------Here I define my generator-----------------"""
def _parse_function(proto):

    keys_to_features = {'image': tf.FixedLenFeature([], tf.string),
                        'label': tf.FixedLenFeature([], tf.string)}


    parsed_features = tf.parse_single_example(proto, keys_to_features)

    parsed_features['image'] = tf.decode_raw(parsed_features['image'], tf.float16)
    parsed_features['label'] = tf.decode_raw(parsed_features['label'], tf.float16)
    return parsed_features['image'], parsed_features["label"]


def create_dataset(filepath, batch_size=BATCH_SIZE):

    dataset = tf.data.TFRecordDataset(filepath)

    dataset = dataset.map(_parse_function, num_parallel_calls=8)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(100)
    dataset = dataset.batch(BATCH_SIZE)

    iterator = dataset.make_one_shot_iterator()
    image, label = iterator.get_next()

    image = tf.reshape(image, [BATCH_SIZE, 416, 416, 3])
    label = tf.reshape(label, [BATCH_SIZE, 75, 25])

    while True:
        yield image, label

"""-----------------Here I create my train/val generators-----------------"""
training_generator=create_dataset(dataset_train_path)
validation_generator=create_dataset(dataset_val_path)

"""-----------------Now I can define my model-----------------"""
x,y=next(training_generator);
def net():
    input_image=tf.keras.Input(tensor=x)
    inputs=tf.keras.layers.Conv2D(16,3,padding='same', activation='relu', name='conv_1')(input_image)
    inputs=tf.keras.layers.BatchNormalization(name='norm_1')(inputs)
    ...
    ...
    outputs = tf.keras.layers.Conv2D(75, 1, name='conv_13')(inputs)
    model = tf.keras.Model(inputs=input_image, outputs=outputs)
    return model

if __name__ == '__main__':
    model=net()
    model.compile(optimizer=optimizer, loss=compute_loss, target_tensors=[y])
    model.fit_generator(generator=training_generator,validation_data=validation_generator, epochs=1000, max_queue_size=1000, steps_per_epoch=TRAINING_STEPS, validation_steps=VALIDATION_STEPS, callbacks=callbacks_list)

现在培训进行得很快,但我怀疑某些地方存在错误。你能帮我吗?

编辑:如果我将数据集直接放在fit_generator中,则会得到以下信息:

>>> train(model, DataGenerator, filepath_checkpoint="weights-best-tiny-test.hdf5")
Epoch 1/5000
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "tiny.py", line 239, in train
    model.fit_generator(generator=training_generator,validation_data=validation_generator, epochs=5000, max_queue_size=100, steps_per_epoch=TRAINING_STEPS, validation_steps=VALIDATION_STEPS, callbacks=callbacks_list)
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1586, in fit_generator
    steps_name='steps_per_epoch')
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training_generator.py", line 211, in model_iteration
    batch_data = _get_next_batch(output_generator, mode)
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training_generator.py", line 323, in _get_next_batch
    generator_output = next(output_generator)
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 767, in get
    six.reraise(*sys.exc_info())
  File "C:\Program Files\Python35\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 743, in get
    inputs = self.queue.get(block=True).get()
  File "C:\Program Files\Python35\lib\multiprocessing\pool.py", line 644, in get
    raise self._value
  File "C:\Program Files\Python35\lib\multiprocessing\pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 680, in next_sample
    return six.next(_SHARED_SEQUENCES[uid])
TypeError: 'Iterator' object is not an iterator

0 个答案:

没有答案