Reputation: 2707
TensorFlow's excellent Dataset abstraction can use filtering with a predicate:
filter filter(predicate) Filters this dataset according to predicate.
Args: predicate: A function mapping a nested structure of tensors (having shapes and types defined by self.output_shapes and self.output_types) to a scalar tf.bool tensor.
This is very powerful; as the predicate allows you to filter on dataset contents.
The question is: Is it possible to have the 'opposite' of filtering: e.g. oversampling?
It does not seem possible with take()
as that does not depend on dataset contents:
take take(count) Creates a Dataset with at most count elements from this dataset.
Args: count: A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset. If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.
Upvotes: 0
Views: 1550
Reputation: 36
TensorFlow doesn't currently expose such functionality, but you can achieve the result you want with flat_map
. In this case, for each element of the input dataset, you create a new dataset (tf.data.Dataset.from_tensors
) that produces multiple copies of this single sample (.repeat
).
For example:
import numpy as np
import tensorflow as tf
def run(dataset):
el = dataset.make_one_shot_iterator().get_next()
vals = []
with tf.Session() as sess:
try:
while True:
vals.append(sess.run(el))
except tf.errors.OutOfRangeError:
pass
return vals
dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
print('Original dataset with repeats')
print(run(dataset))
dataset = dataset.flat_map(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r))
print('Repeats flattened')
print(run(dataset))
will print
Original dataset with repeats
[(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]
Repeats flattened
[1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5]
Alternatively, you can use .interleave
to achieve the same result but mix copies of multiple samples (.flat_map
is a particular case of .interleave
). For instance:
dataset = tf.data.Dataset.from_tensor_slices((np.array([1,2,3,4,5]), np.array([5,4,3,2,1])))
dataset = dataset.interleave(lambda v, r: tf.data.Dataset.from_tensors(v).repeat(r), 4, 1)
print('Repeats flattened with a little bit of deterministic mixing')
print(run(dataset))
will print
Repeats flattened with a little bit of deterministic mixing
[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 1, 2, 5, 1]
Upvotes: 2