lzsl
lzsl

Reputation: 41

Flatten a dataset in TensorFlow

I am trying to convert a dataset in TensorFlow to have several single-valued tensors. The dataset currently looks like this:

[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ...

After the transformation it should look like this:

[12] [43] [64] [34] [45] [2] [13] [54] [34] [65] [34] [67] [87] [12] ...

My initial idea was using flat_map on the data set and then converting each tensor to a list of tensors using reshape and unstack:

output_labels = self.dataset.flat_map(convert_labels)

...

def convert_labels(tensor):
    id_list = tf.unstack(tf.reshape(tensor, [-1, 1]))
    return tf.data.Dataset.from_tensors(id_list)

However the shape of each tensor is only partially known (i.e. (?, 1)) which is why the unstack operation fails. Is there any way to still "concat" the different tensors without explicitly iterating over them?

Upvotes: 4

Views: 5010

Answers (1)

mrry
mrry

Reputation: 126154

Your solution is very close, but Dataset.flat_map() takes a function that returns a tf.data.Dataset object, rather than a list of tensors. Fortunately, the Dataset.from_tensor_slices() method works for exactly your use case, because it can split a tensor into a variable number of elements:

output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)

Note that the tf.contrib.data.unbatch() transformation implements the same functionality, and has a slightly more efficient implementation in the current master branch of TensorFlow (will be included in the 1.9 release):

output_labels = self.dataset.apply(tf.contrib.data.unbatch())

Upvotes: 4

Related Questions