Reputation: 65
In tensorflow python APIs, tf.get_variable has a parameter collections to add the created var to the specified collections. But tf.variable_scope does not. What's the suggested way to add all variables under a variable scope into a certain collection?
Upvotes: 2
Views: 954
Reputation: 1804
You could just get all variables within the scope instead of getting a collection:
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_scope')
https://stackoverflow.com/a/36536063/9095840
Upvotes: 0
Reputation: 131
import tensorflow as tf
for var in tf.global_variables(scope='model'):
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, var)
Instead of using global_variables
, you could also iterate over trainable_variables
if that is what you're interested in. In both cases, you do not only capture the variables you created manually using get_variable()
but also the ones created by e.g. any tf.layers
call.
Upvotes: 0
Reputation: 7130
I have managed to do this:
import tensorflow as tf
def var_1():
with tf.variable_scope("foo") as foo_scope:
assert foo_scope.name == "ll/foo"
a = tf.get_variable("a", [2, 2])
return foo_scope
def var_2(foo_scope):
with tf.variable_scope("bar"):
b = tf.get_variable("b", [2, 2])
with tf.variable_scope("baz") as other_scope:
c = tf.get_variable("c", [2, 2])
assert other_scope.name == "ll/bar/baz"
with tf.variable_scope(foo_scope) as foo_scope2:
d = tf.get_variable("d", [2, 2])
assert foo_scope2.name == "ll/foo" # Not changed.
def main():
with tf.variable_scope("ll"):
scp = var_1()
var_2(scp)
all_default_global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
my_collection = tf.get_collection('my_collection') # create my collection
ll_foo_variables = []
for variable in all_default_global_variables:
if "ll/foo" in variable.name:
ll_foo_variables.append(variable)
tf.add_to_collection('my_collection', ll_foo_variables)
variables_in_my_collection = tf.get_collection_ref("my_collection")
print(variables_in_my_collection)
main()
You can see that in my code in a, b, c and d only a and d have the same scope name ll/foo
.
First I add all variables which are created by default in the tf.GraphKeys.GLOBAL_VARIABLES collection, then I create a collection named my_collection and then I add only those variables with 'll/foo' in the scope name to my_collection.
And what I get I what I expected:
[[<tf.Variable 'll/foo/a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'll/foo/d:0' shape=(2, 2) dtype=float32_ref>]]
Upvotes: 0
Reputation: 3211
I don't believe there is a way to do this directly. You could file a feature request on Tensorflow's github issues tracker.
I can suggest two workarounds you might try though:
iterate over the result of tf.all_variables()
, and extract variables whose names look like ".../scope_name/..."
. The scope names are encoded in the variable name, separated by /
characters.
write wrappers around tf.VariableScope and tf.get_variable() that store the variables created inside the scope in a data structure.
I hope that helps!
Upvotes: 1