Reputation: 25
I have a trained model that I would like to employ in the tf.data
pipeline for a second model. When I try to do this, I get a ValueError: Unknown graph. Aborting.
I don't know quite what to make of this error message.
My code looks something like this:
def load_data(..., model):
# code to load an image
files = tf.data.Dataset.from_tensor_slices(file_list)
images = files.map(load_image_from_file)
def pass_image_through_model(img):
return model.predict(img, steps=1)
dataset = images.map(pass_image_through_model)
return dataset
What is wrong with this? The error I get is:
/home/.../code/dataloader.py:236 pass_image_through_model *
return model.predict(img, steps=1)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:1013 predict
use_multiprocessing=use_multiprocessing)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:728 predict
callbacks=callbacks)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:189 model_iteration
f = _make_execution_function(model, mode)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:571 _make_execution_function
return model._make_execution_function(mode)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2131 _make_execution_function
self._make_predict_function()
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2121 _make_predict_function
**kwargs)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3760 function
return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
/home/.../anaconda3/envs/masters/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:3644 __init__
raise ValueError('Unknown graph. Aborting.')
ValueError: Unknown graph. Aborting.
Upvotes: 2
Views: 1855
Reputation: 349
One of the simplest ways to tackle this is to pass the input to the model directly, rather than using model.predit
method. The reason for this is that model.predict
returns a numpy.ndarray
. This causes an error because tf.data
uses graph execution, which means it's best to have any operation input AND output a tensor within that graph.
Below is a quick working example of this.
import tensorflow as tf
# Create example model
inputs = tf.keras.Input((1,))
out = tf.keras.layers.Dense(1)(inputs)
model = tf.keras.Model(inputs, out)
def map_fn(row):
return model(row)
# Create some input data
a = tf.constant([1, 2])
# Create the dataset
ds = tf.data.Dataset.from_tensor_slices(a).batch(1)
model_mapped_ds = ds.map(lambda x: map_fn(x))
for el in model_mapped_ds:
print(el)
Finally, below is what it would look like in your usage.
def pass_image_through_model(img):
return model(img) # this returns a tensor
@tf.function
def load_data(..., model):
# code to load an image
files = tf.data.Dataset.from_tensor_slices(file_list).batch(1) # Don't forget batch size!
images = files.map(load_image_from_file)
dataset = images.map(pass_image_through_model)
return dataset
Upvotes: 3
Reputation: 15003
The error that you are getting is maybe silent if it is your first time dealing with tf.data.Dataset()
object.
All the operations in tf.data.Dataset()
are actually executed in graph mode and you cannot use any functions outside those predefined in tf.*
.
The only way you can mix arbitrary Python code with your tf.data.Dataset()
is to use tf.py_function()
, otherwise an error will be thrown.
Please bear in mind that mixing Python code with optimised tf.data.Dataset()
code will lead to a decrease in time performance.
The only way to test is to retrieve your dataset, use as_numpy_iterator()
to fetch your data and predict with your model, therefore outside of the mapping process.
Upvotes: 0