Gilfoyle
Gilfoyle

Reputation: 3616

Extracting weights from model in Tensorflow 2.0

I use Tensorflow 2.0 and want to extract all weights and biases from a trained model. Here is what I did so far:

I create a model class:

class MyModel(Model):

def __init__(self):
    super(MyModel, self).__init__() # MyModel comes from a Basis Class
    self.conv1 = Conv2D(filters=32, kernel_size=3, strides=[2,2], activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(units=64, activation="relu")
    self.d2 = Dense(units=10, activation="softmax")

def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    x = self.d2(x)
    return x

During and after the training, I save my model:

checkpoint_path = "./logs/model.ckpt"
checkpoint_dir = "./logs/"
self.model.save_weights(checkpoint_path)
self.model.save(checkpoint_dir)

At this point I ask myself already, how to save the model correctly? Do I use save_weights or just save? I want to be able to

  1. retrain the model if necessary
  2. extract the model's weights for further analysis

Currently I load my trained model (in a new file) by doing:

model = MyModel()
model.load_weights(checkpoint_path)

But how can I access the network's weights? I already tried tf.compat.v1.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) which did not work.

I highly appreciate any help!

Upvotes: 2

Views: 1630

Answers (1)

EyesBear
EyesBear

Reputation: 1466

Firstly the difference between two saving methods:

  • model.save_weights(): You save only weights. So, you need the model code to reconstruct the model as model = MyModel() with initial weights. Then you replace weights by .load_weights()
  • model.save(): It saves the whole model including the architecture, optimizer states and weights. So, you can reproduce the entire mode without the code that defines MyModel().
  • By the way, another option in TF2 is to use checkpoint manager. In your case, I would go with .save_weights() or checkpoint manager

Next, you can analyze the weights by:

import tf.keras.Model as Model
from tf.keras.layers import Conv2D, Flatten, Dense

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__() # MyModel comes from a Basis Class
        self.conv1 = Conv2D(filters=32, kernel_size=3, strides=[2,2], activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(units=64, activation="relu")
        self.d2 = Dense(units=10, activation="softmax")

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        x = self.d2(x)
        return x

m = MyModel()
input_shape = tf.TensorShape([None,64,64,1]) # For exmaple, 64x64 images with arbitrary batch size
m.build(input_shape) 

# Train

# Save weights 

# Load weights

# Analyze weights
conv1_weights, conv1_bias = m.conv1.weights
d1_weights, d1_bias = m.d1.weights
d2_weights, d2_bias = m.d2.weights

Upvotes: 1

Related Questions