尝试将双线性层转换为onnx时,上采样ONNX提供INVALID_GRAPH

时间:2019-06-20 09:20:10

标签: pytorch onnx

当我将在Pytorch上训练有双线性层的网络转换为ONNX时,会出现以下错误

  

RuntimeError:[ONNXRuntimeError]:10:INVALID_GRAPH:加载模型   从test.onnx失败:类型错误:输入的类型'张量(int64)'   节点()中运算符(Floor)的参数(11)无效。

我不确定为什么会发生此错误,我尝试从源代码构建ONNX,但问题似乎并没有解决。

关于什么可能导致此错误的任何想法?或如何解决该问题?

繁殖方式-

from torch import nn

import torch
import torch.nn.functional as F
import onnxruntime as rt

class Upsample(torch.nn.Module):
    def forward(self, x):
        #l = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=1, bias=True)
        return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)

m = Upsample()
v = torch.randn(1,3,128,128, dtype=torch.float32, requires_grad=False)

torch.onnx.export(m, v, "test.onnx")
sess = rt.InferenceSession("test.onnx")

1 个答案:

答案 0 :(得分:4)

此错误已在https://github.com/pytorch/pytorch/pull/21434中进行了修复(此修复已在functional.py中进行了修复),因此,如果您安装pytorch的夜间版本,则应该能够得到此错误。

但是,在同一PR中,已禁止以双线性模式转换Upsample;原因是Pytorch的双线性模式与ONNX的不一致,而“最近”模式是当前唯一支持的模式。

ONNX中的upsample(现在称为Resize)正在opset 11中进行更新,以支持与https://github.com/onnx/onnx/pull/2057中的Pytorch对齐的双线性模式,但尚未推送。