使用export_savedmodel导出KMeans模型以在ml-engine上部署

时间:2017-10-19 05:43:42

标签: python-2.7 csv tensorflow k-means google-cloud-ml-engine

我正在使用tensorflow.contrib.learn.KMeansClustering进行K-means聚类。

我可以使用默认模型来预测本地,但由于我想使用ml-engine在线预测,我必须将其导出为ex​​port_savedmodel格式。

我有谷歌的地方,但由于KMeansClustering类不需要功能列,所以我不知道如何为export_savedmodel构建正确的serving_input_fn

这是我的代码

# Generate input_fn
def gen_input(data):
    return tf.constant(data.as_matrix(), tf.float32, data.shape), None

# Declare dataset + export model path
TRAIN = 'train.csv'
MODEL = 'model'

# Read dataset
body = pd.read_csv(
    file_io.FileIO(TRAIN, mode='r'),
    delimiter=',',
    header=None,
    engine='python'
)

# Declare K-Means
km = KMeansClustering(
    num_clusters=2,
    model_dir=MODEL,
    relative_tolerance=0.1
)

est = km.fit(input_fn=lambda: gen_input(body))

# This place is where I stuck
fcols = [tf.contrib.layers.real_valued_column('x', dimension=5)]
fspec = tf.contrib.layers.create_feature_spec_for_parsing(fcols)
serving_input_fn = tf.contrib.learn.python.learn.\
                   utils.input_fn_utils.build_parsing_serving_input_fn(fspec)
est.export_savedmodel(MODEL, serving_input_fn)

这是我的玩具train.csv

1,2,3,4,5
2,3,4,5,6
3,4,5,6,7
5,4,3,2,1
7,6,5,4,3
8,7,6,5,4

导出的模型的格式为saved_model.pb及其变量文件夹

将模型部署到ml-engine是成功的,但是当使用相同的train.csv进行预测时,我收到以下错误

{"error": "Prediction failed: Exception during model execution: AbortionError(code=StatusCode.INVALID_ARGUMENT, details=\"Name: <unknown>, Feature: x (data type: float) is required but could not be found.\n\t [[Node: ParseExample/ParseExample = ParseExample[Ndense=1, Nsparse=0, Tdense=[DT_FLOAT], _output_shapes=-1,5, dense_shapes=5, sparse_types=[], _device=\"/job:localhost/replica:0/task:0/cpu:0\"](_arg_input_example_tensor_0_0, ParseExample/ParseExample/names, ParseExample/ParseExample/dense_keys_0, ParseExample/Const)]]\")"}

我一直在努力解决这个问题,而我找到的所有文档都是针对纯API的

我期待你的建议

提前致谢

1 个答案:

答案 0 :(得分:1)

人口普查示例shows如何为CSV设置serving_input_fn。根据您的示例调整:

    A1 =  (A['wap'][i]-A['wap'][i%len(A)+1]) + (A['wap'][i%len(A)+1])

TensorFlow 1.4将至少简化其中的一部分。

另外,请考虑使用JSON,因为这是更标准的服务方法。很高兴根据要求提供详细信息。