Reputation: 41
Why does the following tf2 tf.keras model 'work' when fitted with tensors but generates a ValueError when attempting to fit the same tensors in tf.data.Dataset.from_tensor_slices form?
EDIT: Put another way, having developed/fitted/tested etc the model below using numpy arrays. How do those same numpy arrays need to be reshaped(?) so that they can be used to create a dataset with tf.data.Dataset.from_tensor_slices that works with the model?
embed = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(embed, output_shape=[20], input_shape=[],
dtype=tf.string, trainable=True, name='hub_layer')
# from tf hub docs. hub_layer takes a 1D tensor of strings.
input_tensor = tf.keras.Input(shape=(), name="input_enquiry", dtype=tf.string) # Note tf.string. Ref: https://github.com/tensorflow/hub/issues/483
hub_tensor = hub_layer(input_tensor)
x = tf.keras.layers.Dense(16, activation='relu')(hub_tensor)
main_output = tf.keras.layers.Dense(units=4, activation='softmax', name='main_output')(x)
model = tf.keras.models.Model(inputs=[input_tensor], outputs=[main_output])
model.compile(optimizer='adam', loss=tf.losses.CategoricalCrossentropy(),metrics='acc')
# Input and target
X = tf.constant([['The quick brown fox'], ['Hello World']])
y = tf.constant([[0,0,0,1], [0,0,1,0]])
# Works OK
model.fit(X, y) # fit on tensors
X_ds = tf.data.Dataset.from_tensor_slices(X)
# Works OK
model.predict(X_ds) # predict on dataset
y_ds = tf.data.Dataset.from_tensor_slices(y)
ds = tf.data.Dataset.zip((X_ds, y_ds))
# Fails with ValueError
model.fit(ds)
ValueError:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
30
31 # Fails with ValueError
---> 32 model.fit(ds)
33
34
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
846 batch_size=batch_size):
847 callbacks.on_train_batch_begin(step)
--> 848 tmp_logs = train_function(iterator)
849 # Catch OutOfRangeError for Datasets of unknown size.
850 # This blocks until the batch has finished executing.
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
578 xla_context.Exit()
579 else:
--> 580 result = self._call(*args, **kwds)
581
582 if tracing_count == self._get_tracing_count():
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
609 # In this case we have created variables on the first call, so we run the
610 # defunned version which is guaranteed to never create variables.
--> 611 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
612 elif self._stateful_fn is not None:
613 # Release the lock early so that multiple threads can perform the call
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
2417 """Calls a graph function specialized to the inputs."""
2418 with self._lock:
-> 2419 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2420 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2421
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
2772 and self.input_signature is None
2773 and call_context_key in self._function_cache.missed):
-> 2774 return self._define_function_with_shape_relaxation(args, kwargs)
2775
2776 self._function_cache.missed.add(call_context_key)
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _define_function_with_shape_relaxation(self, args, kwargs)
2704 relaxed_arg_shapes)
2705 graph_function = self._create_graph_function(
-> 2706 args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
2707 self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
2708
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2665 arg_names=arg_names,
2666 override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667 capture_by_value=self._capture_by_value),
2668 self._function_attributes,
2669 # Tell the ConcreteFunction to clean up its graph once it goes out of
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
979 _, original_func = tf_decorator.unwrap(python_func)
980
--> 981 func_outputs = python_func(*func_args, **func_kwargs)
982
983 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
439 # __wrapped__ allows AutoGraph to swap in a converted function. We give
440 # the function a weak reference to itself to avoid a reference cycle.
--> 441 return weak_wrapped_fn().__wrapped__(*args, **kwds)
442 weak_wrapped_fn = weakref.ref(wrapped_fn)
443
~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
ValueError: in user code:
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:571 train_function *
outputs = self.distribute_strategy.run(
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:951 run **
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
return fn(*args, **kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:533 train_step **
y, y_pred, sample_weight, regularization_losses=self.losses)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/compile_utils.py:205 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/losses.py:143 __call__
losses = self.call(y_true, y_pred)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/losses.py:246 call
return self.fn(y_true, y_pred, **self._fn_kwargs)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/losses.py:1527 categorical_crossentropy
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:4561 categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
/home/karl/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py:1117 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (4, 1) and (1, 4) are incompatible
If instead of using ".from_tensor_slices" we use ".from_tensors" to create X_ds and y_ds then, after zipping, all works well. However, the docs give me the impression ".from_tensors" is memory heavy and not desirable. Also, I believe that the single element ".from_tensors" dataset is simply providing the model with two 2D tensors whereas the from_tensor_slices version is a sequence of 1D elements.
Upvotes: 3
Views: 2108
Reputation: 41
Solution to the specific problem of the question was to .batch() the dataset:
ds = tf.data.Dataset.zip((X_ds, y_ds)).batch(32) # eg, batch size 32
My understanding (docs) is that the 'batch' presented to the model effectively restores the data outer dimension that was removed via the tf.data.Data.from_tensor_slices method. That is, the data is restored to the shape that worked with the original numpy arrays.
Upvotes: 1
Reputation: 643
According to the tf.data.Dataset
documentation
from_tensors
combines the input and returns a single element containing the dataset.
dataset = tf.data.Dataset.from_tensors([[1, 2], [3, 4]])
list(dataset.as_numpy_iterator())
[array([[1, 2], [3, 4]], dtype=int32)]
from_tensor_slices
slices the dataset along its first dimension and creates a dataset with a separate element for each row of the input tensor
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
list(dataset.as_numpy_iterator())
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
You are getting Value Error because the shape of from_tensors
is different from from_tensor_slices
Upvotes: 0