如果调用linger.init(...)接口后,使用torch.onnx.export会被自动替换为linger.onnx.export进行调用,即torch.onnx.export = linger.onnx.export
import linger
.....
linger.init(...)
torch.onnx.export(...) # 实际上调用的是 linger.onnx.export
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"super_resolution.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=12, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes
'output' : {0 : 'batch_size'}})
其中 dynamic_axes使用有几种形式:
input_1的0,2,3维作为动态输入,第1仍然保持固定输入,'input_2'第0维作为动态输入,output的0,1维作为动态输入,对于动态输入的维度,PyTorch会自动给该维度生成一个名字以替换维度信息dynamic_axes = {'input_1':[0, 2, 3],
'input_2':[0],
'output':[0, 1]}
input_1,指定动态维0、1、2的名字分别为batch、width、height,其他输入同理dynamic_axes = {'input_1':{0:'batch',
1:'width',
2:'height'},
'input_2':{0:'batch'},
'output':{0:'batch',
1:'detections'}
dynamic_axes = {'input_1':[0, 2, 3],
'input_2':{0:'batch'},
'output':[0,1]}
torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACKimport torch
import torch.onnx
torch_model = ...
# set the model to inference mode
torch_model.eval()
dummy_input = torch.randn(1,3,244,244)
torch.onnx.export(torch_model,dummy_input,"test.onnx",
opset_version=11,input_names=["input"],output_names=["output"],operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
with torch.no_grad(),即import torch
import torch.onnx
torch_model = ...
# set the model to inference mode
torch_model.eval()
dummy_input = torch.randn(1,3,244,244)
with torch.no_grad():
torch.onnx.export(torch_model,dummy_input,"test.onnx",
opset_version=11,input_names=["input"],output_names=["output"],operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
警告:如果不使用with torch.no_grad(),则会报以下错误
RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/functions/utils.h":59, please report a bug to PyTorch.