Reputation: 161
Suppose I have construct a neural network A with a dictionary of weights and biases variable, and I have get specific values for weights and biases after training in neural network A. And I want to use these specific value to replace the tf.Variable
in neural network B. How could I achieve this purpose? I have tried tf.train.Saver()
, however, I don't know how to restore the weights and biases in another network B (in another file). I have also tried storing them using pickle.dump
, however, I met with another problem, i.e. when restoring weights
and biases
using pickle.load
, it said dictionary
type is not hashable. Could anyone help me solve this problem?
Upvotes: 1
Views: 2251
Reputation: 126154
The tf.train.Saver
class should help with this, although you might need to use some of the optional arguments to get it to work.
Let's say your model A looks like this, and you've trained it and saved it to a file called "/tmp/model_a_ckpt"
:
weights_a = tf.Variable(..., name="weights_a")
biases_a = tf.Variable(..., name="biases_a")
# ...
saver_a = tf.train.Saver()
# ...
saver_a.save(sess, "/tmp/model_a_ckpt")
...and then let's say your model B looks like this:
weights_b = tf.Variable(..., name="weights_b")
biases_b = tf.Variable(..., name="biases_b")
To load the checkpoint into model B you have to create a saver that maps the names of variables in the checkpoint (i.e. "weights_a"
and "biases_a"
because they default to the name
property of the corresponding tf.Variable
objects in model A) to the variables in model B:
saver_b = tf.train.Saver({"weights_a": weights_b, "biases_a": biases_b})
# ...
saver_b.restore(sess, "/tmp/model_a_ckpt")
After running saver_b.restore()
, your variables in model B will have the values trained in model A.
Upvotes: 1