TimZaman
TimZaman

Reputation: 2707

Take from a tf.data.Dataset with a predicate (like filter)

TensorFlow's excellent Dataset abstraction can use filtering with a predicate:

filter filter(predicate) Filters this dataset according to predicate.

Args: predicate: A function mapping a nested structure of tensors (having shapes and types defined by self.output_shapes and self.output_types) to a scalar tf.bool tensor.

This is very powerful; as the predicate allows you to filter on dataset contents.

The question is: Is it possible to have the 'opposite' of filtering: e.g. oversampling?

It does not seem possible with take() as that does not depend on dataset contents:

take take(count) Creates a Dataset with at most count elements from this dataset.

Args: count: A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset. If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.

Upvotes: 0

Views: 1550

Answers (1)

mirandaconrado
mirandaconrado

Reputation: 36

TensorFlow doesn't currently expose such functionality, but you can achieve the result you want with flat_map. In this case, for each element of the input dataset, you create a new dataset (tf.data.Dataset.from_tensors) that produces multiple copies of this single sample (.repeat).

For example:

import numpy as np
import tensorflow as tf

def run(dataset):
    el = dataset.make_one_shot_iterator().get_next()
    vals = []
    with tf.Session() as sess:
        try:
            while True:
                vals.append(sess.run(el))
        except tf.errors.OutOfRangeError:
            pass

    return vals

dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
print('Original dataset with repeats')
print(run(dataset))

dataset = dataset.flat_map(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r))
print('Repeats flattened')
print(run(dataset))

will print

Original dataset with repeats
[(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]
Repeats flattened
[1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]

Alternatively, you can use .interleave to achieve the same result but mix copies of multiple samples (.flat_map is a particular case of .interleave). For instance:

dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
dataset = dataset.interleave(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r), 4, 1)
print('Repeats flattened with a little bit of deterministic mixing')
print(run(dataset))

will print

Repeats flattened with a little bit of deterministic mixing
[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 1, 2, 5, 1]

Upvotes: 2

Related Questions