Akiiino
Akiiino

Reputation: 1090

How can I reuse a Dense layer?

I have a network in Tensorflow, and I want to define a function that passes it's input through a tf.layers.dense layer (obviously, the same one). I see the reuse argument, but in order to use it properly it seems I need to keep a global variable just to remember if my function was called already. Is there a cleaner way?

Upvotes: 13

Views: 5743

Answers (3)

Gobinath
Gobinath

Reputation: 904

I find tf.layers.Dense cleaner than the above answers. All you need is a Dense object defined beforehand. Then you can reuse it any number of times.

import tensorflow as tf

# Define Dense object which is reusable
my_dense = tf.layers.Dense(3, name="optional_name")

# Define some inputs
x1 = tf.constant([[1,2,3], [4,5,6]], dtype=tf.float32)
x2 = tf.constant([[4,5,6], [7,8,9]], dtype=tf.float32)

# Use the Dense layer
y1 = my_dense(x1)
y2 = my_dense(x2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    y1 = sess.run(y1)
    y2 = sess.run(y2)
    print(y1)
    print(y2)

In fact tf.layers.dense function internally constructs a Dense object and pass your input to that object. For more details, check the code.

Upvotes: 12

Sorin
Sorin

Reputation: 11968

You could construct the layer against a constant of the right size and ignore the result.

This way the variable is declared but the operation should be pruned from the the graph.

For example

tf.layers.dense(tf.zeros(1, 128), 3, name='my_layer')

... later
hidden = tf.layers.dense(input, 3, name='my_layer', reuse=True)

Upvotes: 4

AlexP
AlexP

Reputation: 1466

As far as I know, there's no cleaner way. The best we can do is wrap tf.layers.dense into our abstraction and use it as an object, hiding variable scope's backbone:

def my_dense(*args, **kwargs):
  scope = tf.variable_scope(None, default_name='dense').__enter__()
  def f(input):
    r = tf.layers.dense(input, *args, name=scope, **kwargs)
    scope.reuse_variables()
    return r
  return f

a = [[1,2,3], [4,5,6]]
a = tf.constant(a, dtype=tf.float32)
layer = my_dense(3)
a = layer(a)
a = layer(a)

print(*[[int(a) for a in v.get_shape()] for v in tf.trainable_variables()])
# Prints: "[3, 3] [3]" (one pair of (weights and biases))

Upvotes: 4

Related Questions