Konstantin Solomatov
Konstantin Solomatov

Reputation: 10352

Copy variables from one TensorFlow graph to another

I have two tensorflow graphs. One for training and the other for evaluation. They share a lot of variable names. When I evaluate a model I want to copy all variable values from the train graph to the test graph. Obviously, I can do it via tf.train.Saver, but this solution seems not very appropriate to me, especially the fact that we have to use the disk for this.

Upvotes: 4

Views: 2427

Answers (1)

Salvador Dali
Salvador Dali

Reputation: 222929

When you speak about multiple graphs, I assume you mean something like:

g1 = tf.Graph()
with g1.as_default():
  # add your stuff

g2 = tf.Graph()
with g2.as_default():
  # add other stuff

If this is correct, then are you sure you really need two graphs? Can't you have one graph consisting of two connected components?

Using multiple graphs is discouraged (p 47) because:

  • Multiple graphs require multiple sessions, each will try to use all available resources by default
  • Can't pass data between them without passing them through python/numpy, which doesn't work in distributed
  • It’s better to have disconnected subgraphs within one graph

This also gives you a solution how to pass variables in a non-distributed setting.

Upvotes: 4

Related Questions