Reputation: 6164
I'm trying to write some logic to return an array shifted one step to the right, with wrap around. I was relying on receiving an IndexError to implement the wrap-around, but no error is thrown!
def get_batches(arr, batch_size, seq_length):
"""
Return arr data as batches of shape (batch_size, seq_length)
"""
n_chars = batch_size * seq_length
n_batches = int(np.floor(len(arr)/ n_chars))
n_keep = n_chars * n_batches
arr = arr[:n_keep].reshape(batch_size, -1)
for b in range(n_batches):
start = b * seq_length
stop = start + seq_length
x = arr[:, start:stop]
try:
y = arr[:, start + 1: stop + 1]
except IndexError:
y = np.concatenate(x[:, 1:], arr[:, 0], axis=1)
yield x, y
So this code works great, except when the last y
array is yielded... I get a (2,2)
array instead of the expected (2,3)
. That's because an IndexError is never thrown.
test = np.arange(12)
batches = get_batches(test, 2, 3)
for x, y in batches:
print('x=', x)
print('y=', y, '\n')
yields
x=
[[0 1 2]
[6 7 8]]
y= # as expected
[[1 2 3]
[7 8 9]]
x=
[[ 3 4 5]
[ 9 10 11]]
y= # truncated :(
[[ 4 5]
[10 11]]
Does anyone have an alternative suggestion about how to get this done? Preferably something as simple as my failed solution?
Upvotes: 1
Views: 266
Reputation: 12407
Try this:
from skimage.util.shape import view_as_windows
def get_batches2(arr, batch_size, seq_length):
"""
Return arr data as batches of shape (batch_size, seq_length)
"""
n_chars = batch_size * seq_length
n_batches = int(np.floor(len(arr)/ n_chars))
n_keep = n_chars * n_batches
arr = arr[:n_keep].reshape(batch_size, -1)
x = view_as_windows(arr, (batch_size, seq_length), seq_length)[0]
y = view_as_windows(np.roll(arr,-1,axis=1), (batch_size, seq_length), seq_length)[0]
return x, y
view_as_windows
uses the same shared memory (It is a view. You can check to see if they share same memory). So it would not matter if you yield it with loops or return it. It will not use extra memory if that is the issue (specially that your windows are not overlapping), and it should be way faster than generator. You probably can even achieve this by simple reshaping without view_as_windows
too.
Upvotes: 1