Tensorflow / Keras保存的模型加载未经训练的C ++

时间:2018-06-21 21:38:26

标签: python c++ tensorflow keras

我基本上遵循this的示例回答如何在python中保存模型并在C ++中加载模型。我的代码如下。在加载过程中,不会以TF状态抛出错误。但是,当我对模型进行推论时,其输出与经过训练的python模型输出相同。我认为问题可能是在保存模型或加载过程中发生的。

这是我总结的python培训/保存代码:

model = tf.keras.Sequential([
tf.keras.layers.Convolution2D(filters, (5, 5), input_shape=(7, 19, 19), data_format='channels_first', name='Input'),
tf.keras.layers.MaxPooling2D(pool_size=(2,2), name='Pool0'),
tf.keras.layers.Flatten(name='Flatten0'),
tf.keras.layers.Dense(BoardSize, activation='softmax', name='Output'),
])

optimizer = tf.train.AdamOptimizer(learning_rate=0.0018)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) 

# Train the model
# Do I need to save checkpoints here with a callback?
history = model.fit_generator(generator=gen.generator(),
                steps_per_epoch=gen.stepsPerEpoch(),
                validation_data=valGen.generator(),
                validation_steps=valGen.stepsPerEpoch(),
                epochs=numEpochs)

# Save the model
K.set_learning_phase(0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess, './models/latestModel')

这是我的C ++图形/砝码加载代码,该代码没有任何状态错误,应该加载已保存的砝码,对吗?:

static const std::string pathToCheckpoint = "Models/PolicyModel/latestModel";
static const std::string pathToGraph = pathToCheckpoint + ".meta";

// Set memory growth manually. Resolves and issue with tf 1.5 for windows
double memFraction = 1.0;
bool allowMemGrowth = true;
options.config.mutable_gpu_options()->set_allow_growth(allowMemGrowth);
options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(memFraction);

session = tf::NewSession(options);
if (!session)
    throw std::runtime_error("Could no create Tensorflow Session!");

// Read in the protobuf
status = tf::ReadBinaryProto(tf::Env::Default(), pathToGraph, &graphDef);
if (!status.ok())
    throw std::runtime_error("Error reading graph: " + status.ToString());

status = session->Create(graphDef.graph_def());
if (!status.ok())
    throw std::runtime_error("Error creating graph: " + status.ToString());

// Read the weights
tf::Tensor checkpointPathTensor(tf::DT_STRING, tf::TensorShape());
checkpointPathTensor.scalar<std::string>()() = pathToCheckpoint;

status = session->Run(
    { { graphDef.saver_def().filename_tensor_name(), checkpointPathTensor }, },
    {},
    { graphDef.saver_def().restore_op_name() },
    nullptr
);

if (!status.ok())
    throw std::runtime_error("Error loading checkpoint from " + pathToCheckpoint + ": " + status.ToString());

我也可以发布推理代码,为了使它相对简短,我一开始就将其保留。任何人对可能出什么问题有任何想法?我是否正确保存了模型和权重,或者缺少一步?权重是否已使用C ++正确加载?

我尝试了另一种保存/加载图形的方法here,但是无法加载使用此方法保存的图形。

0 个答案:

没有答案