加载预训练的模型以在更改优化器时进行训练

时间:2018-11-15 07:12:58

标签: python tensorflow optimization

简而言之,当我恢复预训练的模型时,我想将优化器更改为AdamOptimizer以进行进一步的训练。但是,令我惊讶的是,它出现了如下所示的错误:

  

NotFoundError(请参阅上面的回溯):在检查点中找不到键 beta1_power

 [[Node: save_1/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2/tensor_names, save_1/RestoreV2/shape_and_slices)]]

我只是假设可以像tf.get_variable一样将相应变量自动添加到计算图中,而无需人工干预。

我使用的代码如下:

# 0. only 1 gpu
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 1. define global parameters

args = get_parser()
global_step = tf.Variable(name='global_step', initial_value=0, trainable=False)
inc_op = tf.assign_add(global_step, 1, name='increment_global_step')
images = tf.placeholder(name='img_inputs', shape=[None, *args.image_size, 3], dtype=tf.float32)
labels = tf.placeholder(name='img_labels', shape=[None, ], dtype=tf.int64)
dropout_rate = tf.placeholder(name='dropout_rate', dtype=tf.float32)

# 2 prepare train datasets and test datasets by using tensorflow dataset api
# 2.1 train datasets 

tfrecords_f = os.path.join(args.tfrecords_file_path, 'tran_asia.tfrecords')
dataset = tf.data.TFRecordDataset(tfrecords_f)
dataset = dataset.map(parse_function)
dataset = dataset.shuffle(buffer_size=args.buffer_size)
dataset = dataset.batch(args.batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# 3. define network, loss, optimize method, learning rate schedule, summary writer, saver
# 3.1 inference phase

w_init_method = tf.contrib.layers.xavier_initializer(uniform=False)
net = get_resnet(...)

# 3.2 loss

logit = self_define_loss(embedding=net.outputs, labels=labels, w_init=w_init_method, out_num=args.num_output)
...
# 3.3 calculate loss

infer_loss = ...

# 3.4 optimizer(change after pretrained)

# stage1
# opt = tf.train.MomentumOptimizer(learning_rate=lr, momentum=args.momentum)
# stage2
opt = tf.train.AdamOptimizer(learning_rate=lr)

# 3.5 get train op

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = opt.apply_gradients(grads, global_step=global_step)

# 4.restore stage1 model 

# 4.1 saver

saver = tf.train.Saver(max_to_keep=10)

# 4.2 init all variables

sess.run(tf.global_variables_initializer())

# 4.3 restore stage1 model and change optimizer to do further training!

restore_saver = tf.train.Saver()
restore_saver.restore(sess, 'xxx.ckpt')


# Omit training part
...

我使用的Tensorflow版本是1.7.0,非常感谢您的帮助,谢谢!

0 个答案:

没有答案