Reputation: 79
Background: This is one of the exercise problems in the text book Hands on Machine Learning by Aurelien Geron.
The question is: Write a function that can shift an MNIST image in any direction (left, right, up, down) by one pixel. Then for each image in the training set, create four shifted copies (one per direction) and add them to the training set.
My thought process:
My code:
import numpy as np
from scipy.ndimage.interpolation import shift
def shift_and_append(X, n):
x_arr = np.zeros((1, 784))
for i in range(n):
for j in range(-1,2):
for k in range(-1,2):
if j!=k and j!=-k:
x_arr = np.append(x_arr, shift(X[i,:].reshape(28,28), [j, k]).reshape(1, 784), axis=0)
return np.append(X, x_arr[1:,:], axis=0)
X_train_new = shift_and_append(X_train, X_train.shape[0])
y_train_new = np.append(y_train, np.repeat(y_train, 4), axis=0)
It takes a long time to run. I feel this is brute forcing it. Is there an efficient vector like method to achieve this?
Upvotes: 1
Views: 767
Reputation: 60400
3 nested for
loops with an if
condition while reshaping and appending is clearly not a good idea; numpy.roll
does the job beautifully in a vector way:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train.shape
# (60000, 28, 28)
# plot an original image
plt.gray()
plt.matshow(x_train[0])
plt.show()
Let's first demonstrate the operations:
# one pixel down:
x_down = np.roll(x_train[0], 1, axis=0)
plt.gray()
plt.matshow(x_down)
plt.show()
# one pixel up:
x_up = np.roll(x_train[0], -1, axis=0)
plt.gray()
plt.matshow(x_up)
plt.show()
# one pixel left:
x_left = np.roll(x_train[0], -1, axis=1)
plt.gray()
plt.matshow(x_left)
plt.show()
# one pixel right:
x_right = np.roll(x_train[0], 1, axis=1)
plt.gray()
plt.matshow(x_right)
plt.show()
Having established that, we can generate, say, "right" versions of all the training images simply by
x_all_right = [np.roll(x, 1, axis=1) for x in x_train]
and similarly for the other 3 directions.
Let's confirm that the first image in x_all_right
is indeed what we want:
plt.gray()
plt.matshow(x_all_right[0])
plt.show()
You can even avoid the last list comprehension in favor of pure Numpy code, as
x_all_right = np.roll(x_train, 1, axis=2)
which is more efficient, although slightly less intuitive (just take the respective single-image command versions and increase axis
by 1).
Upvotes: 2