Reputation: 1769
I am trying to export my LSTM Anomally-Detection Pytorch model to ONNX, but I'm experiencing errors. Please take a look at my code below.
Note: My data is shaped as [2685, 5, 6]. Here is where I define my model:
class Model(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim):
super(Model, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, input_dim)
def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
out = self.fc1(out)
out = self.fc2(out)
return out
input_dim = 6
hidden_dim = 3
layer_dim = 2
model = Model(input_dim, hidden_dim, layer_dim)
I can train it and test with it fine. But the problem comes when exporting:
model.eval()
import torch.onnx
torch_out = torch.onnx.export(model,
torch.randn(2685, 5, 6),
"onnx_model.onnx",
export_params = True
)
But I have the following error:
LSTM(6, 3, num_layers=2, batch_first=True)
Linear(in_features=3, out_features=3, bias=True)
Linear(in_features=3, out_features=6, bias=True)
['input_1', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear']
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/symbolic.py:173: UserWarning: ONNX export failed on RNN/GRU/LSTM because batch_first not supported
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-264-28c6c55537ab> in <module>()
10 torch.randn(2685, 5, 6),
11 "onnx_model.onnx",
---> 12 export_params = True
13 )
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py in export(*args, **kwargs)
23 def export(*args, **kwargs):
24 from torch.onnx import utils
---> 25 return utils.export(*args, **kwargs)
26
27
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
129 operator_export_type=operator_export_type, opset_version=opset_version,
130 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
--> 131 strip_doc_string=strip_doc_string)
132
133
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
367 if export_params:
368 proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type,
--> 369 strip_doc_string)
370 else:
371 proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type, strip_doc_string)
RuntimeError: ONNX export failed: Couldn't export operator aten::lstm
Defined at:
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(522): forward_impl
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(539): forward_tensor
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(559): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
<ipython-input-255-468cef410a2c>(14): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(294): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(493): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(231): get_trace_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(225): _trace_and_get_graph_from_model
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(266): _model_to_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(363): _export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(131): export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py(25): export
<ipython-input-264-28c6c55537ab>(12): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2963): run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2903): run_ast_nodes
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2785): _run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2662): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/zmqshell.py(537): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/ipkernel.py(208): do_execute
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(399): execute_request
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(233): dispatch_shell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(283): dispatcher
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(432): _run_callback
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(480): _handle_recv
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(450): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(117): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/events.py(145): _run
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(1432): _run_once
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(422): run_forever
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(127): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelapp.py(486): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/traitlets/config/application.py(658): launch_instance
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/__main__.py(3): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(85): _run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(193): _run_module_as_main
Graph we tried to export:
graph(%input.1 : Float(2685, 5, 6),
%lstm.weight_ih_l0 : Float(12, 6),
%lstm.weight_hh_l0 : Float(12, 3),
%lstm.bias_ih_l0 : Float(12),
%lstm.bias_hh_l0 : Float(12),
%lstm.weight_ih_l1 : Float(12, 3),
%lstm.weight_hh_l1 : Float(12, 3),
%lstm.bias_ih_l1 : Float(12),
%lstm.bias_hh_l1 : Float(12),
%fc1.weight : Float(3, 3),
%fc1.bias : Float(3),
%fc2.weight : Float(6, 3),
%fc2.bias : Float(6)):
%13 : Long() = onnx::Constant[value={0}](), scope: Model
%14 : Tensor = onnx::Shape(%input.1), scope: Model
%15 : Long() = onnx::Gather[axis=0](%14, %13), scope: Model
%16 : Long() = onnx::Constant[value={2}](), scope: Model
%17 : Long() = onnx::Constant[value={3}](), scope: Model
%18 : Tensor = onnx::Unsqueeze[axes=[0]](%16)
%19 : Tensor = onnx::Unsqueeze[axes=[0]](%15)
%20 : Tensor = onnx::Unsqueeze[axes=[0]](%17)
%21 : Tensor = onnx::Concat[axis=0](%18, %19, %20)
%22 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%21), scope: Model
%23 : Long() = onnx::Constant[value={0}](), scope: Model
%24 : Tensor = onnx::Shape(%input.1), scope: Model
%25 : Long() = onnx::Gather[axis=0](%24, %23), scope: Model
%26 : Long() = onnx::Constant[value={2}](), scope: Model
%27 : Long() = onnx::Constant[value={3}](), scope: Model
%28 : Tensor = onnx::Unsqueeze[axes=[0]](%26)
%29 : Tensor = onnx::Unsqueeze[axes=[0]](%25)
%30 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
%31 : Tensor = onnx::Concat[axis=0](%28, %29, %30)
%32 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%31), scope: Model
%33 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
%34 : Long() = onnx::Constant[value={2}](), scope: Model/LSTM[lstm]
%35 : Double() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
%36 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
%37 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
%38 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
%input.2 : Float(2685!, 5!, 3), %40 : Float(2, 2685, 3), %41 : Float(2, 2685, 3) = aten::lstm(%input.1, %22, %32, %lstm.weight_ih_l0, %lstm.weight_hh_l0, %lstm.bias_ih_l0, %lstm.bias_hh_l0, %lstm.weight_ih_l1, %lstm.weight_hh_l1, %lstm.bias_ih_l1, %lstm.bias_hh_l1, %33, %34, %35, %36, %37, %38), scope: Model/LSTM[lstm]
%42 : Float(3!, 3!) = onnx::Transpose[perm=[1, 0]](%fc1.weight), scope: Model/Linear[fc1]
%43 : Float(2685, 5, 3) = onnx::MatMul(%input.2, %42), scope: Model/Linear[fc1]
%44 : Float(2685, 5, 3) = onnx::Add(%43, %fc1.bias), scope: Model/Linear[fc1]
%45 : Float(3!, 6!) = onnx::Transpose[perm=[1, 0]](%fc2.weight), scope: Model/Linear[fc2]
%46 : Float(2685, 5, 6) = onnx::MatMul(%44, %45), scope: Model/Linear[fc2]
%47 : Float(2685, 5, 6) = onnx::Add(%46, %fc2.bias), scope: Model/Linear[fc2]
return (%47)
What does this mean? What should I do to export properly?
Upvotes: 4
Views: 7144
Reputation: 188
If you're coming here from Google the previous answers are no longer up to date. ONNX now supports an LSTM operator. Take care as exporting from PyTorch will fix the input sequence length by default unless you use the dynamic_axes
parameter.
Below is a minimal LSTM export example I adapted from the torch.onnx FAQ
import torch
import onnx
from torch import nn
import numpy as np
import onnxruntime.backend as backend
import numpy as np
torch.manual_seed(0)
layer_count = 4
model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
model.eval()
with torch.no_grad():
input = torch.randn(1, 3, 10)
h0 = torch.randn(layer_count * 2, 3, 20)
c0 = torch.randn(layer_count * 2, 3, 20)
output, (hn, cn) = model(input, (h0, c0))
# default export
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
onnx_model = onnx.load('lstm.onnx')
# input shape [5, 3, 10]
print(onnx_model.graph.input[0])
# export with `dynamic_axes`
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
input_names=['input', 'h0', 'c0'],
output_names=['output', 'hn', 'cn'],
dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
onnx_model = onnx.load('lstm.onnx')
# input shape ['sequence', 3, 10]
# Check export
y, (hn, cn) = model(input, (h0, c0))
y_onnx, hn_onnx, cn_onnx = backend.run(
onnx_model,
[input.numpy(), h0.numpy(), c0.numpy()],
device='CPU'
)
np.testing.assert_almost_equal(y_onnx, y.detach(), decimal=5)
np.testing.assert_almost_equal(hn_onnx, hn.detach(), decimal=5)
np.testing.assert_almost_equal(cn_onnx, cn.detach(), decimal=5)
I've tested this example with: torch==1.4.0, onnx=1.7.0
Upvotes: 10
Reputation: 1
Try with batch_first=False. It is not supported on True by ONNX. you might need to transpose your datas because you'll have : (timesteps, batch, features) instead of (batch, timesteps, features).
Upvotes: 0
Reputation: 1962
You are not doing anything wrong
RuntimeError: ONNX export failed: Couldn't export operator aten::lstm
LSTM is not in the list of supported operators on onnx limitations
Checking github issue queue for RuntimError on unsupported aten:: there are more types not (yet) supported.
Upvotes: 0