ItM
ItM

Reputation: 331

Fastest way to shift rows of matrix in Python

I have a 4x4 matrix like this:

1  2  3  4
5  6  7  8
9  10 11 12
13 14 15 16

I want to shift each row left (left circular shift), by the amount of the row index. I.e. row 0 stays as is, row 1 shifts left 1, row 2 shifts left 2, etc.

So we get this:

1  2  3  4
6  7  8  5
11 12 9  10
16 13 14 15

The fastest way I've come up with to do this in Python is the following:

import numpy as np
def ShiftRows(x):
    x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
    return x

I need to run this function on thousands of 4x4 matrices like this, so speed is important (to the extent possible in Python). I'm not concerned about using other modules such as numpy, I'm only concerned with speed.

Any help would really be appreciated!

Thank you!

Upvotes: 2

Views: 1802

Answers (4)

Jan Christoph Terasa
Jan Christoph Terasa

Reputation: 5945

First improvement, get rid of list comprehension

I assume that your input will always be a 4x4 ndarray. If not, you need to modify the functions appropriately (i.e. add np.asarray, check dimensions etc.) Removing the list comprehension gives a nice speedup already:

import numpy as np

a = np.arange(16).reshape(4, 4)

def ShiftRows(x):
    x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
    return x

def shift(x):
    for i in range(1, 4):
        x[i] = np.append(x[i, i:], x[i, :i])
    return x

%timeit ShiftRows(a)
# 38.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit shift(a)
# 31.9 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Keep in mind that both variants modify the array in-place. If that is not what you want, add a x = x.copy() at the beginning of both functions.

From my testing numpy.roll is much slower than either version.

Second improvement, use numba

Now, the real speedup comes when we use numba:

import numba

@numba.njit
def shift_numba(x):
    for i in range(1, 4):
        x[i] = np.append(x[i, i:], x[i, :i])
    return x    

%timeit shift_numba(a)
# 2.5 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

That is about a factor of 15 faster than what you have now. Using parallel mode does not increase the performance, probably because of the tiny size of the array.


Test: Unrolling the loop

At request of Patrick Artner, I unrolled the loop (well possible with 4x4):

@numba.njit
def shift_numba_unrolled(x):
    x[1] = np.append(x[1, 1:], x[1, :1])
    x[2] = np.append(x[2, 2:], x[2, :2])
    x[3] = np.append(x[3, 3:], x[3, :3])
    return x

%timeit shift_numba_unrolled(a)
# 2.49 µs ± 85 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Unrolling seems to produce the same results.


EDIT: Fixed a big problem, speedup is much less now.

Upvotes: 1

Patrick Artner
Patrick Artner

Reputation: 51683

Base solution using just list of lists (for those that find this question without numpy in mind):

import numpy as np

def SimpleShift(x):
    for i in range(1,4):
        # inplace slicing
        x[i][:] = x[i][i:] + x[i][:i]
    return x

def EvenSimplerShift(x):
    # manually unrolled loop
    x[1][:] = x[1][1:] + x[1][:1]
    x[2][:] = x[2][2:] + x[2][:2]
    x[3][:] = x[3][3:] + x[3][:3]
    return x

from timeit import timeit

data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]

print(data)
print(SimpleShift(data))
print(EvenSimplerShift(data))

print(timeit(lambda:SimpleShift(data)))
print(timeit(lambda: EvenSimplerShift(data)))

to get

[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]

4.8055571                 # timing with for loop
4.098531100000001         # timing with unrolled loop

so manually unrolling the loop seems to be faster. you might want to take a look at that using numpy as well.

Upvotes: 0

Karl Knechtel
Karl Knechtel

Reputation: 61643

If you don't mind hard-coding an array size, in my testing it's about 6x as fast to just hard-code the rearrangement pattern:

def rot(a):
    return a.take((0,1,2,3,5,6,7,4,10,11,8,9,15,12,13,14)).reshape(4, 4)

Upvotes: 4

Grismar
Grismar

Reputation: 31416

This works:

import numpy as np


def stepped_roll(arr):
    return np.array([np.roll(row, -n) for n, row in enumerate(arr)])


print(stepped_roll(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])))

I'd favour using np.roll because the numpy routines tend to be faster than what you can do in Python. np.apply_along_axis doesn't work here sadly, because you need the index of each row as you go.

In your case though, the operation is so trivial and the data set so small, something like the shift() function suggested in @JanChristophTerasa's answer will be much faster.

Upvotes: 0

Related Questions