Reputation: 289
I made a custom layer named Channel
:
class Channel(Layer):
def __init__(self, snr):
...
def call(self, inputs):
...
def get_snr(self):
return self.snr
def set_snr(self, snr):
self.snr = snr
And I used it in a keras functional model. Trained the model and saved it, then loaded the model in a new notebook. Now the problem is that my Channel
layer does not have get_snr
and set_snr
methods with it (or even snr
property). How should I change the snr of my channel layer? Even changing the old Channel
layer with new Channel
layer (having proper snr) is not easy, because the model is functional and I should parse and reconstruct the graph.
Notebook 1:
from tensorflow import keras
from tensorflow.keras.layers import *
class Channel(Layer):
def __init__(self, snr):
super(Channel, self).__init__()
self.set_snr(snr)
def call(self, inputs):
return inputs + self.snr # Assume just a simple add
def get_snr(self):
return self.snr
def set_snr(self, snr):
self.snr = snr
channel = Channel(10)
model_in = Input((32, 32, 3))
x = Conv2D(16, 3)(model_in)
x = channel(x)
x = Conv2D(16, 3)(x)
model = keras.models.Model(model_in, x)
model.save("/content/model")
Notebook 2:
from tensorflow import keras
model = keras.models.load_model("/content/model")
channel = model.layers[2]
channel.set_snr(20) # Error: 'Channel' object has no attribute 'set_snr'
print(channel.get_snr()) # Error: 'Channel' object has no attribute 'get_snr'
Upvotes: 0
Views: 180
Reputation: 17239
As you have a custom layer i.e. Channel
, you need to update get_config
and specify the custom_objects
attribute to successfully load the model. Doc. Here is the full working code
from tensorflow import keras
from tensorflow.keras.layers import *
class Channel(Layer):
def __init__(self, snr, **kwargs):
super(Channel, self).__init__(**kwargs)
self.snr = snr
self.set_snr(self.snr)
def call(self, inputs):
return inputs + self.snr
def get_snr(self):
return self.snr
def set_snr(self, snr):
self.snr = snr
def get_config(self):
config = super(Channel, self).get_config()
config.update(
{
"snr": self.snr
}
)
return config
model_in = Input((32, 32, 3))
x = Conv2D(16, 3)(model_in)
x = Channel(10)(x)
x = Conv2D(16, 3)(x)
model = keras.models.Model(model_in, x)
And save and reload with the custom_object
argument. Check out this doc.
model.save("/content/model")
new_model = keras.models.load_model("/content/model",
custom_objects={"Channel": Channel})
channel = new_model.layers[2]
channel.set_snr(20)
print(channel.get_snr())
channel.set_snr(200)
print(channel.get_snr())
channel.set_snr(2000)
print(channel.get_snr())
INFO:tensorflow:Assets written to: /content/model/assets
WARNING:tensorflow:No training configuration found in save file,
so the model was *not* compiled. Compile it manually.
20
200
2000
Another way is as follows. In this way, we save the model weights and load the weights later.
model.save_weights('sw.h5')
new_model = keras.models.Model(model_in, x)
new_model.load_weights('sw.h5')
channel = new_model.layers[2]
channel.set_snr(50)
print(channel.get_snr())
channel.set_snr(500)
print(channel.get_snr())
channel.set_snr(5000)
print(channel.get_snr())
50
500
5000
Upvotes: 2