Reputation: 58
I am currently trying to train a model (hypernetwork) that can predict the weights for another model (main network) such that the main network's cross-entropy loss decreases. However when I use tf.assign to assign the new weights to the network it does not allow backpropagation into the hypernetwork thus rendering the system non-differentiable. I have tested whether my weights are properly updated and they seem to be since when subtracting initial weights from updated ones is a non zero sum.
This is a minimal sample of what I am trying to achieve.
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import softmax
def random_addition(variables):
addition_update_ops = []
for variable in variables:
update = tf.assign(variable, variable+tf.random_normal(shape=variable.get_shape()))
addition_update_ops.append(update)
return addition_update_ops
def network_predicted_addition(variables, network_preds):
addition_update_ops = []
for idx, variable in enumerate(variables):
if idx == 0:
print(variable)
update = tf.assign(variable, variable + network_preds[idx])
addition_update_ops.append(update)
return addition_update_ops
def dense_weight_update_net(inputs, reuse):
with tf.variable_scope("weight_net", reuse=reuse):
output = tf.layers.conv2d(inputs=inputs, kernel_size=(3, 3), filters=16, strides=(1, 1),
activation=tf.nn.leaky_relu, name="conv_layer_0", padding="SAME")
output = tf.reduce_mean(output, axis=[0, 1, 2])
output = tf.reshape(output, shape=(1, output.get_shape()[0]))
output = tf.layers.dense(output, units=(16*3*3*3))
output = tf.reshape(output, shape=(3, 3, 3, 16))
return output
def conv_net(inputs, reuse):
with tf.variable_scope("conv_net", reuse=reuse):
output = tf.layers.conv2d(inputs=inputs, kernel_size=(3, 3), filters=16, strides=(1, 1),
activation=tf.nn.leaky_relu, name="conv_layer_0", padding="SAME")
output = tf.reduce_mean(output, axis=[1, 2])
output = tf.layers.dense(output, units=2)
output = softmax(output)
return output
input_x_0 = tf.zeros(shape=(32, 32, 32, 3))
target_y_0 = tf.zeros(shape=(32), dtype=tf.int32)
input_x_1 = tf.ones(shape=(32, 32, 32, 3))
target_y_1 = tf.ones(shape=(32), dtype=tf.int32)
input_x = tf.concat([input_x_0, input_x_1], axis=0)
target_y = tf.concat([target_y_0, target_y_1], axis=0)
output_0 = conv_net(inputs=input_x, reuse=False)
target_y = tf.one_hot(target_y, 2)
crossentropy_loss_0 = tf.losses.softmax_cross_entropy(onehot_labels=target_y, logits=output_0)
conv_net_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="conv_net")
weight_net_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="weight_net")
print(conv_net_parameters)
weight_updates = dense_weight_update_net(inputs=input_x, reuse=False)
#updates_0 = random_addition(conv_net_parameters)
updates_1 = network_predicted_addition(conv_net_parameters, network_preds=[weight_updates])
with tf.control_dependencies(updates_1):
output_1 = conv_net(inputs=input_x, reuse=True)
crossentropy_loss_1 = tf.losses.softmax_cross_entropy(onehot_labels=target_y, logits=output_1)
check_sum = tf.reduce_sum(tf.abs(output_0 - output_1))
c_opt = tf.train.AdamOptimizer(beta1=0.9, learning_rate=0.001)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Needed for correct batch norm usage
with tf.control_dependencies(update_ops): # Needed for correct batch norm usage
train_variables = weight_net_parameters #+ conv_net_parameters
c_error_opt_op = c_opt.minimize(crossentropy_loss_1,
var_list=train_variables,
colocate_gradients_with_ops=True)
init=tf.global_variables_initializer()
with tf.Session() as sess:
init = sess.run(init)
loss_list_0 = []
loss_list_1 = []
for i in range(1000):
_, checksum, crossentropy_0, crossentropy_1 = sess.run([c_error_opt_op, check_sum, crossentropy_loss_0,
crossentropy_loss_1])
loss_list_0.append(crossentropy_0)
loss_list_1.append(crossentropy_1)
print(checksum, np.mean(loss_list_0), np.mean(loss_list_1))
Does anyone know how I can get tensorflow to compute the gradients for this? Thank you.
Upvotes: 0
Views: 480
Reputation: 32081
In this case your weights aren't variables, they are computed tensors based on the hypernetwork. All you really have is one network during training. If I understand you correctly you are then proposing to discard the hypernetwork and be able to use just the main network to perform predictions.
If this is the case then you can either save the weight values manually and reload them as constants, or you could use tf.cond
and tf.assign
to assign them as you are doing during training, but use tf.cond
to choose to use the variable or the computed tensor depending on whether you're doing training or inference.
During training you will need to use the computed tensor from the hypernetwork in order to enable backprop.
Example from comments, w
is the weight you'll use, you can assign a variable during training to keep track of it, but then use tf.cond
to either use the variable (during inference) or the computed value from the hypernetwork (during training). In this example you need to pass in a boolean placeholder is_training_placeholder
to indicate if you're running training of inference.
tf.assign(w_variable, w_from_hypernetwork)
w = tf.cond(is_training_placeholder, true_fn=lambda: w_from_hypernetwork, false_fn=lambda: w_variable)
Upvotes: 0