Reputation: 7254
I am writing a data pipeline to feed batches of time-series sequences and corresponding labels into an LSTM model which requires a 3D input shape. I currently have the following:
def split(window):
return window[:-label_length], window[-label_length]
dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
The resulting shape of for x, y in dataset.take(1): x.shape
is (32, 20), where 32 is the batch size and 20 the length of the sequence, but I need a shape of (32, 20, 1), where the additional dimension denotes the feature.
My question is how I can reshape, ideally in the split
function that is passed into the dataset.map
function before caching the data?
Upvotes: 1
Views: 859
Reputation: 11333
That's easy. Do this in your split function
def split(window):
return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]
Upvotes: 1