Milad
Milad

Reputation: 5490

Returning strings in tf.data.Dataset map method

In Tensorflow 1.4.1 the map method in tf.data.Dataset could return strings so I could return something like this my map function:

return filename, image, one_hot_label

where filename is string. This doesn't work anymore in TF1.5+:

    dataset = dataset.map(self._mapper)

      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 838, in map

    return MapDataset(self, map_func)

      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1826, in __init__

    self._map_func.add_to_graph(ops.get_default_graph())

      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 488, in add_to_graph

    self._create_definition_if_needed()

      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed

    self._create_definition_if_needed_impl()

      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl

    outputs = self._func(*inputs)

      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1814, in tf_map_func

    ret, [t.get_shape() for t in nest.flatten(ret)])

      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1814, in <listcomp>

    ret, [t.get_shape() for t in nest.flatten(ret)])

AttributeError: 'str' object has no attribute 'get_shape'

Is this by design or a regression?

A reproducible example:

import tensorflow as tf

def map_fn(x):
    return x*2, 'foo'

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(map_fn)

Upvotes: 3

Views: 1682

Answers (1)

mikkola
mikkola

Reputation: 3476

As discussed in the comments, this seems to be a bug in TF 1.5 up to at least 1.6, likely also 1.7. I have opened a Github issue on this at https://github.com/tensorflow/tensorflow/issues/18355

Until the issue is addressed in a future Tensorflow version, I would suggest to explicitly convert the string output to a tensor:

import tensorflow as tf

def map_fn(x):
    # Explicitly convert 'foo' to tensor
    return x*2, tf.convert_to_tensor('foo')

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(map_fn)

Upvotes: 1

Related Questions