恢复检查点失败:在检查点中找不到密钥

时间:2019-03-22 11:43:26

标签: tensorflow

我能够成功训练RNN,并在Tensorboard中看到准确性/损失。问题是,当我尝试从检查点文件加载模型时,出现以下错误:

Key fully_connected/Variable not found in checkpoint
     [[node save/RestoreV2 (defined at train.py:87)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

这是导致此问题的代码(我省略了我认为不相关的部分):

tf.reset_default_graph()

with tf.name_scope('input_data'):
    input_data = tf.placeholder(tf.int32, [batchSize, maxSeqLength])

with tf.name_scope('labels'):
    labels = tf.placeholder(tf.float32, [batchSize, numClasses])

with tf.name_scope('embeddings'):
    data = tf.nn.embedding_lookup(wordVectors, input_data)

with tf.name_scope('lstm_layer'):
    lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits)
    lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75)

with tf.name_scope('rnn_feed_forward'):
    value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)

with tf.name_scope('fully_connected'):
    weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))
    bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))

with tf.name_scope('predictions'):
    value = tf.transpose(value, [1, 0, 2])
    last = tf.gather(value, int(value.get_shape()[0]) - 1)
    prediction = (tf.matmul(last, weight) + bias)

with tf.name_scope('accuracy'):
    correctPred = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))
    accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))

with tf.name_scope('cost'):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))

with tf.name_scope('train'):
    optimizer = tf.train.AdamOptimizer().minimize(loss)

merged = tf.summary.merge_all()

saver = tf.train.Saver() # Saving and loading

# Train the model
print('Training has begun.')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    tf.summary.scalar('Loss', loss)
    tf.summary.scalar('Accuracy', accuracy)

    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter(logdir, sess.graph)

    for i in range(iterations):
        nextBatch, nextBatchLabels = get_train_batch();
        sess.run(optimizer, { input_data: nextBatch, labels: nextBatchLabels })

        if (i % 50 == 0):
            print('Entering iteration ' + str(i))
            summary = sess.run(merged, {input_data: nextBatch, labels: nextBatchLabels})
            writer.add_summary(summary, i)

        if (i % 10000 == 0 and i != 0):
            save_path = saver.save(sess, modelsDir, global_step=i)
            print('Saved to %s' % save_path)
    writer.close()

我的想法是,当我将优化器添加到sess.run(optimizer...的会话中时,实际上是在将所有变量及其依赖变量添加到图中。

尽管键“ fully_connected”是一个name_scope,但是我对如何随机地忽略它感到有些困惑。

详细信息

chkp.print_tensors_in_checkpoint_file("./models/pretrained_lstm.ckpt-10000", tensor_name='', all_tensors=True)命令的输出为我提供了一些名称不是很有用的变量:

Variable
Variable_1
Variable_1/Adam
Variable_1/Adam_1
etc.

现在我想知道这是否与我没有明确命名变量有关?现在尝试。

问题

  1. 有更多经验的人能发现我做错了吗?你能启发我吗?

  2. 作为一个开放性问题,除了Tensorboard之外(由于它实际上没有读取检查点文件而无法帮助我解决此问题)还有哪些工具,您会建议检查会话和图形吗? p>

1 个答案:

答案 0 :(得分:0)

没有错。 Adam优化器结合了AdaGrad(自适应梯度)和RMSProp(均方根传播)技术。后者跟踪每个参数的学习率,这些学习率根据当前梯度的移动平均值进行调整。

对于恢复模型而言,重要的是该算法使用梯度和平方梯度的EMA,因此将控制衰减率的内部变量beta1和beta2添加到了层中。

您无法通过将它们排除在字典之外而无法恢复这些特殊变量,可以将其传递给saver.restore 您可以创建此字典

vars_to_restore = [i[0] for i in tf.train.list_variables(file.ckpt)]
restore_dict = {variable.op.name: variable for variable in tf.global_variables() if variable.op.name in vars_to_restore}

然后您只需要初始化adam变量

tf.variables_initializer(optimizer.variables())

您可以使用此简单功能来检查检查点和当前图中的变量/作用域名称。