Szymon Maszke
Szymon Maszke

Reputation: 24874

PyTorch's grid_sample conversion to CoreML (via coremltools)

torch.nn.functional.grid_sample (source here, click on docs for documentation) is currently unsupported operation by CoreML (and their conversion utilities library: coremltools).

What I'm looking for is a way to export layer shown below from PyTorch's torchscript (docs here) to CoreML (either using custom op created via Swift or via efficient PyTorch rewrite of grid_sample).

For details and tips to get you started see Tips section

Minimal verifiable example

import coremltools as ct
import torch

class GridSample(torch.nn.Module):
    def forward(self, inputs, grid):
        # Rest could be the default behaviour, e.g. bilinear
        return torch.nn.functional.grid_sample(inputs, grid, align_corners=True)

# Image could also have more in_channels, different dimension etc.,
# for example (2, 32, 64, 64)
image = torch.randn(2, 3, 32, 32)  # (batch, in_channels, width, height)
grid = torch.randint(low=-1, high=2, size=(2, 64, 64, 2)).float()

layer = GridSample()
# You could use `torch.jit.script` if preferable
scripted = torch.jit.trace(layer, (image, grid))

# Sanity check
print(scripted(image, grid).shape)

# Error during conversion
coreml_layer = ct.converters.convert(
        ct.TensorType(name="image", shape=image.shape),
        ct.TensorType(name="grid", shape=grid.shape),

which raises the following error:

Traceback (most recent call last):
  File "/home/REDACTED/Downloads/", line 23, in <module>
    coreml_layer = ct.converters.convert(
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/", line 175, in convert
    mlmodel = mil_convert(
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/", line 128, in mil_convert
    proto = mil_convert_to_proto(, convert_from, convert_to,
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/", line 171, in mil_convert_to_proto
    prog = frontend_converter(, **kwargs)
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/", line 85, in __call__
    return load(*args, **kwargs)
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/", line 81, in load
    raise e
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/", line 73, in load
    prog = converter.convert()
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/", line 227, in convert
    convert_nodes(self.context, self.graph)
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/", line 54, in convert_nodes
    raise RuntimeError(
RuntimeError: PyTorch convert function for op 'grid_sampler' not implemented.


Python (conda):

You could also use nightly/master builds (at least for the day of writing: 2021-03-20)


Those were split into two possible solutions I currently see:

PyTorch only

Rewrite torch.nn.functional.grid_sample from scratch.



Swift & CoreML

Register custom layer which is responsible for running grid_sample. CPU only implementation would be fine (although using Apple's Metal for GPU speedups would be great).

As I'm not into Swift, I've gathered a few resources which might help you:



Upvotes: 4

Views: 1454

Answers (2)

Szymon Maszke
Szymon Maszke

Reputation: 24874

Apparently some good soul saw our struggles and provided custom op using MIL (intermediate representation language of CoreML).

Blog post where I found the solution and gist with grid sample

I am not sure why OP did not post it here, but please do respond with a comment if you want to take some SO points for your solution!

Full operation conversion code below:

from import register_torch_op, register_op
from import *

# Custom operator for `torch.nn.functional.grid_sample`
@register_op(doc_str="Custom Grid Sampler", is_custom_op=True)
class custom_grid_sample(Operation):
    input_spec = InputSpec(
        x = TensorInputType(),
        grid = TensorInputType(),
        mode = StringInputType(const=True, optional=True),
        padding_mode = StringInputType(const=True, optional=True),
        align_corners = BoolInputType(const=True, optional=True)

    bindings = {
        "class_name": "CustomGridSampler",
        "input_order": ["x", "grid"],
        "parameters": ["mode", "padding_mode", "align_corners"],
        "description": "Custom Grid Sampler"

    def __init__(self, **kwargs):
        super(custom_grid_sample, self).__init__(**kwargs)

    def type_inference(self):
        x_type = self.x.dtype
        x_shape = self.x.shape

        grid_type = self.grid.dtype
        grid_shape = self.grid.shape

        assert len(x_shape) == len(grid_shape) == 4
        assert grid_shape[-1] == 2

        shape = list(x_shape)
        shape[-2] = grid_shape[1]
        shape[-1] = grid_shape[2]
        return types.tensor(x_type, tuple(shape))

def grid_sampler(context, node):
    inputs = _get_inputs(context, node)
    x = inputs[0]
    grid = inputs[1]
    mode = node.attr.get("mode", "bilinear")
    padding_mode = node.attr.get("padding_mode", "zeros")
    align_corners = node.attr.get("align_corners", False)
    x = mb.custom_grid_sample(

Upvotes: 1

Alexey Birukov
Alexey Birukov

Reputation: 1680

Well, this is not exact answer, rather some research. grid_sample by it's nature is sparse matrix operation, the idea is to try make it dense. The code below is demonstrates how it could be done. It may be slow, and requires grid to be static to eliminate grid_sample from model to be converted, but kinda works.

The goal is to get our transformation in linear form. Here, to get the dense matrix, we feed unit diagonal to 'grid_sample', and the result is matrix holding transform we are looking for. To do named transform, multiply flattened image to this matrix. As you can see batch=1 here, the conversion must be done for each grid independently.

Your code:

in_sz  = 2;    out_sz = 4;    batch  = 1;    ch     = 3

class GridSample(torch.nn.Module):
    def forward(self, inputs, grid):
        # Rest could be the default behaviour, e.g. bilinear
        return torch.nn.functional.grid_sample(inputs, grid, align_corners=True)

image = torch.randn( batch, ch, in_sz, in_sz)  # (batch, in_channels, width, height)
grid = torch.randint(low=-1, high=2, size=( batch, out_sz, out_sz, 2)).float()

layer = GridSample()
scripted = torch.jit.trace(layer, (image, grid))
print(scripted(image, grid))


tensor([[[[-0.8226, -0.4457, -0.3382, -0.0795],
          [-0.4457, -0.0052, -0.8226, -0.6341],
          [-0.4457, -0.8226, -0.4457, -0.6341],
          [-0.4510, -0.3382, -0.4457, -0.0424]],

         [[-1.0090, -1.6029, -1.3813, -0.1212],
          [-1.6029, -2.7920, -1.0090, -1.3060],
          [-1.6029, -1.0090, -1.6029, -1.3060],
          [-0.5651, -1.3813, -1.6029, -1.4566]],

         [[ 0.1482,  0.7313,  0.8916,  1.8723],
          [ 0.7313,  0.8144,  0.1482,  0.4398],
          [ 0.7313,  0.1482,  0.7313,  0.4398],
          [ 1.0103,  0.8916,  0.7313,  1.3434]]]])


oness  = torch.ones( in_sz*in_sz )
diagg  = torch.diag( oness ).reshape( 1, in_sz*in_sz, in_sz, in_sz )
denser = torch.nn.functional.grid_sample( diagg, grid, align_corners=True).reshape( in_sz*in_sz, out_sz*out_sz ).transpose(0,1)
print (denser.shape)
print (image.shape)
image_flat = image.reshape( batch, ch, in_sz*in_sz )
print (image_flat.shape)
print( torch.nn.functional.linear( image_flat, denser ).reshape( batch, ch, out_sz, out_sz ) )


torch.Size([16, 4])
torch.Size([1, 3, 2, 2])
torch.Size([1, 3, 4])
tensor([[[[-0.8226, -0.4457, -0.3382, -0.0795],
          [-0.4457, -0.0052, -0.8226, -0.6341],
          [-0.4457, -0.8226, -0.4457, -0.6341],
          [-0.4510, -0.3382, -0.4457, -0.0424]],

         [[-1.0090, -1.6029, -1.3813, -0.1212],
          [-1.6029, -2.7920, -1.0090, -1.3060],
          [-1.6029, -1.0090, -1.6029, -1.3060],
          [-0.5651, -1.3813, -1.6029, -1.4566]],

         [[ 0.1482,  0.7313,  0.8916,  1.8723],
          [ 0.7313,  0.8144,  0.1482,  0.4398],
          [ 0.7313,  0.1482,  0.7313,  0.4398],
          [ 1.0103,  0.8916,  0.7313,  1.3434]]]])

Well, may be not very effective, I hope this amuses at least.

Upvotes: 2

Related Questions