训练了内切解码器序列模型。为了将模型应用于推理,将第t-1
个解码器输出的预测用作t
的输入。如果我们有一个ending_token
,并且我想让解码器在ending_token
上输出t
后停止。
在训练阶段,可以使用Masking(mask_value=ending_token)
解决问题。在推断阶段,如果每次仅测试一个例句,我可以使用:
output = []
while decoder_output_id != ending_token:
decoder_output_id, decoder_status = DECODER_model.predict([decoder_output_id, decoder_status])
output.append(decoder_out_id)
但是,如果我想对一批样本进行推断,我应该如何处理批次中的每个样本可能在不同的时间结束。一种想法是固定输出时间长度,并在ending_token
看起来为0之后强制输出,但这会导致开销。