Tensorflow Estimator预测签名和保存/加载

时间:2019-02-20 15:46:37

标签: python tensorflow keras

我有一个从this file加载的预先训练的Keras模型。输入形状为(128,513,1)。

我需要将此模型导出到Estimator以便进行部署。但是从保存的文件加载后无法使用它-出于某种原因,它似乎需要输入形状(-1,)

保存代码:

classifier = keras.estimator.model_to_estimator(keras.models.load_model(<SAVED MODEL PATH>)

feature_spec = {
    'input_1': tf.FixedLenFeature(dtype=tf.float32, shape=(128, 513, 1))
}
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

exported_model = classifier.export_savedmodel(
    export_dir_base = 'export/Servo/', 
    serving_input_receiver_fn = serving_input_fn)

with tarfile.open('model.tar.gz', mode='w:gz') as archive:
    archive.add('export', recursive=True)

加载代码

estimator_predict_fn = tf.contrib.predictor.from_saved_model(exported_model)

estimator_predict_fn({'examples': np.zeros((1, 128, 513, 1))})

投掷

  

INFO:tensorflow:从中还原参数   出口/伺服/ 1550677087 /变量/变量INFO:tensorflow:恢复   导出/伺服/ 1550677087 /变量/变量中的参数   -------------------------------------------------- ------------------------- ValueError追踪(最近的呼叫   最后)在()         1 estimator_predict_fn = tf.contrib.predictor.from_saved_model(exported_model)         2   ----> 3 estimator_predict_fn({'examples':np.zeros((1,128,513,1))})

     

〜/ anaconda3 / envs / tensorflow_p36 / lib / python3.6 / site-packages / tensorflow / contrib / predictor / predictor.py   在通话中(自己,input_dict)        75,如果值不为None:        76 feed_dict [self.feed_tensors [key]] =值   ---> 77返回self._session.run(fetches = self.fetch_tensors,feed_dict = feed_dict)

     

〜/ anaconda3 / envs / tensorflow_p36 / lib / python3.6 / site-packages / tensorflow / python / client / session.py   在运行中(自我,获取,feed_dict,选项,run_metadata)       875尝试:       第876章   -> 877 run_metadata_ptr)       第878章真相       879 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

     

〜/ anaconda3 / envs / tensorflow_p36 / lib / python3.6 / site-packages / tensorflow / python / client / session.py   在_run(自身,句柄,访存,feed_dict,选项,run_metadata)中
  1074'形状为%r'%1075
  (np_val.shape,subfeed_t.name,   -> 1076 str(subfeed_t.get_shape())))1077(如果不是self.graph.is_feedable(subfeed_t):1078
  引发ValueError('Tensor%s可能无法馈入。'%subfeed_t)

     

ValueError:无法为Tensor输入形状(1,128,513,1)的值   'input_example_tensor:0',其形状为'(?,)'

让我感到困惑的是,保存时,日志中有这样一行:

  

INFO:tensorflow:导出中包含的Predicts签名:   ['serving_default']

那是什么意思?为什么它不使用来自serve_input_fn的东西呢?

如何更改?

0 个答案:

没有答案
相关问题