knub
knub

Reputation: 4032

Using flat_map in Tensorflow's Dataset API

I'm using the dataset API, reading data as follows:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

I now want to use flat_map in order to filter out some, while duplicating some other samples dynamically at training time (this is the input function leading to my model).

The API for flat_map requires to return a Dataset object, however I don't know how to create that. Here's a pseudo-code implementation of what I want to achieve:

def flat_map_impl(tf_example):
    # Pseudo-code:
    # if tf_example["a"] == 1:
    #     return []
    # else:
    #     return [tf_example, tf_example]

dataset.flat_map(flat_map_impl)

How can I implement this in the flat_map function?

NOTE: I guess it's possible to implement this via a py_func, but I'd prefer to avoid this.

Upvotes: 6

Views: 10193

Answers (1)

mrry
mrry

Reputation: 126154

Perhaps the most common way to create a tf.data.Dataset when returning from a Dataset.flat_map() is to use Dataset.from_tensors() or Dataset.from_tensor_slices(). In this case, because tf_example is a dictionary, it is probably easiest to use a combination of Dataset.from_tensors() and Dataset.repeat(count), where a conditional expression computes count:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

def flat_map_impl(tf_example):
  count = tf.cond(tf.equal(tf_example["a"], 1)),
                  lambda: tf.constant(0, dtype=tf.int64),
                  lambda: tf.constant(2, dtype=tf.int64))

  return tf.data.Dataset.from_tensors(tf_example).repeat(count)

dataset = dataset.flat_map(flat_map_impl)

Upvotes: 8

Related Questions