Mead
Mead

Reputation: 430

How to vectorize a torch function?

When using numpy I can use np.vectorize to vectorize a function that contains if statements in order for the function to accept array arguments. How can I do the same with torch in order for a function to accept tensor arguments?

For example, the final print statement in the code below will fail. How can I make this work?

import numpy as np
import torch as tc

def numpy_func(x):
    return x if x > 0. else 0.
numpy_func = np.vectorize(numpy_func)

print('numpy function (scalar):', numpy_func(-1.))
print('numpy function (array):', numpy_func(np.array([-1., 0., 1.])))

def torch_func(x):
    return x if x > 0. else 0.

print('torch function (scalar):', torch_func(-1.))
print('torch function (tensor):', torch_func(tc.tensor([-1., 0., 1.])))

Upvotes: 4

Views: 4971

Answers (1)

dx2-66
dx2-66

Reputation: 2851

You can use .apply_() for CPU tensors. For CUDA ones, the task is problematic: if statements aren't easy to SIMDify.

You may apply the same workaround for functorch.vmap as video drivers used to do for shaders: evaluate both branches of the condition and stick to arithmetics.

Otherwise, just use a for loop: that's what np.vectorize() mostly does anyway.

def torch_vectorize(f, inplace=False):
    def wrapper(tensor):
        out = tensor if inplace else tensor.clone()
        view = out.flatten()
        for i, x in enumerate(view):
            view[i] = f(x)
        return out
    return wrapper

Upvotes: 6

Related Questions