Dmitrii
Dmitrii

Reputation: 259

How can I duplicate tensorflow layer?

Suppose I have a layer (i.e. a collection of ops under the same name scope) in Tensorflow. How can I duplicate it together with input connections?

More specifically, suppose I have the following graph:

A --> B --> C --> D

now I want to duplicate C as C1, where C is a whole name scope:

A --> B --> C --> D
        \-> C

How can I do that in TensorFlow?

Upvotes: 4

Views: 2838

Answers (2)

Alien Ambassador
Alien Ambassador

Reputation: 31

The solution can be divided to 2 parts.

1. Replicate the graph of the layer

This is straightforward: just use the same code that you created that layer to do that. I suggest using Keras instead of raw TensorFlow — that will give you more flexibility and easiness in doing this step.

2. Copy the weights

The idea is you only need to copy tf.Variables, which are basically a group of following ops: initializer, kernel, and assign. Here is a good explanation. So the code will look as follows:

vars = tf.trainable_variables()  # getting the variables
vars_vals = sess.run(vars)       # getting their weights as numpy arrays
vars_duplicates = ...            # here, get the weights of your layer,
                                 # that should be in the same order
for var, val in zip(vars_duplicates, vars_vals):
    var.load(val, sess)

Upvotes: 3

Chan Kha Vu
Chan Kha Vu

Reputation: 10390

This can be done using tf.contrib.graph_editor. Let's see how it can be done:

import tensorflow.contrib.graph_editor as ge

# Get the SubgraphView of given layer
layer_sgv = ge.make_view_from_scope(layer_name, tf.get_default_graph())

# Retrieve the incoming tensors to the layer from ops outside.
# We need these to preserve input hierarchy while duplicating.
replacement_ts = {}
for op in layer_sgv.inputs:
    replacement_ts[op] = op

# Duplicate the layer
duplicate_sgv, info = ge.copy_with_input_replacements(
    layer_sgv,
    replacement_ts=replacement_ts,
    src_scope=layer_name,
    dst_scope=new_layer_name)

You can read more on SubgraphView here.

Upvotes: 4

Related Questions