C ++ Tensorflow中的DNNClassifier预测

时间:2018-07-25 22:13:19

标签: c++ tensorflow tensorflow-estimator

我在Python tensorflow中创建了DNNClassifier对象,并将其保存到protobuf文件中:

dnn_classifier = tf.estimator.DNNClassifier(...)
dnn_classifier.train(...)
dnn_classifier.export_savedmodel('saved_model.pb',...)

我需要将所述模型导入C ++,并对新的数据向量进行预测。我已经编译了tensorflow库,并确认它可以与测试代码一起使用。

我以这种方式加载了模型:

std::unique_ptr<tensorflow::Session>* session
SavedModelBundle* const bundle = new SavedModelBundle();

LoadSavedModel(SessionOptions(), RunOptions(), "saved_model.pb", {"serve"}, bundle);
session->reset(bundle->session.get());

我有12个输入功能(例如:A,B,C,...,L),并且需要基于这些功能进行类别预测。

Tensor A(DT_FLOAT, TensorShape()); A.scalar<float>()() = 1.4;
Tensor B(DT_FLOAT, TensorShape()); B.scalar<float>()() = -0.6;
...
Tensor L(DT_FLOAT, TensorShape()); L.scalar<float>()() = 2.6;

const std::vector<std::pair<string, Tensor> >& inputs = {{"A",A},{"B",B}...,{"L",L}}
const std::vector<string>& output_tensor_names = {"A","B",...,"L"}
std::vector<Tensor> outputs;

在这里,我认为应该进行预测:

session->Run(inputs, output_tensor_names, {}, &outputs);

此代码可以运行,但我认为它没有达到我想要的目的。

有更好的方法吗? SavedModelBundle还具有一个MetaGraphDef数据成员,该成员可能有用,但我不知道如何从中进行预测。帮助吗?

0 个答案:

没有答案
相关问题