Reputation: 9869
I want to create a utility function where I give it a tensor t
and a variable name n
, and the function returns (it it exists) the variable
whos name contains n
and is part of the graph of t
.
def get_variable(t, n):
#code
return variable
The reason I want to do this is that using Jupyter I often create a graph multiple times while deciding the best structure, but the tensor names keep changing to something like: {name}_{repetition}:0
, so finding them through tf.all_variables()
becomes increasingly harder.
It would be easier if I could limit the search to the variables related to a specific tensor because its identifier is of the latest repetition.
Upvotes: 1
Views: 701
Reputation: 28218
(Warning: I am not going to answer the main question about creating an utility function)
The reason the tensors have names like: {name}_{repetition}:0
is not because you add again and again these tensors to the default graph.
A solution is to always specify the graph in which you are working with graph.as_default()
:
graph = tf.Graph()
with graph.as_default():
with tf.variable_scope('foo'):
var = tf.get_variable('var', [])
print var.name # should be 'foo/var:0'
tensor = tf.constant(2., name='tensor')
print tensor.name # should be 'foo/tensor:0'
If you run it again, you will see the exact same results because the graph = tf.Graph()
line will create a new graph.
graph = tf.Graph()
with graph.as_default():
with tf.variable_scope('foo'):
var = tf.get_variable('var', [])
print var.name # should be 'foo/var:0'
tensor = tf.constant(2., name='tensor')
print tensor.name # should be 'foo/tensor:0'
The drawback is that now you cannot rely on the default graph and should pass the graph to compute for example the list of all variables:
print tf.all_variables() # returns []
with graph.as_default():
print tf.all_variables() # returns [...] with all variables
Upvotes: 2