Tensoflow估算器:如何使用tf.graph_util.convert_variables_to_constants

时间:2018-07-18 08:32:09

标签: python tensorflow tensorflow-estimator

我想知道是否有可能在训练/评估循环中使用函数 tf.graph_util.convert_variables_to_constants (以存储图形的冻结版本),而我m使用自定义估算器。例如:

best_validation_accuracy = -1
for _ in range(steps // how_often_validation):

    # Train the model
    estimator.train(input_fn=train_input_fn, steps=how_often_validation)

    # Evaluate the model
    validation_accuracy = estimator.evaluate(input_fn=eval_input_fn)

    # Save best model
    if validation_accuracy["accuracy"] > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy["accuracy"]
        # Save best model perfomances
        # I WANT TO USE  tf.graph_util.convert_variables_to_constants HERE

1 个答案:

答案 0 :(得分:0)

要使用功能tf.graph_util.convert_variables_to_constants,需要图形和模型会话。

经过TensorFlow code defining the estimators后,看起来:

  • 此代码已弃用,
  • 该图是动态创建的,并且不容易访问(至少,我无法检索它)。

因此,我们将不得不使用良好的旧方法。

调用estimator.train时,模型的检查点将保存在指定目录(estimator.model_dir)中。您可以使用这些文件来访问图形和会话并冻结变量,如下所示:

1。加载元图

saver = tf.train.import_meta_graph('/path/to/meta')

2。负载重量

sess = tf.Session
saver.restore(sess, '/path/to/weights')

3。冻结变量

tf.graph_util.convert_variables_to_constants(sess,
                                             sess.graph.as_graph_def(),
                                             ['output'])