bluesummers
bluesummers

Reputation: 12657

tf.data.Dataset.padded_batch pad differently each feature

I have a tf.data.Dataset instance which holds 3 different features

I am trying to use tf.data.Dataset.padded_batch() to genereate padded data as input to my model - and I want to pad every feature differently.

Example batch:

[{'label': 24,
  'sequence_feature': [1, 2],
  'seq_of_seqs_feature': [[11.1, 22.2],
                          [33.3, 44.4]]},
 {'label': 32,
  'sequence_feature': [3, 4, 5],
  'seq_of_seqs_feature': [[55.55, 66.66]]}]

Expected output:

[{'label': 24,
  'sequence_feature': [1, 2, 0],
  'seq_of_seqs_feature': [[11.1, 22.2],
                          [33.3, 44.4]]},
 {'label': 32,
  'sequence_feature': [3, 4, 5],
  'seq_of_seqs_feature': [[55.55, 66.66],
                           0.0, 0.0    ]}]

As you can see the label feature should not be padded, and the sequence_feature and seq_of_seqs_feature should be padded by the corresponding longest entry in the given batch.

Upvotes: 13

Views: 9155

Answers (1)

mrry
mrry

Reputation: 126194

The tf.data.Dataset.padded_batch() method allows you to specify padded_shapes for each component (feature) of the resulting batch. For example, if your input dataset is called ds:

padded_ds = ds.padded_batch(
    BATCH_SIZE,
    padded_shapes={
        'label': [],                          # Scalar elements, no padding.
        'sequence_feature': [None],           # Vector elements, padded to longest.
        'seq_of_seqs_feature': [None, None],  # Matrix elements, padded to longest
    })                                        # in each dimension.

Notice that the padded_shapes argument has the same structure as your input dataset's elements, so in this case it takes a dictionary with keys that match your feature names.

Upvotes: 21

Related Questions