Reputation: 31
This model works in PyTorch however, after exporting it with PyTorch to Onnx format, the onnx runtime crashes with a 'Trilu NOT_IMPLEMENTED error' when loading it in. (I do not have this issue for my other models that use torch.tril() )
How do I make this model run in the Onnxruntime?
This is a visualisation of the Onnx graph of the Model.
The Model in PyTorch
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, item_seq):
attention_mask = item_seq < 100
tril_mask = torch.tril(attention_mask)
query_layer = torch.rand((1, 2, 2, 32))
key_layer = torch.rand((1, 2, 32, 2))
attention_scores = torch.matmul(query_layer, key_layer)
return attention_scores + tril_mask
model = MyModel()
model.eval()
x_train = torch.ones([1, 2], dtype=torch.long)
# demonstrate that eager works
print(model.forward(x_train))
bigmodel_onnx_filename = 'mymodel.onnx'
torch.onnx.export(
model,
x_train,
bigmodel_onnx_filename,
input_names=['x'],
output_names=['output'],
)
onnx.load(bigmodel_onnx_filename)
# Onnxruntime crashes when loading in the model
ort_sess = ort.InferenceSession(bigmodel_onnx_filename, providers=['CPUExecutionProvider'])
key = {'x': x_train.numpy()}
print(ort_sess.run(None, key))
This results in the following error for ort.InferenceSession():
NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/net/Trilu'
How can I make this model run in the Onnxruntime?
[github: code to reproduce the error and the model.onnx file] (https://github.com/bkersbergen/pytorch_onnx_runtime_error/blob/main/main.py)
I'm using python 3.9, these are the project requirements
torch==1.13.1
jupyter==1.0.0
onnxruntime==1.13.1
onnx==1.13.0
Torch nightly version 2.0.0.dev20230205 gave the same error
I then decided to implement my own tril function.
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, item_seq):
attention_mask = item_seq < 100
tril_mask = self.my_tril(attention_mask)
query_layer = torch.rand((1, 2, 2, 32))
key_layer = torch.rand((1, 2, 32, 2))
attention_scores = torch.matmul(query_layer, key_layer)
return attention_scores + tril_mask
def my_tril(self, x):
l = x.size(-1)
arange = torch.arange(l)
mask = arange.expand(l, l)
arange = arange.unsqueeze(-1)
mask = torch.le(mask, arange)
return x.masked_fill(mask == 0, 0)
but then I get a Where(9) node with name '/Where_1' NOT_IMPLEMENTED error. (?!)
Upvotes: 0
Views: 1318
Reputation: 31
The boolean output of torch.lt() as input for torch.tril() works with PyTorch's Eager and LIT mode. However it breaks the Onnx runtime with the "TRILU not implemented error".
I was able to work around it by casting the torch.tril() input to float():
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, item_seq):
attention_mask = torch.lt(item_seq, 100).float()
tril_mask = torch.tril(attention_mask)
query_layer = torch.rand((1, 2, 2, 32))
key_layer = torch.rand((1, 2, 32, 2))
attention_scores = torch.matmul(query_layer, key_layer)
return attention_scores + tril_mask
Based on this experience, my hypothesis is that the TRILU NOT_IMPLEMENTED error is only applicable when having BOOLEAN Tensors as input. The Onnxruntime then throws this generic TRILU NOT_IMPLEMENTED error making me believe that Onnx has no TRILU support at all, which is clearly not the case.
Upvotes: 0