Matt
Matt

Reputation: 7254

How to reshape data in Tensorflow dataset?

I am writing a data pipeline to feed batches of time-series sequences and corresponding labels into an LSTM model which requires a 3D input shape. I currently have the following:

def split(window):
    return window[:-label_length], window[-label_length]

dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

The resulting shape of for x, y in dataset.take(1): x.shape is (32, 20), where 32 is the batch size and 20 the length of the sequence, but I need a shape of (32, 20, 1), where the additional dimension denotes the feature.

My question is how I can reshape, ideally in the split function that is passed into the dataset.map function before caching the data?

Upvotes: 1

Views: 859

Answers (1)

thushv89
thushv89

Reputation: 11333

That's easy. Do this in your split function

def split(window):
    return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]

Upvotes: 1

Related Questions