Reputation: 683
I have a dataset of variable-length sequences (a tensorflow TFRecord dataset) to feed an LSTM network and I want to try and compare pre- and post-padding in the batches, but current padded_batch function only pads at the sequences end. I know that we have tf.keras.preprocessing.sequence.pad_sequences
function in API but I don't know how to apply this function to dataset batch processor. The padded_batch function in tensorflow does both padding and batching, and it will find the required paddding size per batch dynamically. How can I implement this myself? My code right now is like this, and I am reading multiple TFRecord files and interleave them to make my mixed dataset:
featuresDict = {'data': tf.FixedLenFeature([], dtype=tf.string),
'rows': tf.FixedLenFeature([], dtype=tf.int64),
'label': tf.FixedLenFeature([], dtype=tf.int64)
}
def parse_tfrecord(example):
features = tf.parse_single_example(example, featuresDict)
label = tf.one_hot(features['label'],N)
rows = features['rows']
data = tf.decode_raw(features['data'], tf.int64)
data = tf.reshape(data, (rows,num_features)
return data, label
def read_datasets(pattern, numFiles, numEpochs=None, batchSize=None):
files = tf.data.Dataset.list_files(pattern)
def _parse(x):
x = tf.data.TFRecordDataset(x, compression_type='GZIP')
return x
dataset = files.interleave(_parse, cycle_length=numFiles, block_length=1).map(parse_tfrecord)
padded_shapes = (tf.TensorShape([None, num_features]), tf.TensorShape([N,])))
dataset = dataset.padded_batch(batchSize, padded_shapes)
dataset = dataset.prefetch(buffer_size=batchSize)
dataset = dataset.repeat(numEpochs)
return dataset
Upvotes: 2
Views: 1358
Reputation: 21
I have the same problem like you, and I also notice that you also raise an issue on the tensorflow. I manage to solve it with pad_sequences
, and I think I solve it right!!
import numpy as np
import tensorflow as tf
# Code snippet from https://www.tensorflow.org/guide/data
# Generator that generate the data
def gen_series():
i = 0
np.random.seed(0)
while True:
size = np.random.randint(0, 10)
yield np.random.normal(size=(size, ))
i += 1
# Transform the generator to Dataset
ds_series = tf.data.Dataset.from_generator(gen_series,
output_types=(tf.float32),
output_shapes=((None, )))
# output_shapes is (None, ) because the vector is unknown size
# Take first 5 samples
print("Before padding")
for vector in ds_series.take(5):
print(vector)
print("*" * 10)
# Start to transform
def pad_session(session):
"""
We pad the sequece in the pre-order with maxlen is 5.
If any vector is larger than 5, we truncate the pre-sequence
"""
return tf.keras.preprocessing.sequence.pad_sequences(
[session.numpy()],
maxlen=5,
truncating="pre",
padding='pre',
value=0.0,
dtype=np.float).squeeze()
def pad_map_fn(session):
return tf.py_function(pad_session, inp=[session], Tout=(tf.float32))
padded_dataset = ds_series.map(pad_map_fn)
print("After padding")
for pre_padded_vector in padded_dataset.take(5):
print(pre_padded_vector)
and will generate the following output
Before padding
tf.Tensor([ 0.11849646 0.1139678 0.37025538 1.0405308 -1.5169828 ], shape=(5,), dtype=float32)
tf.Tensor(
[-0.8662762 -0.10321885 0.41059852 0.14404356 1.4542735 0.7610377
0.12167501 0.44386324], shape=(8,), dtype=float32)
tf.Tensor([0.33367434 1.4143772 ], shape=(2,), dtype=float32)
tf.Tensor([-0.12405066 1.1682731 0.94718593], shape=(3,), dtype=float32)
tf.Tensor([-2.5529897], shape=(1,), dtype=float32)
**********
After padding
tf.Tensor([ 0.11849646 0.1139678 0.37025538 1.0405308 -1.5169828 ], shape=(5,), dtype=float32)
tf.Tensor([0.14404356 1.4542735 0.7610377 0.12167501 0.44386324], shape=(5,), dtype=float32)
tf.Tensor([0. 0. 0. 0.33367434 1.4143772 ], shape=(5,), dtype=float32)
tf.Tensor([ 0. 0. -0.12405066 1.1682731 0.94718593], shape=(5,), dtype=float32)
tf.Tensor([ 0. 0. 0. 0. -2.5529897], shape=(5,), dtype=float32)
I want to get your attention into the pad_session
, we pass the sequence [session.numpy]
into pad_sequences
because we need to pass 2-d array into it.
Maybe there is a better way to solve it, but here's the answer that I get.
Hope it will help you!!!
Upvotes: 1