使用带有c ++的tensorflow预训练模型

时间:2017-09-10 08:49:52

标签: python c++ tensorflow tensorflow-serving

我已经训练了一个带有tensorflow的GAN,现在我想在我的c ++项目中使用它。 我的GAN是这样的(输入和输出都是图像):

image = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 3, SIZE, SIZE])  
input = 2*((tf.cast(image, tf.float32)/255.)-.5) #0~255  to -1~1
output = GAN(input) #GAN is my network including many modules

我注意到有一个saved_model工具可以将我的模型保存到saved_model.pb中,我可以直接在C ++中使用它。 我这样做的代码是这样的:

tensor_input_info = tf.saved_model.utils.build_tensor_info(input)
tensor_output_info = tf.saved_model.utils.build_tensor_info(output)

gan_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'image': tensor_input_info},
        outputs={'result': tensor_output_info},
        method_name='gan'
    )
)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
    session, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
    'my_gan':gan_signature
    },
    legacy_init_op=legacy_init_op)
builder.save()

这里我不确定dict中的键。在这段代码中,我使用“image”作为我输入的关键,但我不知道它是否正确。即使我成功了saved_model.pb

现在我不知道该怎么做,我怎样才能在我的C ++项目中使用它?

0 个答案:

没有答案