Monte
Monte

Reputation: 25

Is there a way to use tf.keras.model.predict within a tf.data pipeline?

I have a trained model that I would like to employ in the tf.data pipeline for a second model. When I try to do this, I get a ValueError: Unknown graph. Aborting. I don't know quite what to make of this error message.

My code looks something like this:

def load_data(..., model):
    # code to load an image
    files = tf.data.Dataset.from_tensor_slices(file_list)
    images = files.map(load_image_from_file) 

    def pass_image_through_model(img):
        return model.predict(img, steps=1)

    dataset = images.map(pass_image_through_model)
    return dataset

What is wrong with this? The error I get is:

    /home/.../code/dataloader.py:236 pass_image_through_model  *
        return model.predict(img, steps=1)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:1013 predict
        use_multiprocessing=use_multiprocessing)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:728 predict
        callbacks=callbacks)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:189 model_iteration
        f = _make_execution_function(model, mode)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:571 _make_execution_function
        return model._make_execution_function(mode)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2131 _make_execution_function
        self._make_predict_function()
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2121 _make_predict_function
        **kwargs)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3760 function
        return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
    /home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3644 __init__
        raise ValueError('Unknown graph. Aborting.')

    ValueError: Unknown graph. Aborting.

Upvotes: 2

Views: 1855

Answers (2)

Wolfgang
Wolfgang

Reputation: 349

One of the simplest ways to tackle this is to pass the input to the model directly, rather than using model.predit method. The reason for this is that model.predict returns a numpy.ndarray. This causes an error because tf.data uses graph execution, which means it's best to have any operation input AND output a tensor within that graph.

Below is a quick working example of this.

import tensorflow as tf

# Create example model
inputs = tf.keras.Input((1,))
out = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, out)

def map_fn(row):
    return model(row)


# Create some input data 
a = tf.constant([1, 2])

# Create the dataset
ds = tf.data.Dataset.from_tensor_slices(a).batch(1)
model_mapped_ds = ds.map(lambda x: map_fn(x))

for el in model_mapped_ds:
    print(el)

Finally, below is what it would look like in your usage.


def pass_image_through_model(img):
    return model(img) # this returns a tensor 

@tf.function
def load_data(..., model):
    # code to load an image
    files = tf.data.Dataset.from_tensor_slices(file_list).batch(1) # Don't forget batch size!
    images = files.map(load_image_from_file) 

    dataset = images.map(pass_image_through_model)
    return dataset

Upvotes: 3

Timbus Calin
Timbus Calin

Reputation: 15003

The error that you are getting is maybe silent if it is your first time dealing with tf.data.Dataset() object.

All the operations in tf.data.Dataset() are actually executed in graph mode and you cannot use any functions outside those predefined in tf.*.

The only way you can mix arbitrary Python code with your tf.data.Dataset() is to use tf.py_function(), otherwise an error will be thrown.

Please bear in mind that mixing Python code with optimised tf.data.Dataset() code will lead to a decrease in time performance.

The only way to test is to retrieve your dataset, use as_numpy_iterator() to fetch your data and predict with your model, therefore outside of the mapping process.

Upvotes: 0

Related Questions