Reputation: 6377
I have a custom model with dynamic input shape (flexible second dimension).
I need to save it in SaveModel format. But it saves only one signature (the first used).
When I try to use different signature after loading - I am getting an error:
Python inputs incompatible with input_signature
My code is as follows:
seq_len = 2
batch_size = 3
import tensorflow as tf
class CustomModule(tf.keras.Model):
def __init__(self):
super(CustomModule, self).__init__()
self.v = tf.Variable(1.)
#@tf.function
def call(self, x):
return x * self.v
module_output = CustomModule()
input = tf.random.uniform([batch_size, seq_len], dtype=tf.float32)
input2 = tf.random.uniform([batch_size, seq_len+1], dtype=tf.float32)
output = tf.random.uniform([batch_size, seq_len], dtype=tf.float32)
output2 = tf.random.uniform([batch_size, seq_len+1], dtype=tf.float32)
optimizer = tf.keras.optimizers.SGD()
training_loss = tf.keras.losses.MeanSquaredError()
module_output.compile(optimizer=optimizer, loss=training_loss)
#hist = module_output.fit(input, output, epochs=1, steps_per_epoch=1, verbose=0)
#hist = module_output.fit(input2, output2, epochs=1, steps_per_epoch=1, verbose=0)
a = module_output(input) # the first signature
a = module_output(input2) # the second signature
module_output.save('savedModel/', True, False)
module_output = tf.keras.models.load_model('savedModel/')
a = module_output(input) # <= it works
a = module_output(input2) # <= the error is here
How can I make it work ?
Edit: It is a toy example. I can not compose model using functional API because the real model is too complicated.
Upvotes: 1
Views: 1917
Reputation: 92
You can manually specify the input shape/dtype decorating the call
function with @tf.function(input_signature=...)
.
In the example you show, you can decorate the call
function as follow:
@tf.function(input_signature=[tf.TensorSpec(shape=(batch_size, None), dtype=tf.float32)])
def call(self, x):
return x * self.v
Upvotes: 0
Reputation: 2751
Try creating the model using different Input shape, and using functional API:
def create_model(batch_size, seq_len):
inputs = tf.keras.Input(shape=(batch_size, seq_len)) #input layer
x = tf.keras.layers...(inputs) # next layer
x = tf.keras.layers...(x)
...
outputs = tf.keras.layers...(x) # output layer
model = tf.keras.Model(inputs = inputs, outputs = outputs)
model.compile(...)
return model
Since you inherit from Model, it should work if you replace the model declaration line.
Upvotes: 1