had
had

Reputation: 327

How to make padding to max sequence lengths in batch with tensorflow dataset api?

I have for example

123
1234
12345
1234556
1234567890

It is easy to make global padding like this

0000000123
0000001234
0000012345
0001234556
1234567890

But I wanna padding in every generated by dataset api batch. For example with batch size 3 it take 3 random samples

123
1234
12345

And pad it like this

00123
01234
12345

I can do it in numpy for example, but this is how batches constructed in tf api:

data = tf.data.Dataset.from_tensor_slices((X, y))
data = data.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=len(y)))
data = data.batch(batch_size, drop_remainder=False)
data = data.prefetch(2)

Upvotes: 1

Views: 2284

Answers (3)

Alex_Y
Alex_Y

Reputation: 608

If I get it right, you can use keras pad_sequences:

sequence = np.array([[1,2], [1, 2, 3, 4], [1, 2, 3,4, 5, 6]])  

tf.keras.preprocessing.sequence.pad_sequences(sequence, padding='pre', value=0)  

array([[0, 0, 0, 0, 1, 2],  
       [0, 0, 1, 2, 3, 4],  
       [1, 2, 3, 4, 5, 6]])

Upvotes: 0

tlitfin
tlitfin

Reputation: 96

You can use the padded_batch method.

data.padded_batch(batch_size, padded_shapes=max_shape)

where max_shape is the size of the padded tensor you want.

I believe that this will append trailing zeros instead of leading zeros, but it is probably still suitable for your purpose.

EDIT

Complete working example:

import tensorflow as tf
import numpy as np

def gen():
    yield (np.array([1,2,3]), np.array(1))
    yield (np.array([1,2,3,4]), np.array(0))

data = tf.data.Dataset.from_generator(gen, output_types=(tf.int32, tf.int32))
data = data.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=2))
data = data.padded_batch(10, padded_shapes=([None], []))
iterator = tf.data.Iterator.from_structure(data.output_types, data.output_shapes)
batch = iterator.get_next()
init_op = iterator.make_initializer(data)

with tf.Session() as sess:
    sess.run(init_op)
    batch_out = sess.run(batch)
    print(batch_out)


Upvotes: 3

Vinicius Cainelli
Vinicius Cainelli

Reputation: 847

If I understood correctly, you could do:

import os

data = """123
1234
12345"""

lines = data.splitlines()
max_len = max((len(i) for i in lines))

lines = (i.rjust(max_len, '0') for i in lines)
data = os.linesep.join(lines)

print(data)

Output:

00123
01234
12345

Upvotes: 0

Related Questions