MASTER OF CODE
MASTER OF CODE

Reputation: 302

How to save custom attributes with custom model in Tensorflow?

GOAL

I'm trying to create custom model in Tensorflow with subclassing method. My goal is to create model with some custom attributes in it, train it, save it and after loading get the values of custom attributes with the model.

I've been looking for solution in the Internet, but I found nothing about this problem.

ISSUE

I've created test custom model class with self.custom_att attribute, which is a list, in it. I've trained it on random data, saved and loaded. After loading the model, the attribute itself is in the model object, but it's changed to ListWrapper object and it's empty.

QUESTION

How to store this attribute, so it would keep the values from before the saving process and after the loading process?

CODE

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
import numpy as np
from tensorflow.keras.models import load_model


class CustomModel(Model):

    def __init__(self):
        super(CustomModel, self).__init__()
        self.in_dense = Dense(10, activation='relu')
        self.dense = Dense(30, activation='relu')
        self.out = Dense(3, activation='softmax')
        self.custom_att = ['custom_att1', 'custom_att2'] # <- this attribute I want to store

    def call(self, inputs, training=None, mask=None):
        x = self.in_dense(inputs)
        x = self.dense(x)
        x = self.out(x)
        return x

    def get_config(self):
        base_config = super(CustomModel, self).get_config()
        return {**base_config, 'custom_att': self.custom_att}


X = np.random.random((1000, 5))
y = np.random.random((1000, 3))

model = CustomModel()
model.build((1, 5))
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
model.summary()
history = model.fit(X, y, epochs=1, validation_split=0.1)
model.save('models/testModel.model')

del model

model = load_model('models/testModel.model', custom_objects={'CustomModel': CustomModel}) # <- here attribute becomes ListWrapper([])
print(model.custom_att)

ENVIRONMENT

Upvotes: 2

Views: 1049

Answers (1)

AloneTogether
AloneTogether

Reputation: 26718

I do not think using a list there would work when loading your model. Replace

self.custom_att = ['custom_att1', 'custom_att2']

with

self.custom_att = tf.Variable(['custom_att1', 'custom_att2'])

And you should see something like this:

print(model.custom_att.numpy())
# [b'custom_att1' b'custom_att2']

You can remove the b literal in the strings like this:

print(model.custom_att.numpy()[0].decode("utf-8"))
# custom_att1

See @Ivan's comment for further improvements.

Upvotes: 4

Related Questions