Reputation: 302
GOAL
I'm trying to create custom model in Tensorflow with subclassing method. My goal is to create model with some custom attributes in it, train it, save it and after loading get the values of custom attributes with the model.
I've been looking for solution in the Internet, but I found nothing about this problem.
ISSUE
I've created test custom model class with self.custom_att
attribute, which is a list, in it. I've trained it on random data, saved and loaded. After loading the model, the attribute itself is in the model object, but it's changed to ListWrapper
object and it's empty.
QUESTION
How to store this attribute, so it would keep the values from before the saving process and after the loading process?
CODE
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
import numpy as np
from tensorflow.keras.models import load_model
class CustomModel(Model):
def __init__(self):
super(CustomModel, self).__init__()
self.in_dense = Dense(10, activation='relu')
self.dense = Dense(30, activation='relu')
self.out = Dense(3, activation='softmax')
self.custom_att = ['custom_att1', 'custom_att2'] # <- this attribute I want to store
def call(self, inputs, training=None, mask=None):
x = self.in_dense(inputs)
x = self.dense(x)
x = self.out(x)
return x
def get_config(self):
base_config = super(CustomModel, self).get_config()
return {**base_config, 'custom_att': self.custom_att}
X = np.random.random((1000, 5))
y = np.random.random((1000, 3))
model = CustomModel()
model.build((1, 5))
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
model.summary()
history = model.fit(X, y, epochs=1, validation_split=0.1)
model.save('models/testModel.model')
del model
model = load_model('models/testModel.model', custom_objects={'CustomModel': CustomModel}) # <- here attribute becomes ListWrapper([])
print(model.custom_att)
ENVIRONMENT
Upvotes: 2
Views: 1049
Reputation: 26718
I do not think using a list there would work when loading your model. Replace
self.custom_att = ['custom_att1', 'custom_att2']
with
self.custom_att = tf.Variable(['custom_att1', 'custom_att2'])
And you should see something like this:
print(model.custom_att.numpy())
# [b'custom_att1' b'custom_att2']
You can remove the b
literal in the strings like this:
print(model.custom_att.numpy()[0].decode("utf-8"))
# custom_att1
See @Ivan's comment for further improvements.
Upvotes: 4