Reputation: 261
I'm trying to build a NN that can be broken apart into two pieces, where each piece can be run independently. Subclassing the keras Model class allows this to be implemented nicely, as illustrated with this toy model:
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(5, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
intermediate_value = self.model_part_1(inputs)
final_output = self.model_part_2(intermediate_value)
return final_output
def model_part_1(self, inputs):
x = self.dense1(inputs)
return x
def model_part_2(self, inputs):
x = self.dense2(inputs)
return x
All of this works nicely except that the custom methods don't get carried through saving/loading. Using the standard model.save("saved_model_path")
, then loading with tf.keras.models.load_model("saved_model")
, the loaded model objects works as expected when running predict
, but no longer has the model_part_1
or model_part_2
methods (attributes dense1 and dense2 are properly loaded).
Adding the keyword argument custom_objects={"MyModel": MyModel}
when loading didn't solve the problem
It should be possible to add the methods to the loaded instance, but that's pretty hacky.
Upvotes: 6
Views: 2346
Reputation: 261
I was able to solve this by decorating the functions with tf.function:
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(5, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
intermediate_value = self._model_part_1(inputs)
final_output = self._model_part_2(intermediate_value)
return final_output
def _model_part_1(self, inputs):
x = self.dense1(inputs)
return x
def _model_part_2(self, inputs):
x = self.dense2(inputs)
return x
@tf.function(
input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)]
)
def model_part_1(self, inputs):
""" tf.function-deocrated version of _model_part_1 """
return self._model_part_1(inputs)
@tf.function(
input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)]
)
def model_part_2(self, inputs):
""" tf.function-deocrated version of _model_part_2 """
return self._model_part_2(inputs
)
After saving with the .save()
method and loading with tf.keras.models.load_model
, the decorated methods are available.
Note that I made new functions with the decorator; this is because calling a decorated function in the call
method caused errors.
Upvotes: 4