Reputation: 327
I have for example
123
1234
12345
1234556
1234567890
It is easy to make global padding like this
0000000123
0000001234
0000012345
0001234556
1234567890
But I wanna padding in every generated by dataset api batch. For example with batch size 3 it take 3 random samples
123
1234
12345
And pad it like this
00123
01234
12345
I can do it in numpy for example, but this is how batches constructed in tf api:
data = tf.data.Dataset.from_tensor_slices((X, y))
data = data.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=len(y)))
data = data.batch(batch_size, drop_remainder=False)
data = data.prefetch(2)
Upvotes: 1
Views: 2284
Reputation: 608
If I get it right, you can use keras pad_sequences:
sequence = np.array([[1,2], [1, 2, 3, 4], [1, 2, 3,4, 5, 6]])
tf.keras.preprocessing.sequence.pad_sequences(sequence, padding='pre', value=0)
array([[0, 0, 0, 0, 1, 2],
[0, 0, 1, 2, 3, 4],
[1, 2, 3, 4, 5, 6]])
Upvotes: 0
Reputation: 96
You can use the padded_batch method.
data.padded_batch(batch_size, padded_shapes=max_shape)
where max_shape is the size of the padded tensor you want.
I believe that this will append trailing zeros instead of leading zeros, but it is probably still suitable for your purpose.
EDIT
Complete working example:
import tensorflow as tf
import numpy as np
def gen():
yield (np.array([1,2,3]), np.array(1))
yield (np.array([1,2,3,4]), np.array(0))
data = tf.data.Dataset.from_generator(gen, output_types=(tf.int32, tf.int32))
data = data.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=2))
data = data.padded_batch(10, padded_shapes=([None], []))
iterator = tf.data.Iterator.from_structure(data.output_types, data.output_shapes)
batch = iterator.get_next()
init_op = iterator.make_initializer(data)
with tf.Session() as sess:
sess.run(init_op)
batch_out = sess.run(batch)
print(batch_out)
Upvotes: 3
Reputation: 847
If I understood correctly, you could do:
import os
data = """123
1234
12345"""
lines = data.splitlines()
max_len = max((len(i) for i in lines))
lines = (i.rjust(max_len, '0') for i in lines)
data = os.linesep.join(lines)
print(data)
Output:
00123
01234
12345
Upvotes: 0