Reputation: 4032
I'm using the dataset API, reading data as follows:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
I now want to use flat_map
in order to filter out some, while duplicating some other samples dynamically at training time (this is the input function leading to my model).
The API for flat_map
requires to return a Dataset
object, however I don't know how to create that. Here's a pseudo-code implementation of what I want to achieve:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]
dataset.flat_map(flat_map_impl)
How can I implement this in the flat_map
function?
NOTE: I guess it's possible to implement this via a py_func
, but I'd prefer to avoid this.
Upvotes: 6
Views: 10193
Reputation: 126154
Perhaps the most common way to create a tf.data.Dataset
when returning from a Dataset.flat_map()
is to use Dataset.from_tensors()
or Dataset.from_tensor_slices()
. In this case, because tf_example
is a dictionary, it is probably easiest to use a combination of Dataset.from_tensors()
and Dataset.repeat(count)
, where a conditional expression computes count
:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
def flat_map_impl(tf_example):
count = tf.cond(tf.equal(tf_example["a"], 1)),
lambda: tf.constant(0, dtype=tf.int64),
lambda: tf.constant(2, dtype=tf.int64))
return tf.data.Dataset.from_tensors(tf_example).repeat(count)
dataset = dataset.flat_map(flat_map_impl)
Upvotes: 8