如何在Python中保存/恢复tf.estimator.DNNClassifier?

时间:2020-03-24 12:53:40

标签: python tensorflow model save restore

我有一个模型(在Tensorflow 1.3.0中)将数据分类为:

input_func = tf.estimator.inputs.pandas_input_fn(x=X_train, y=y_train, batch_size=100, num_epochs=None, shuffle=True)

#Create the model
model = tf.estimator.DNNClassifier(feature_columns=feat_cols, hidden_units=[10, 10], n_classes=2)

#Train the model
model.train(input_fn=input_func, steps=5000)

虽然我没有使用tf.Session()运行我的代码,但不能使用saver = tf.train.Saver()保存我的模型。因此,如何保存我的模型以在以后还原它?

我看到也许我应该使用

estimator = tf.estimator.Estimator(model_fn, 'model', params={})
estimator.export_saved_model('saved_model', serving_input_receiver_fn)

但是,我不知道应该在serving_input_receiver_fn,model_fn中输入什么?如何定义保存模型的目录? 有想法吗?

0 个答案:

没有答案