如何在Keras解码器中处理批处理样本的结束令牌

时间:2019-05-14 02:02:09

标签: python keras sequence encoder-decoder

训练了内切解码器序列模型。为了将模型应用于推理,将第t-1个解码器输出的预测用作t的输入。如果我们有一个ending_token,并且我想让解码器在ending_token上输出t后停止。

在训练阶段,可以使用Masking(mask_value=ending_token)解决问题。在推断阶段,如果每次仅测试一个例句,我可以使用:

output = []
while decoder_output_id != ending_token:
    decoder_output_id, decoder_status = DECODER_model.predict([decoder_output_id, decoder_status])
    output.append(decoder_out_id)

但是,如果我想对一批样本进行推断,我应该如何处理批次中的每个样本可能在不同的时间结束。一种想法是固定输出时间长度,并在ending_token看起来为0之后强制输出,但这会导致开销。

0 个答案:

没有答案
相关问题