zimmerrol
zimmerrol

Reputation: 4951

Use tensorflow operation in tf.keras Model

I am trying to use a pre-trained model, adding some new layers and operation and perform a training session in tensorflow. Therefore, I stumpled upon the tf.keras.applications.* namespace and started to use some of the implemented models there.

After loading the base model, I am adding these new layers like this:

x = base_model.output
# this line seems to cause my error
x = tf.reshape(x, [-1, 1]) 
# using this line solves the issue
# tf.keras.layers.Flatten()(x) #
x = tf.keras.layers.Dense(1024, activation="relu")(x)
x = tf.keras.layers.Dense(5, activation="softmax")(x)

When I now create a new tf.keras.models.Model(...) from the Tensor x, I get this error message:

Output tensors to a Model must be the output of a TensorFlow `Layer`
(thus holding past layer metadata).
Found: Tensor("dense_27/Softmax:0", shape=(?, 3), dtype=float32)

This exception is caused because of using a tf.* operation inside the tf.keras model, I guess. In this situation I could easily use the keras alterantive instead, but now I have started wondering if there exists a workaround to use tensor operations inside the keras model anyhow. Or am I completely restricted to use tf.keras.layer.* operations?

Upvotes: 5

Views: 2512

Answers (1)

Yu-Yang
Yu-Yang

Reputation: 14619

As have been mentioned in the comment, you need to wrap TF operations in a Lambda layer (or any self-defined layer) so that Keras can find the required metadata for building the Model object.

x = base_model.output
x = tf.keras.layers.Lambda(lambda x: tf.reshape(x, [-1, 1]))(x)
x = tf.keras.layers.Dense(1024, activation="relu")(x)
x = tf.keras.layers.Dense(5, activation="softmax")(x)

It's probably worth noting that when trying to save and load this model, there would be an error complaining that the name tf is not defined.

model = tf.keras.Model(base_model.input, x)
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
model.save('1.h5')
m = tf.keras.models.load_model('1.h5')

# NameError: name 'tf' is not defined

It's because during model loading, tf is not imported in the scope where the Lambda layer is re-constructed. It can be solved via providing a custom_objects dictionary to load_model.

m = tf.keras.models.load_model('1.h5', custom_objects={'tf': tf})

Upvotes: 7

Related Questions