Tensorflow r1.4估算器保存/加载

时间:2017-12-08 21:04:06

标签: python tensorflow

我正在尝试加载以前保存的Tensor DNNRegressor来预测新数据,以使用估算器类API r1.4验证保存/加载过程。问题可能在于我如何使用加载器或我如何保存模型。数据是来自csv的实数,用于未来未知输入大小的训练,所以我没有在feature_columns中使用过键

培训

train_inputs =numpy.float32(dftrain.values)
train_outputs =numpy.float32(dftest.values)

feature_columns=[tf.feature_column.numeric_column("XX", shape=[103])]

estimator=tf.estimator.DNNRegressor(feature_columns=feature_columns, hidden_units=[20])

input_fn_train = tf.estimator.inputs.numpy_input_fn(
    x={"XX": numpy.array(train_inputs)}, y=numpy.array(train_outputs),
    batch_size=batch_size, num_epochs=None, shuffle=False)

estimator.train(input_fn_train, steps=num_steps)

保存

feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)

export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

export_dir = estimator.export_savedmodel(output_directory, export_input_fn)

加载和新预测

predict_fn = predictor.from_saved_model(load_folder)
train_X =numpy.float32(dftrain.values)

预测第一行

traininput=train_X[0]

predict_input_fn_test = tf.estimator.inputs.numpy_input_fn(
x={"XX": numpy.array(traininput)},
num_epochs=1, shuffle=False)

predictions = predict_fn(predict_input_fn_test)

给出错误

input_keys = set(input_dict.keys()) AttributeError:'功能'对象没有属性'键'

0 个答案:

没有答案