PasserbyD
PasserbyD

Reputation: 65

How to add all variables under a scope into a certain collection

In tensorflow python APIs, tf.get_variable has a parameter collections to add the created var to the specified collections. But tf.variable_scope does not. What's the suggested way to add all variables under a variable scope into a certain collection?

Upvotes: 2

Views: 954

Answers (4)

markemus
markemus

Reputation: 1804

You could just get all variables within the scope instead of getting a collection:

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_scope')

https://stackoverflow.com/a/36536063/9095840

Upvotes: 0

stecklin
stecklin

Reputation: 131

import tensorflow as tf    

for var in tf.global_variables(scope='model'):
    tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, var)

Instead of using global_variables, you could also iterate over trainable_variables if that is what you're interested in. In both cases, you do not only capture the variables you created manually using get_variable() but also the ones created by e.g. any tf.layers call.

Upvotes: 0

Lerner Zhang
Lerner Zhang

Reputation: 7130

I have managed to do this:

import tensorflow as tf

def var_1():
    with tf.variable_scope("foo") as foo_scope:
        assert foo_scope.name == "ll/foo"
        a = tf.get_variable("a", [2, 2])
    return foo_scope

def var_2(foo_scope):
    with tf.variable_scope("bar"):
        b = tf.get_variable("b", [2, 2])
        with tf.variable_scope("baz") as other_scope:
            c = tf.get_variable("c", [2, 2])
            assert other_scope.name == "ll/bar/baz"
            with tf.variable_scope(foo_scope) as foo_scope2:
                d = tf.get_variable("d", [2, 2])
                assert foo_scope2.name == "ll/foo"  # Not changed.

def main():
    with tf.variable_scope("ll"):
        scp = var_1()
        var_2(scp)
        all_default_global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
        my_collection = tf.get_collection('my_collection') # create my collection
        ll_foo_variables = []
        for variable in all_default_global_variables:
            if "ll/foo" in variable.name:
                ll_foo_variables.append(variable)
        tf.add_to_collection('my_collection', ll_foo_variables)

        variables_in_my_collection = tf.get_collection_ref("my_collection")
        print(variables_in_my_collection)

main()

You can see that in my code in a, b, c and d only a and d have the same scope name ll/foo.

The process:

First I add all variables which are created by default in the tf.GraphKeys.GLOBAL_VARIABLES collection, then I create a collection named my_collection and then I add only those variables with 'll/foo' in the scope name to my_collection.

And what I get I what I expected:

[[<tf.Variable 'll/foo/a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'll/foo/d:0' shape=(2, 2) dtype=float32_ref>]]

Upvotes: 0

Peter Hawkins
Peter Hawkins

Reputation: 3211

I don't believe there is a way to do this directly. You could file a feature request on Tensorflow's github issues tracker.

I can suggest two workarounds you might try though:

  • iterate over the result of tf.all_variables(), and extract variables whose names look like ".../scope_name/...". The scope names are encoded in the variable name, separated by / characters.

  • write wrappers around tf.VariableScope and tf.get_variable() that store the variables created inside the scope in a data structure.

I hope that helps!

Upvotes: 1

Related Questions