amin
amin

Reputation: 289

Keras custom layers methods not reachable after loading the model

I made a custom layer named Channel:

class Channel(Layer):
  def __init__(self, snr):
    ...
  
  def call(self, inputs):
    ...
  
  def get_snr(self):
    return self.snr

  def set_snr(self, snr):
    self.snr = snr

And I used it in a keras functional model. Trained the model and saved it, then loaded the model in a new notebook. Now the problem is that my Channel layer does not have get_snr and set_snr methods with it (or even snr property). How should I change the snr of my channel layer? Even changing the old Channel layer with new Channel layer (having proper snr) is not easy, because the model is functional and I should parse and reconstruct the graph.


Simple complete example

Notebook 1:

from tensorflow import keras
from tensorflow.keras.layers import *

class Channel(Layer):
  def __init__(self, snr):
    super(Channel, self).__init__()
    self.set_snr(snr)

  def call(self, inputs):
    return inputs + self.snr  # Assume just a simple add

  def get_snr(self):
    return self.snr

  def set_snr(self, snr):
    self.snr = snr

channel = Channel(10)

model_in = Input((32, 32, 3))
x = Conv2D(16, 3)(model_in)
x = channel(x)
x = Conv2D(16, 3)(x)
model = keras.models.Model(model_in, x)

model.save("/content/model")

Notebook 2:

from tensorflow import keras

model = keras.models.load_model("/content/model")
channel = model.layers[2]

channel.set_snr(20)  # Error: 'Channel' object has no attribute 'set_snr'
print(channel.get_snr())  # Error: 'Channel' object has no attribute 'get_snr'

Upvotes: 0

Views: 180

Answers (1)

Innat
Innat

Reputation: 17239

As you have a custom layer i.e. Channel, you need to update get_config and specify the custom_objects attribute to successfully load the model. Doc. Here is the full working code

from tensorflow import keras
from tensorflow.keras.layers import *

class Channel(Layer):
    def __init__(self, snr,  **kwargs):
        super(Channel, self).__init__(**kwargs)
        self.snr = snr
        self.set_snr(self.snr)

    def call(self, inputs):
        return inputs + self.snr  

    def get_snr(self):
        return self.snr

    def set_snr(self, snr):
        self.snr = snr

    def get_config(self):
        config = super(Channel, self).get_config()
        config.update(
            {
                "snr": self.snr
             }
        )
        return config

model_in = Input((32, 32, 3))
x = Conv2D(16, 3)(model_in)
x =  Channel(10)(x)
x = Conv2D(16, 3)(x)
model = keras.models.Model(model_in, x)

And save and reload with the custom_object argument. Check out this doc.

model.save("/content/model")
new_model = keras.models.load_model("/content/model", 
                                custom_objects={"Channel": Channel})
channel = new_model.layers[2]

channel.set_snr(20) 
print(channel.get_snr())  

channel.set_snr(200) 
print(channel.get_snr())

channel.set_snr(2000) 
print(channel.get_snr())

INFO:tensorflow:Assets written to: /content/model/assets
WARNING:tensorflow:No training configuration found in save file, 
so the model was *not* compiled. Compile it manually.
20
200
2000

Another way is as follows. In this way, we save the model weights and load the weights later.

model.save_weights('sw.h5')
new_model = keras.models.Model(model_in, x)
new_model.load_weights('sw.h5')
channel = new_model.layers[2]

channel.set_snr(50) 
print(channel.get_snr())  

channel.set_snr(500) 
print(channel.get_snr())

channel.set_snr(5000) 
print(channel.get_snr())

50
500
5000

Upvotes: 2

Related Questions