Reputation: 876
I have a model M1
whose data input is a placeholder M1.input
and whose weights are trained.
My goal is to build a new model M2
which computes the output o
of M1
(with its trained weights) from an input w
in a form of tf.Variable
(instead of feeding actual values to M1.input
). In other words, I use the trained model M1
as a black-box function to build a new model o = M1(w)
(in my new model, w
is to be learned and the weights of M1
are fixed as constants). The problem is that M1
only accepts as its input M1.input
through which we need to feed actual values, not a tf.Variable like w
.
As a naive solution to build M2
, I can just manually build M1
within M2
and then initialize M1
's weights with the pre-trained values and keep them not trainable within M2
. However, in practice, M1
is complicated and I don't want to manually build M1
again within M2
. I am looking for a more elegant solution, something like a workaround or a direct solution to replace the input placeholder M1.input
of M1
with tf.Variable w
.
Thank you for your time.
Upvotes: 5
Views: 1027
Reputation: 4460
This is possible. What about:
import tensorflow as tf
def M1(input, reuse=False):
with tf.variable_scope('model_1', reuse=reuse):
param = tf.get_variable('param', [1])
o = input + param
return o
w = tf.get_variable('some_w', [1])
plhdr = tf.placeholder_with_default(w, [1])
output_m1 = M1(plhdr)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(w.assign([42]))
print(sess.run(output_m1, {plhdr: [0]})) # direct from placeholder
print(sess.run(output_m1)) # direct from variable
So when feed_dict has a value for the placeholder, this value is used. Otherwise, the fallback option using the variable "w" is active.
Upvotes: 2