momo
momo

Reputation: 1122

Slicing in tf.data causes "iterating over `tf.Tensor` is not allowed in Graph execution" error

I have a dataset created as follows where image_train_path is a list of image file paths, eg. [b'/content/drive/My Drive/data/folder1/im1.png', b'/content/drive/My Drive/data/folder2/im6.png',...]. I need to extract the folder paths eg, '/content/drive/My Drive/data/folder1' and follow with some other operations. I try to do this using the preprocessData function as follows.

dataset = tf.data.Dataset.from_tensor_slices(image_train_path)
dataset = dataset.map(preprocessData, num_parallel_calls=16)

where preprocessData is:

def preprocessData(images_path):
    folder=tf.strings.split(images_path,'/')
    foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
    ....

However, the slicing line causes the following error:

OperatorNotAllowedInGraphError: in user code:

    <ipython-input-21-2a9827982c16>:4 preprocessData  *
        foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:210 wrapper  **
        result = dispatch(wrapper, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:122 dispatch
        result = dispatcher.handle(args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/ragged/ragged_dispatch.py:130 handle
        for elt in x:
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:524 __iter__
        self._disallow_iteration()
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:520 _disallow_iteration
        self._disallow_in_graph_mode("iterating over `tf.Tensor`")
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:500 _disallow_in_graph_mode
        " this function with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

I tried this in both Tf2.4 as well as tf nightly. I tried decorating with @tf.function as well as using tf.data.experimental.enable_debug_mode(). Always gives the same error.

I don't quite understand which part is causing the 'iteration' though I guess the issue is the slicing. Is there an alternate way to accomplish this?

Upvotes: 1

Views: 595

Answers (1)

Lescurel
Lescurel

Reputation: 11631

The function tf.strings.join expects a list of Tensor, as the documentation states:

Args

inputs: A list of tf.Tensor objects of same size and tf.string dtype.

tf.slice returns a Tensor, and then the join function will try to iterate over it, causing the error.

You can feed the function with a simple list comprehension:

def preprocessData(images_path):
    folder=tf.strings.split(images_path,'/')
    foldername=tf.strings.join([folder[i] for i in range(6)],"/")
    return foldername

Upvotes: 1

Related Questions