Mike
Mike

Reputation: 1827

np.roll vs scipy.interpolation.shift--discrepancy for integer shift values

I wrote some code to shift an array, and was trying to generalize it to handle non-integer shifts using the "shift" function in scipy.ndimage. The data is circular and so the result should wrap around, exactly as the np.roll command does it.

However, scipy.ndimage.shift does not appear to wrap integer shifts properly. The following code snippet shows the discrepancy:

import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt 

def shiftfunc(data, amt):
    return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)

if __name__ == "__main__":
    xvals = np.arange(100)*1.0

    yvals = np.sin(xvals*0.1)

    rollshift   = np.roll(yvals, 2)

    interpshift = shiftfunc(yvals, 2)

    plt.plot(xvals, rollshift, label = 'np.roll', alpha = 0.5)
    plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5)
    plt.legend()
    plt.show()

roll vs shift

It can be seen that the first couple of values are highly discrepant, while the rest are fine. I suspect this is an implementation error of the prefiltering and interpolation operation when using the wrap option. A way around this would be to modify shiftfunc to revert to np.roll when the shift value is an integer, but this is unsatisfying.

Am I missing something obvious here?

Is there a way to make ndimage.shift coincide with np.roll?

Upvotes: 1

Views: 1817

Answers (1)

plasmon360
plasmon360

Reputation: 4199

I dont think there is anything wrong with the shift function. when you use roll, your need to chop an extra element for fair comparision. please see the code below.

import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt 


def shiftfunc(data, amt):
    return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)

def rollfunc(data,amt):
    rollshift   = np.roll(yvals, amt)
    # Here I remove one element (first one before rollshift) from the array 
    return np.concatenate((rollshift[:amt], rollshift[amt+1:]))

if __name__ == "__main__":
    shift_by = 5
    xvals = np.linspace(0,2*np.pi,20)
    yvals = np.sin(xvals)
    rollshift   = rollfunc(yvals, shift_by)
    interpshift = shiftfunc(yvals,shift_by)
    plt.plot(xvals, yvals, label = 'original', alpha = 0.5)
    plt.plot(xvals[1:], rollshift, label = 'np.roll', alpha = 0.5,marker='s')
    plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5,marker='o') 
    plt.legend()
    plt.show()

results in

enter image description here

Upvotes: 1

Related Questions