Reputation: 24874
torch.nn.functional.grid_sample
(source here, click on docs for documentation) is currently unsupported operation by CoreML (and their conversion utilities library: coremltools).
What I'm looking for is a way to export layer shown below from PyTorch's torchscript
(docs here) to CoreML (either using custom op
created via Swift or via efficient PyTorch rewrite of grid_sample
).
For details and tips to get you started see Tips section
import coremltools as ct
import torch
class GridSample(torch.nn.Module):
def forward(self, inputs, grid):
# Rest could be the default behaviour, e.g. bilinear
return torch.nn.functional.grid_sample(inputs, grid, align_corners=True)
# Image could also have more in_channels, different dimension etc.,
# for example (2, 32, 64, 64)
image = torch.randn(2, 3, 32, 32) # (batch, in_channels, width, height)
grid = torch.randint(low=-1, high=2, size=(2, 64, 64, 2)).float()
layer = GridSample()
# You could use `torch.jit.script` if preferable
scripted = torch.jit.trace(layer, (image, grid))
# Sanity check
print(scripted(image, grid).shape)
# Error during conversion
coreml_layer = ct.converters.convert(
scripted,
source="pytorch",
inputs=[
ct.TensorType(name="image", shape=image.shape),
ct.TensorType(name="grid", shape=grid.shape),
],
)
which raises the following error:
Traceback (most recent call last):
File "/home/REDACTED/Downloads/sample.py", line 23, in <module>
coreml_layer = ct.converters.convert(
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 175, in convert
mlmodel = mil_convert(
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 128, in mil_convert
proto = mil_convert_to_proto(, convert_from, convert_to,
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 171, in mil_convert_to_proto
prog = frontend_converter(, **kwargs)
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 85, in __call__
return load(*args, **kwargs)
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 81, in load
raise e
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 73, in load
prog = converter.convert()
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 227, in convert
convert_nodes(self.context, self.graph)
File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 54, in convert_nodes
raise RuntimeError(
RuntimeError: PyTorch convert function for op 'grid_sampler' not implemented.
Python (conda
):
coremltools==4.1
torch==1.8.0
You could also use nightly
/master
builds (at least for the day of writing: 2021-03-20
)
Those were split into two possible solutions I currently see:
Rewrite torch.nn.functional.grid_sample
from scratch.
__getitem__
on list
or related types - seems to work with torch.Tensor
but had problems with that so you should have it in mind if you get RuntimeError: PyTorch convert function for op '__getitem__' not implemented
Pros:
Cons:
Register custom layer which is responsible for running grid_sample
. CPU only implementation would be fine (although using Apple's Metal for GPU speedups would be great).
As I'm not into Swift, I've gathered a few resources which might help you:
Pros:
Cons:
Upvotes: 4
Views: 1454
Reputation: 24874
Apparently some good soul saw our struggles and provided custom op using MIL (intermediate representation language of CoreML).
I am not sure why OP did not post it here, but please do respond with a comment if you want to take some SO points for your solution!
Full operation conversion code below:
from coremltools.converters.mil import register_torch_op, register_op
from coremltools.converters.mil.mil.ops.defs._op_reqs import *
# Custom operator for `torch.nn.functional.grid_sample`
@register_op(doc_str="Custom Grid Sampler", is_custom_op=True)
class custom_grid_sample(Operation):
input_spec = InputSpec(
x = TensorInputType(),
grid = TensorInputType(),
mode = StringInputType(const=True, optional=True),
padding_mode = StringInputType(const=True, optional=True),
align_corners = BoolInputType(const=True, optional=True)
)
bindings = {
"class_name": "CustomGridSampler",
"input_order": ["x", "grid"],
"parameters": ["mode", "padding_mode", "align_corners"],
"description": "Custom Grid Sampler"
}
def __init__(self, **kwargs):
super(custom_grid_sample, self).__init__(**kwargs)
def type_inference(self):
x_type = self.x.dtype
x_shape = self.x.shape
grid_type = self.grid.dtype
grid_shape = self.grid.shape
assert len(x_shape) == len(grid_shape) == 4
assert grid_shape[-1] == 2
shape = list(x_shape)
shape[-2] = grid_shape[1]
shape[-1] = grid_shape[2]
return types.tensor(x_type, tuple(shape))
@register_torch_op
def grid_sampler(context, node):
inputs = _get_inputs(context, node)
x = inputs[0]
grid = inputs[1]
mode = node.attr.get("mode", "bilinear")
padding_mode = node.attr.get("padding_mode", "zeros")
align_corners = node.attr.get("align_corners", False)
x = mb.custom_grid_sample(
x=x,
grid=grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
name=node.name
)
context.add(x)
Upvotes: 1
Reputation: 1680
Well, this is not exact answer, rather some research. grid_sample
by it's nature is sparse matrix operation, the idea is to try make it dense. The code below is demonstrates how it could be done. It may be slow, and requires grid
to be static to eliminate grid_sample
from model to be converted, but kinda works.
The goal is to get our transformation in linear form. Here, to get the dense matrix, we feed unit diagonal to 'grid_sample', and the result is matrix holding transform we are looking for. To do named transform, multiply flattened image to this matrix.
As you can see batch=1 here, the conversion must be done for each grid
independently.
Your code:
in_sz = 2; out_sz = 4; batch = 1; ch = 3
class GridSample(torch.nn.Module):
def forward(self, inputs, grid):
# Rest could be the default behaviour, e.g. bilinear
return torch.nn.functional.grid_sample(inputs, grid, align_corners=True)
image = torch.randn( batch, ch, in_sz, in_sz) # (batch, in_channels, width, height)
grid = torch.randint(low=-1, high=2, size=( batch, out_sz, out_sz, 2)).float()
layer = GridSample()
scripted = torch.jit.trace(layer, (image, grid))
print(scripted(image, grid))
out:
tensor([[[[-0.8226, -0.4457, -0.3382, -0.0795],
[-0.4457, -0.0052, -0.8226, -0.6341],
[-0.4457, -0.8226, -0.4457, -0.6341],
[-0.4510, -0.3382, -0.4457, -0.0424]],
[[-1.0090, -1.6029, -1.3813, -0.1212],
[-1.6029, -2.7920, -1.0090, -1.3060],
[-1.6029, -1.0090, -1.6029, -1.3060],
[-0.5651, -1.3813, -1.6029, -1.4566]],
[[ 0.1482, 0.7313, 0.8916, 1.8723],
[ 0.7313, 0.8144, 0.1482, 0.4398],
[ 0.7313, 0.1482, 0.7313, 0.4398],
[ 1.0103, 0.8916, 0.7313, 1.3434]]]])
Conversion:
oness = torch.ones( in_sz*in_sz )
diagg = torch.diag( oness ).reshape( 1, in_sz*in_sz, in_sz, in_sz )
denser = torch.nn.functional.grid_sample( diagg, grid, align_corners=True).reshape( in_sz*in_sz, out_sz*out_sz ).transpose(0,1)
print (denser.shape)
print (image.shape)
image_flat = image.reshape( batch, ch, in_sz*in_sz )
print (image_flat.shape)
print( torch.nn.functional.linear( image_flat, denser ).reshape( batch, ch, out_sz, out_sz ) )
Out:
torch.Size([16, 4])
torch.Size([1, 3, 2, 2])
torch.Size([1, 3, 4])
tensor([[[[-0.8226, -0.4457, -0.3382, -0.0795],
[-0.4457, -0.0052, -0.8226, -0.6341],
[-0.4457, -0.8226, -0.4457, -0.6341],
[-0.4510, -0.3382, -0.4457, -0.0424]],
[[-1.0090, -1.6029, -1.3813, -0.1212],
[-1.6029, -2.7920, -1.0090, -1.3060],
[-1.6029, -1.0090, -1.6029, -1.3060],
[-0.5651, -1.3813, -1.6029, -1.4566]],
[[ 0.1482, 0.7313, 0.8916, 1.8723],
[ 0.7313, 0.8144, 0.1482, 0.4398],
[ 0.7313, 0.1482, 0.7313, 0.4398],
[ 1.0103, 0.8916, 0.7313, 1.3434]]]])
Well, may be not very effective, I hope this amuses at least.
Upvotes: 2