Reputation: 3616
I use Tensorflow 2.0 and want to extract all weights and biases from a trained model. Here is what I did so far:
I create a model class:
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__() # MyModel comes from a Basis Class
self.conv1 = Conv2D(filters=32, kernel_size=3, strides=[2,2], activation='relu')
self.flatten = Flatten()
self.d1 = Dense(units=64, activation="relu")
self.d2 = Dense(units=10, activation="softmax")
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
x = self.d2(x)
return x
During and after the training, I save my model:
checkpoint_path = "./logs/model.ckpt"
checkpoint_dir = "./logs/"
self.model.save_weights(checkpoint_path)
self.model.save(checkpoint_dir)
At this point I ask myself already, how to save the model correctly? Do I use save_weights
or just save
? I want to be able to
Currently I load my trained model (in a new file) by doing:
model = MyModel()
model.load_weights(checkpoint_path)
But how can I access the network's weights? I already tried tf.compat.v1.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
which did not work.
I highly appreciate any help!
Upvotes: 2
Views: 1630
Reputation: 1466
Firstly the difference between two saving methods:
model = MyModel()
with initial weights. Then you replace weights by .load_weights()
Next, you can analyze the weights by:
import tf.keras.Model as Model
from tf.keras.layers import Conv2D, Flatten, Dense
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__() # MyModel comes from a Basis Class
self.conv1 = Conv2D(filters=32, kernel_size=3, strides=[2,2], activation='relu')
self.flatten = Flatten()
self.d1 = Dense(units=64, activation="relu")
self.d2 = Dense(units=10, activation="softmax")
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
x = self.d2(x)
return x
m = MyModel()
input_shape = tf.TensorShape([None,64,64,1]) # For exmaple, 64x64 images with arbitrary batch size
m.build(input_shape)
# Train
# Save weights
# Load weights
# Analyze weights
conv1_weights, conv1_bias = m.conv1.weights
d1_weights, d1_bias = m.d1.weights
d2_weights, d2_bias = m.d2.weights
Upvotes: 1