Ulton Prinsphield
Ulton Prinsphield

Reputation: 161

how to store trained weights and biases in model A for use in model B in tensorflow

Suppose I have construct a neural network A with a dictionary of weights and biases variable, and I have get specific values for weights and biases after training in neural network A. And I want to use these specific value to replace the tf.Variable in neural network B. How could I achieve this purpose? I have tried tf.train.Saver(), however, I don't know how to restore the weights and biases in another network B (in another file). I have also tried storing them using pickle.dump, however, I met with another problem, i.e. when restoring weights and biases using pickle.load, it said dictionary type is not hashable. Could anyone help me solve this problem?

Upvotes: 1

Views: 2251

Answers (1)

mrry
mrry

Reputation: 126154

The tf.train.Saver class should help with this, although you might need to use some of the optional arguments to get it to work.

Let's say your model A looks like this, and you've trained it and saved it to a file called "/tmp/model_a_ckpt":

weights_a = tf.Variable(..., name="weights_a")
biases_a = tf.Variable(..., name="biases_a")
# ...

saver_a = tf.train.Saver()
# ...
saver_a.save(sess, "/tmp/model_a_ckpt")

...and then let's say your model B looks like this:

weights_b = tf.Variable(..., name="weights_b")
biases_b = tf.Variable(..., name="biases_b")

To load the checkpoint into model B you have to create a saver that maps the names of variables in the checkpoint (i.e. "weights_a" and "biases_a" because they default to the name property of the corresponding tf.Variable objects in model A) to the variables in model B:

saver_b = tf.train.Saver({"weights_a": weights_b, "biases_a": biases_b})
# ...
saver_b.restore(sess, "/tmp/model_a_ckpt")

After running saver_b.restore(), your variables in model B will have the values trained in model A.

Upvotes: 1

Related Questions