保存并恢复keras模型GRU图层隐藏状态,同时预测

时间:2018-03-23 08:48:18

标签: python tensorflow machine-learning keras rnn

我使用带有张量流后端的keras。

我的模特是:

X_input = Input(name="x_input", batch_shape=(BATCH_SIZE, 1, len(alphabet_dict)))
GRU_layer, hs = GRU(32, return_sequences=True, dropout=0.25, recurrent_dropout=0.25, stateful=True, return_state=True)(X_input)
GRU_layer2, hs2 = GRU(32, return_sequences=True, dropout=0.25, recurrent_dropout=0.25, stateful=True, return_state=True)(GRU_layer)
y = Dense(len(alphabet_dict), activation="softmax", name="y")(GRU_layer2)
model = Model(inputs=[X_input], outputs=[y, hs, hs2])
model.compile(loss={"y": categorical_crossentropy}, optimizer="adam", metrics=["acc"])

在几个时代之后保存它。然后在播放器脚本中从检查点恢复此操作,并使用一个输入独立于batch_size模型:

with tf.device('/cpu:0'):
    model = load_model(filepath=model_file, compile=True)
    old_weights = model.get_weights()
    del model

    X_input = Input(name="x_input", batch_shape=(1, 1, len(alphabet_dict)))
    GRU_layer, hs = GRU(32, return_sequences=True, dropout=0.25, recurrent_dropout=0.25, stateful=True, return_state=True)(X_input)
    GRU_layer2, hs2 = GRU(32, return_sequences=True, dropout=0.25, recurrent_dropout=0.25, stateful=True, return_state=True)(GRU_layer)
    y = Dense(len(alphabet_dict), activation="softmax", name="y")(GRU_layer2)
    model = Model(inputs=[X_input], outputs=[y, hs, hs2])
    model.compile(loss={"y": categorical_crossentropy}, optimizer="adam", metrics=["acc"])
    model.set_weights(old_weights)

那么我用模型预测:

prediction = model.predict(np_utils.to_categorical(input_idx, len(alphabet_dict)).reshape(1,1,len(alphabet_dict)))
prediction[0] - is model output "y"
prediction[1] - is GRU_layer hidden state hs
prediction[2] - is GRU_layer2 hidden state hs2

所以我可以"阅读"我的GRU图层的隐藏状态。 但后来我需要像"分支"这样的东西。我需要保存当前隐藏的状态,然后从THIS语句中为我的模型预测几个不同的输入。

我需要这样的东西:

  • 1)预测,预测,预测,预测......, - 这里我预测一个 一个输入,使用theese输出并累积模型内隐藏 的状态。
  • 2)保存该步骤的模型状态(例如步骤#500)
  • 3)为我的模型使用#501_1输入。获取并保存模型输出和模型 内隐的状态。
  • 4)从步骤#500
  • 加载模型内部状态
  • 5)为我的模型使用#501_2输入。获取并保存模型输出和模型 内隐的状态。
  • 6)选择最佳输出,并根据此选择正确的#501节点。
  • 7)喜欢1 - "预测,预测,......"

我希望你理解我的想法。这对我来说很重要,因为我不想从每个分支开始预测所有内容。

老实说,模型的python脚本播放器不是我的观点。 我将模型导出为C ++格式。然后我用C ++代码加载它,它比python解释更快。 因此,在理想情况下,我需要两种解决方案,用于python和C ++播放器。

0 个答案:

没有答案
相关问题