ttt
ttt

Reputation: 876

Tensorflow: Replacing/feeding a placeholder of a graph with tf.Variable?

I have a model M1 whose data input is a placeholder M1.input and whose weights are trained. My goal is to build a new model M2 which computes the output o of M1 (with its trained weights) from an input w in a form of tf.Variable (instead of feeding actual values to M1.input). In other words, I use the trained model M1 as a black-box function to build a new model o = M1(w) (in my new model, w is to be learned and the weights of M1 are fixed as constants). The problem is that M1 only accepts as its input M1.input through which we need to feed actual values, not a tf.Variable like w.

As a naive solution to build M2, I can just manually build M1 within M2 and then initialize M1's weights with the pre-trained values and keep them not trainable within M2. However, in practice, M1 is complicated and I don't want to manually build M1 again within M2. I am looking for a more elegant solution, something like a workaround or a direct solution to replace the input placeholder M1.input of M1 with tf.Variable w.

Thank you for your time.

Upvotes: 5

Views: 1027

Answers (1)

Patwie
Patwie

Reputation: 4460

This is possible. What about:

import tensorflow as tf


def M1(input, reuse=False):
    with tf.variable_scope('model_1', reuse=reuse):
        param = tf.get_variable('param', [1])
        o = input + param
        return o


w = tf.get_variable('some_w', [1])
plhdr = tf.placeholder_with_default(w, [1])

output_m1 = M1(plhdr)

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    sess.run(w.assign([42]))

    print(sess.run(output_m1, {plhdr: [0]}))  # direct from placeholder
    print(sess.run(output_m1))                # direct from variable

So when feed_dict has a value for the placeholder, this value is used. Otherwise, the fallback option using the variable "w" is active.

Upvotes: 2

Related Questions