TrsNium
TrsNium

Reputation: 3

How to share variables of RNN on Tensorflow

I just make seqGAN on Tensorflow.

But I cannot share variables.

I wrote code aimed Discriminator as following...

import tensorflow as tf 
def discriminator(x, args, name, reuse=False): 
    with tf.variable_scope(name, reuse=reuse) as scope:
        print(tf.contrib.framework.get_name_scope())

        with tf.variable_scope(name+"RNN", reuse=reuse) as scope:
            cell_ = tf.contrib.rnn.GRUCell(args.dis_rnn_size, reuse=reuse)
            rnn_outputs, _= tf.nn.dynamic_rnn(cell_, x, initial_state=cell_.zero_state(batch_size=args.batch_size, dtype=tf.float32), dtype=tf.float32) 

        with tf.variable_scope(name+"Dense", reuse=reuse) as scope:
            logits = tf.layers.dense(rnn_outputs[:,-1,:], 1, activation=tf.nn.sigmoid, reuse=reuse)

    return logits

discriminator(fake, args, "D_", reuse=False) #printed D_
discriminator(real, args, "D_", reuse=True) #printed D_1

Please teach me how to reuse .

Upvotes: 0

Views: 207

Answers (1)

suharshs
suharshs

Reputation: 1088

variable_scope doesn't interact directly with name_scope. variable_scope is used to determine whether to create new variables or lookup new variables. You should use variable_scope with get_variable to accomplish this.

Here are some examples:

with tf.variable_scope("foo") as foo_scope:
    # A new variable will be created.
    v = tf.get_variable("v", [1])
with tf.variable_scope(foo_scope)
    # A new variable will be created.
    w = tf.get_variable("w", [1])
with tf.variable_scope(foo_scope, reuse=True)
    # Both variables will be reused.
    v1 = tf.get_variable("v", [1])
    w1 = tf.get_variable("w", [1])
with tf.variable_scope("foo", reuse=True)
    # Both variables will be reused.
    v2 = tf.get_variable("v", [1])
    w2 = tf.get_variable("w", [1])
with tf.variable_scope("foo", reuse=False)
    # New variables will be created.
    v3 = tf.get_variable("v", [1])
    w3 = tf.get_variable("w", [1])
assert v1 is v
assert w1 is w
assert v2 is v
assert w2 is w
assert v3 is not v
assert w3 is not w

https://www.tensorflow.org/versions/r0.12/how_tos/variable_scope/ has a lot of useful examples.

In your particular example, you don't need to specify the name of the inner variable_scopes as name+'RNN'. RNN will suffice since the variable_scope is nested. Otherwise, it looks to me that you are using reuse correctly, you are just comparing the name_scope which is a different thing. You can double check by looking at tf.global_variables to see what variables were created and if you were reusing in the way you intended.

I hope that helps!

Upvotes: 1

Related Questions