Reputation: 199
In PyTorch, the build-in torch.roll
function is only able to shift columns (or rows) with same offsets. But I want to shift columns with different offsets. Suppose the input tensor is
[[1,2,3],
[4,5,6],
[7,8,9]]
Let's say, I want to shift with offset i
for the i-th column. Thus, the expected output is
[[1,8,6],
[4,2,9],
[7,5,3]]
An option to do so is to separately shift every column using torch.roll
and concat each of them. But for the consideration of effectiveness and code compactness, I don't want to introduce the loop structure. Is there a better way?
Upvotes: 6
Views: 4847
Reputation: 27191
A generic version of @DanielM's solution. Given:
mat = torch.tensor(
[[1,2,3],
[4,5,6],
[7,8,9]]
)
shifts = torch.tensor([0, 1, 2])
indices = (torch.arange(mat.shape[0])[:, None] - shifts[None, :]) % mat.shape[0]
torch.gather(mat, 0, indices)
indices = (torch.arange(mat.shape[1])[None, :] - shifts[:, None]) % mat.shape[1]
torch.gather(mat, 1, indices)
def roll_along(arr, shifts, dim):
assert arr.ndim - 1 == shifts.ndim
dim %= arr.ndim
shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1)
dim_indices = torch.arange(arr.shape[dim]).reshape(shape)
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices)
roll_along(mat, shifts, dim=0) # roll rows
roll_along(mat, shifts, dim=1) # roll columns
Upvotes: 0
Reputation: 368
I was sceptical about the performance of torch.gather
so I searched for similar questions with numpy and found this post.
I took the solution from @Andy L and translated it into pytorch. However, take it with a grain of salt, because I don't know how the strides work:
from numpy.lib.stride_tricks import as_strided
# NumPy solution:
def custom_roll(arr, r_tup):
m = np.asarray(r_tup)
arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy`
#print(arr_roll)
strd_0, strd_1 = arr_roll.strides
#print(strd_0, strd_1)
n = arr.shape[1]
result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))
return result[np.arange(arr.shape[0]), (n-m)%n]
# Translated to PyTorch
def pcustom_roll(arr, r_tup):
m = torch.tensor(r_tup)
arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].clone() #need `copy`
#print(arr_roll)
strd_0, strd_1 = arr_roll.stride()
#print(strd_0, strd_1)
n = arr.shape[1]
result = torch.as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))
return result[torch.arange(arr.shape[0]), (n-m)%n]
Here is also the solution from @Daniel M as plug and play.
def roll_by_gather(mat,dim, shifts: torch.LongTensor):
# assumes 2D array
n_rows, n_cols = mat.shape
if dim==0:
#print(mat)
arange1 = torch.arange(n_rows).view((n_rows, 1)).repeat((1, n_cols))
#print(arange1)
arange2 = (arange1 - shifts) % n_rows
#print(arange2)
return torch.gather(mat, 0, arange2)
elif dim==1:
arange1 = torch.arange(n_cols).view(( 1,n_cols)).repeat((n_rows,1))
#print(arange1)
arange2 = (arange1 - shifts) % n_cols
#print(arange2)
return torch.gather(mat, 1, arange2)
First, I ran the methods on CPU.
Surprisingly, the gather
solution from above is the fastest:
n_cols = 10000
n_rows = 100
shifts = torch.randint(-100,100,size=[n_rows,1])
data = torch.arange(n_rows*n_cols).reshape(n_rows,n_cols)
npdata = np.arange(n_rows*n_cols).reshape(n_rows,n_cols)
npshifts = shifts.numpy()
%timeit roll_by_gather(data,1,shifts)
%timeit pcustom_roll(data,shifts)
%timeit custom_roll(npdata,npshifts)
>> 2.41 ms ± 68.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>> 90.4 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>> 247 ms ± 6.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Running the code on GPU shows similar results:
%timeit roll_by_gather(data,shifts)
%timeit pcustom_roll(data,shifts)
131 µs ± 6.79 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
3.29 ms ± 46.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(Note: You need torch.arange(...,device='cuda:0')
within the roll_by_gather
method)
Upvotes: 3
Reputation: 1433
Let's define some names:
import torch
mat = torch.Tensor(
[[1,2,3],
[4,5,6],
[7,8,9]])
indices = torch.LongTensor([0, 1, 2]) # Could also use arange in this specific scenario
First, you can make a tensor like
[[0, 0, 0],
[1, 1, 1],
[2, 2, 2]]
using
arange1 = torch.arange(3).view((3, 1)).repeat((1, 3))
Now, let's make a tensor of the destination indices
[[0, 2, 1],
[1, 0, 2],
[2, 1, 0]]
with
arange2 = (arange1 - indices) % 3
Lastly, we get the expected output with
torch.gather(mat, 0, arange2)
Upvotes: 3