pfm
pfm

Reputation: 6328

Why does tf.get_variable('test') returns a variable with name test_1?

I created a tensorflow variable with tf.Variable. I wonder why if I call tf.get_variable with the same name no exception is raised and a new variable is created with an incremented name ?

import tensorflow as tf

class QuestionTest(tf.test.TestCase):

    def test_version(self):
        self.assertEqual(tf.__version__, '1.10.1')

    def test_variable(self):
        a = tf.Variable(0., trainable=False, name='test')
        self.assertEqual(a.name, "test:0")

        b = tf.get_variable('test', shape=(), trainable=False)
        self.assertEqual(b.name, "test_1:0")

        self.assertNotEqual(a, b, msg='`a` is not `b`')

        with self.assertRaises(ValueError) as ecm:
            tf.get_variable('test', shape=(), trainable=False)
        exception = ecm.exception
        self.assertStartsWith(str(exception), "Variable test already exists, disallowed.")

Upvotes: 1

Views: 197

Answers (1)

pfm
pfm

Reputation: 6328

This is because tf.Variable is a low level method which stores created variable in GLOBALS (or LOCALS) collection while tf.get_variable keeps account of the variable it has created by storing them in a variable store.

When you first call tf.Variable, the variable created is not added to the variable store letting think that no variable with name "test" has been created.

So, when you later call tf.get_variable("test") it will look at the variable store, see that no variable with name "test" is in it.
It will thus call tf.Variable, which will create a variable with an incremented name "test_1" stored in the variable store under the key "test".

import tensorflow as tf

class AnswerTest(tf.test.TestCase):

    def test_version(self):
        self.assertEqual(tf.__version__, '1.10.1')    

    def test_variable_answer(self):
        """Using the default variable scope"""
        # Let first check the __variable_store and the GLOBALS collections.
        self.assertListEqual(tf.get_collection(("__variable_store",)), [], 
                             "No variable store.")
        self.assertListEqual(tf.global_variables(), [],
                             "No global variables")

        a = tf.Variable(0., trainable=False, name='test')
        self.assertEqual(a.name, "test:0")
        self.assertListEqual(tf.get_collection(("__variable_store",)), [],
                             "No variable store.")
        self.assertListEqual(tf.global_variables(), [a],
                             "but `a` is in global variables.")

        b = tf.get_variable('test', shape=(), trainable=False)
        self.assertNotEqual(a, b, msg='`a` is not `b`')
        self.assertEqual(b.name, "test_1:0", msg="`b`'s name is not 'test'.")
        self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0,
                        "There is now a variable store.")
        var_store = tf.get_collection(("__variable_store",))[0]
        self.assertDictEqual(var_store._vars, {"test": b},
                             "and variable `b` is in it.")
        self.assertListEqual(tf.global_variables(), [a, b],
                             "while `a` and `b` are in global variables.")

        with self.assertRaises(ValueError) as exception_context_manager:
            tf.get_variable('test', shape=(), trainable=False)
        exception = exception_context_manager.exception
        self.assertStartsWith(str(exception),
                              "Variable test already exists, disallowed.")

The same is true when using an explicit variable scope.

    def test_variable_answer_with_variable_scope(self):
        """Using now a variable scope"""
        self.assertListEqual(tf.get_collection(("__variable_store",)), [], 
                             "No variable store.")

        with tf.variable_scope("my_scope") as scope:
            self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0, 
                            "There is now a variable store.")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "but with variable in it.")

            a = tf.Variable(0., trainable=False, name='test')
            self.assertEqual(a.name, "my_scope/test:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "Still no variable in the store.")


            b = tf.get_variable('test', shape=(), trainable=False)
            self.assertEqual(b.name, "my_scope/test_1:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(
                var_store._vars, {"my_scope/test": b},
                "`b` is in the store, but notice the difference between its name and its key in the store.")

            with self.assertRaises(ValueError) as exception_context_manager:
                tf.get_variable('test', shape=(), trainable=False)
            exception = exception_context_manager.exception
            self.assertStartsWith(str(exception),
                                  "Variable my_scope/test already exists, disallowed.")

Upvotes: 2

Related Questions