Tensorflow服务客户端批处理

时间:2018-08-14 17:45:52

标签: python tensorflow tensorflow-serving object-detection-api batching

Python:3.6.6

Tensorflow:1.10.0

Tensorflow服务:1.10.0

我已经看到多个示例(例如How to do batching in Tensorflow Serving?),这些示例使用以下代码来解决此问题:

# Create Stub
channel = grpc.insecure_channel(FLAGS.server)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)

image_data = []
for image in FLAGS.input_image.split(','):
    with open(image, 'rb') as f:
        image_data.append(f.read())

# Create prediction request object
request = predict_pb2.PredictRequest()

# Specify model name (must be same as when TF server started)
request.model_spec.name = 'inference'

# Initialize prediction
# Specify signature name should be same as specified when exporting model)
request.model_spec.signature_name = 'detection_signature'

request.inputs['inputs'].CopyFrom(
    make_tensor_proto(image_data, shape=[len(image_data)])
)

# Call the prediction server
results = stub.Predict(request, 10.0) # 10 secs timeout

但是,出现以下错误:

Traceback (most recent call last):
  File "client_batch.py", line 64, in <module>
    results = stub.Predict(request, 10.0) # 10 secs timeout
  File "/path/to/python3.6/site-packages/grpc/_channel.py", line 514, in __call__
    return _end_unary_response_blocking(state, call, False, None)
  File "/path/to/w/lib/python3.6/site-packages/grpc/_channel.py", line 448, in _end_unary_response_blocking
    raise _Rendezvous(state, None, None, deadline)
grpc._channel._Rendezvous: <_Rendezvous of RPC that terminated with:
	status = StatusCode.INVALID_ARGUMENT
	details = "Expects arg[0] to be uint8 but string is provided"
	debug_error_string = "{"created":"@1534265330.005987356","description":"Error received from peer","file":"src/core/lib/surface/call.cc","file_line":1095,"grpc_message":"Expects arg[0] to be uint8 but string is provided","grpc_status":3}"

在阅读make_tensor_proto之后,我发现make_tensor_proto接受python标量,python列表,numpy ndarray或numpy标量的“值”。因此,在此示例中,当前版本似乎应支持字符串标量。

我能够通过替换

使代码适用于非批处理输入
make_tensor_proto(image_data, shape=[len(image_data)])

使用

make_tensor_proto(scipy.misc.imread(FLAGS.input_image), shape=[1] + list(img.shape))

指定传递ndarray和输入的确切形状。

但是,这似乎无法针对多个阵列进行扩展。您最终会收到如下错误:

got shape [2,1120, 1152, 3], but wanted [2]

在新版本的tensorflow中是否有新方法可以做到这一点?也许我显然做错了什么。

0 个答案:

没有答案