一开始的培训准确度很高,但验证准确度较低且剂量不变

时间:2019-07-01 16:31:10

标签: tensorflow

我尝试编写我的第一个tensorflow代码,但是它具有很高的训练精度,但验证精度很低,并且剂量不变。谁能帮我吗?

我尝试使用Inception v3 net进行分类以对4种图像进行分类。我使用了keras应用程序的初始网络。

train_files = tf.data.Dataset.list_files(r"\train\train*", seed = 36)
validation_files = tf.data.Dataset.list_files(r"\test\test*", seed = 8)
standard_image_shape = tf.stack([299, 299, 3])
batch_size = 64
num_classes = 4
epochs = 2000
train_steps_per_epoch = 6500
vali_steps_per_epoch = 735

def train_preprocess(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image=image, max_delta=32/255)
    image = tf.image.random_contrast(image, 0.5, 1.5)
    image = tf.image.random_hue(image, 0.2)
    return image

def decode_example(example_proto):
    image_feature_description = {
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/colorspace': tf.io.FixedLenFeature([], tf.string),
    'image/channels': tf.io.FixedLenFeature([], tf.int64),
    'image/class/label': tf.io.FixedLenFeature([], tf.int64),
    'image/class/raw': tf.io.FixedLenFeature([], tf.int64),
    'image/class/source': tf.io.FixedLenFeature([], tf.int64),
    'image/class/text': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string),
    'image/filename': tf.io.FixedLenFeature([], tf.string),
    'image/id': tf.io.FixedLenFeature([], tf.int64),
    'image/encoded': tf.io.FixedLenFeature([], tf.string),

}         parsed_features = tf.io.parse_single_example(example_proto,image_feature_description)         高度= tf.cast(parsed_features ['image / height'],tf.int32)         宽度= tf.cast(parsed_features ['image / width'],tf.int32)         频道= tf.cast(parsed_features ['image / channels'],tf.int32)         label = tf.cast(parsed_features ['image / class / label'],tf.int32)         image_buffer = parsed_features ['image / encoded']         图片= tf.io.decode_jpeg(图片缓冲区,频道= 3)         图片= tf.cast(图片,tf.float32)/ 255         图片= tf.image.resize(图片[299,299])         返回图片,标签

def decode_train_example(example_proto):
    image_feature_description = {
    'image/height': tf.io.FixedLenFeature([], tf.int64),
    'image/width': tf.io.FixedLenFeature([], tf.int64),
    'image/colorspace': tf.io.FixedLenFeature([], tf.string),
    'image/channels': tf.io.FixedLenFeature([], tf.int64),
    'image/class/label': tf.io.FixedLenFeature([], tf.int64),
    'image/class/raw': tf.io.FixedLenFeature([], tf.int64),
    'image/class/source': tf.io.FixedLenFeature([], tf.int64),
    'image/class/text': tf.io.FixedLenFeature([], tf.string),
    'image/format': tf.io.FixedLenFeature([], tf.string),
    'image/filename': tf.io.FixedLenFeature([], tf.string),
    'image/id': tf.io.FixedLenFeature([], tf.int64),
    'image/encoded': tf.io.FixedLenFeature([], tf.string),

}         parsed_features = tf.io.parse_single_example(example_proto,image_feature_description)         高度= tf.cast(parsed_features ['image / height'],tf.int32)         宽度= tf.cast(parsed_features ['image / width'],tf.int32)         频道= tf.cast(parsed_features ['image / channels'],tf.int32)         label = tf.cast(parsed_features ['image / class / label'],tf.int32)         image_buffer = parsed_features ['image / encoded']         图片= tf.io.decode_jpeg(图片缓冲区,频道= 3)         图片= tf.cast(图片,tf.float32)/ 255         图片= tf.image.central_crop(图片,0.7)         图片= tf.image.resize(图片[299,299])         图片= train_preprocess(图片)         返回图片,标签

def processed_dataset(filenames):
    dataset = filenames.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=46)
    dataset = dataset.map(decode_example)
    dataset = dataset.shuffle(buffer_size=20000, seed=6)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    return dataset

def processed_train_dataset(filenames):
    dataset = filenames.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=407)
    dataset = dataset.map(decode_train_example)
    dataset = dataset.shuffle(buffer_size=40000, seed=6)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    return dataset

def data_generator(dataset):
    iter = tf.compat.v1.data.make_one_shot_iterator(dataset)
    batch = iter.get_next()
    while True:
        yield tuple(batch)

train_data = processed_train_dataset(train_files)
validation_data = processed_dataset(validation_files)

model = tf.keras.applications.inception_v3.InceptionV3(include_top=True,     weights=None, input_tensor=None, input_shape=None, pooling=None, classes=4)

def step_decay(epoch):
    initial_lrate = 0.015
    drop = 0.94
    epochs_drop = 2.0
    lrate = initial_lrate * math.pow(drop,  
       math.floor((1+epoch)/epochs_drop))
    return lrate

model.compile(loss='sparse_categorical_crossentropy',
          optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.015, momentum=0.9, epsilon=0.1, decay=0.9, rho=0.9),
          metrics=['accuracy'])

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("classification_model_crop.h5",                                          
                                               save_best_only=True)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience = 10,
                                                  restore_best_weights = True)

lrate = tf.keras.callbacks.LearningRateScheduler(step_decay)

history = model.fit_generator(
    data_generator(train_data),
    steps_per_epoch=train_steps_per_epoch,
    epochs=epochs,
    verbose=1,
    callbacks=[lrate, checkpoint_cb, early_stopping_cb],
    validation_data=data_generator(validation_data),
    validation_steps=vali_steps_per_epoch,
    workers = 0  # runs generator on the main thread
)

它可以运行,但结果非常糟糕。 时代1/2000 6500/6500 [==============================]-5546s 853ms / step-损耗:0.5267-精度:0.9660-val_loss :2.7306-val_accuracy:0.4531 时代2/2000 6500/6500 [==============================]-3088s 475ms / step-损耗:0.5115-精度:0.9688-val_loss :2.7646-val_accuracy:0.4531 时代3/2000 6500/6500 [==============================]-3093s 476ms / step-损耗:0.5106-精度:0.9688-val_loss :2.7780-val_accuracy:0.4531 时代4/2000 6500/6500 [==============================]-3086s 475ms / step-损耗:0.5101-精度:0.9688-val_loss :2.7845-val_accuracy:0.4531 时代5/2000 6500/6500 [==============================]-3088s 475ms / step-损耗:0.5099-精度:0.9688-val_loss :2.7887-val_accuracy:0.4531 时代6/2000 6500/6500 [==============================]-3087s 475ms / step-损耗:0.5097-精度:0.9688-val_loss :2.7913-val_accuracy:0.4531 时代7/2000 6500/6500 [==============================]-3089s 475ms / step-损耗:0.5096-精度:0.9688-val_loss :2.7934-val_accuracy:0.4531 时代8/2000 6500/6500 [==============================]-3087s 475ms / step-损耗:0.5095-精度:0.9688-val_loss :2.7948-val_accuracy:0.4531 时代9/2000 6500/6500 [==============================]-3090s 475ms / step-损耗:0.5094-精度:0.9688-val_loss :2.7961-val_accuracy:0.4531 时代10/2000 6500/6500 [==============================]-3086s 475ms / step-损耗:0.5094-精度:0.9688-val_loss :2.7971-val_accuracy:0.4531 时代11/2000 6500/6500 [==============================]-3085s 475ms / step-损耗:0.5093-精度:0.9688-val_loss :2.7978-val_accuracy:0.4531

0 个答案:

没有答案