BERT-pytorch到ONNX转换Lambda错误

时间:2019-07-10 00:53:13

标签: pytorch transformer onnx

我正在尝试将BERT的PyTorch实现从此处(https://github.com/codertimo/BERT-pytorch)转换为ONNX(并希望转换为coreml),但是要实现Transformer块的实现:

class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
        :param dropout: dropout rate
        """

        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) // <-- Error!
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

正在导致错误:

builtins.ValueError: Auto nesting doesn't know how to process an input object of type bert_pytorch.model.transformer.TransformerBlock.forward.<locals>.<lambda>. Accepted types: Tensors, or lists/tuples of them

我了解到lambda导致了错误(或者至少是我认为所要解决的问题),但是我不确定如何解决-如果确实是解决方法使用当前的ONNX是可能的。例如,有没有办法不用lambda来重写它?那可以在不破坏模型的情况下解决问题吗? (我对Python还是很陌生。)

我的转换非常简单:

export_model = bert
model_name = "bert.onnx"

dummy_input = (torch.randn(1, 40).long().cuda(), torch.randn(1, 40).long().cuda())
torch.onnx.export(export_model, 
                  dummy_input, 
                  model_name, 
                  input_names=['query_sequence'], 
                  output_names=['token_prediction'], 
                  operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

torch.randn(1, 40).long().cuda()行来自反复试验,ONNX_ATEN_FALLBACK的使用来自Google搜索有关转换变形金刚的搜索。

在相关说明中,我也很好奇是否有人知道在此阶段是否无法进行完整的BERT-PyTorch> ONNX> CoreML转换(至少要等到CoreML 3准备就绪时)。如果绝对不会发生,那么我将为自己省去一些扑朔迷离的动作!

任何想法都很感激。

0 个答案:

没有答案
相关问题