Reputation: 616
The variational autoencoder loss function is this: Loss = Loss_reconstruction + Beta * Loss_kld. I am trying to efficiently implement Kullback-Liebler Divergence Cyclic Annealing--that is changing the weight of beta dynamically during training. I subclass the tf.keras.callbacks.Callback
class as a start, but I don't know how I can update a tf.keras.Model
variable from a custom keras callback. Furthermore, I would like to track how the betas change at the end of each training step (on_train_batch_end
), and right now I have a list in the callback class, but I know python lists don't play well with TensorFlow. When I fit the model, I get a warning that my on_train_batch_end
function is slower than the processing of the batch itself. I think I should use a tf.TensorArray
instead of python lists, but then the tf.TensorArray
method write
cannot use a tf.Variable
for the index (i.e., as the number of steps changes, the index in the tf.TensorArray
to which a new beta for that step should be written changes)... is there a better way to store value changes? It looks like this github shows a solution that doesn't involve a custom tf.keras.Model
and that uses a different kind of KL annealing. Below is a callback function and dummy VAE.
class CyclicAnnealing(tf.keras.callbacks.Callback):
"""Cyclic annealing from https://arxiv.org/abs/1903.10145
Requires that model tracks training iterations and
total number of training iterations. It also requires
that model has hyperparameter for `M` and `R`.
"""
def __init__(self, schedule_fxn='sigmoid', **kwargs):
super().__init__(**kwargs)
# INEFFICIENT WAY OF LOGGING `betas` AND THE TRAIN STEPS...
# The `train_iterations` list could be removed because in principle
# if I have a list of betas, I know that the list of betas is of length
# (number of samples//batch size) * number of epochs.
# This is because (number of samples//batch size) * number of epochs is the total number of steps for the model.
self.betas = []
self.train_iterations = []
if schedule_fxn == 'sigmoid':
self.schedule_fxn = self.sigmoid
elif schedule_fxn =='linear':
self.schedule_fxn = self.linear
else:
raise ValueError('Invalid arg: `schedule_fxn`')
def on_epoch_end(self, epoch, logs=None):
print('\nCurrent anneal weight B =', self.beta)
def on_train_batch_end(self, batch, logs=None):
"""Computes betas and updates list"""
# Compute beta
self.beta = self.beta_tau_cyclic_annealing(self.compute_tau())
###################################
# HOW TO UPDATE BETA IN THE MODEL???
###################################
# Update the lists for logging
self.betas.append(self.beta)
self.train_iterations.append(self.model._train_counter))
def get_annealing_data(self):
return {'betas': self.betas, 'training_iterations': self.train_iterations}
def sigmoid(self, x):
"""Monotonic increasing function
:return: tf.constant float32
"""
return (1/(1+tf.keras.backend.exp(-x)))
def linear(self, x):
return x/self.model._R
def compute_tau(self):
"""Used to determine kld_beta.
:return: tf.constant float32
"""
t = tf.identity(self.model._train_counter)
T = self.model._total_training_iterations
M = self.model._M
numerator = tf.cast(tf.math.floormod(tf.subtract(t, 1), tf.math.floordiv(T, M)), dtype=tf.float32)
denominator = tf.cast(tf.math.floordiv(T, M), dtype=tf.float32)
return tf.math.divide(numerator, denominator)
def beta_tau_cyclic_annealing(self, tau):
"""Compute change for kld_beta.
:param tau: Increases beta_tau
:param R: Proportion used to increase Beta w/i cycle.
:return: tf.constant float32
"""
R = self.model._R
if tau <= R:
return self.schedule_fxn(tau)
else:
return tf.constant(1.0)
Dummy vae:
class VAE(tf.keras.Model):
def __init__(self, num_samples, batch_size, epochs, features, units, latent_size, kld_beta, M, R, **kwargs):
"""Defines state for model.
:param num_samples: <class 'int'>
:param batch_size: <class 'int'>
:param epochs: <class 'int'>
:param features: <class 'int'> if input is (n, m), then `features` is the the `m` dimension. This param is used with the decoder.
:param units: <class 'int'> Number of hidden units.
:param latent_size: <class 'int'> Dimension of latent space z.
:param kld_beta: <tf.Variable??> for dynamic weight.
:param M: <class 'int'> Hyperparameter for cyclic annealing.
:param R: <class 'float'> Hyperparameter for cyclic annealing.
"""
super().__init__(**kwargs)
# NEED TO UPDATE THIS SOMEHOW -- I think it should be a tf.Variable?
self.kld_beta = kld_beta
# Hyperparameters for CyclicAnnealing
self._M = M
self._R = R
self._total_training_iterations = (num_samples//batch_size) * epochs
# Encoder and Decoder not defined, but typically
# encoder = inputs -> dense -> dense mu and dense log var -> z
# while decoder = z -> dense -> reconstructions
self.encoder = Encoder(units, latent_size)
self.decoder = Decoder(features)
def call(self, inputs):
z, mus, log_vars = self.encoder(inputs)
reconstructions = self.decoder(z)
kl_loss = self.compute_kl_loss(mus, log_vars)
# THE BETA WEIGHT NEEDS TO BE DYNAMIC
weighted_kl_loss = self.kld_beta * kl_loss
self.add_loss(weighted_kl_loss)
return reconstructions
def compute_kl_loss(self, mus, log_vars):
return -0.5 * tf.reduce_mean(1. + log_vars - tf.exp(log_vars) - tf.pow(mus, 2))
Upvotes: 2
Views: 932
Reputation: 26708
Concerning your first question: It depends how you plan to update your gradients with your optimizer (e.g. ADAM). When training a VAE with Tensorflow / Keras, I usually use the @tf.function
decorator to calculate the loss of my model and based on that update my model's parameters:
@tf.function
def train_step(self, model, batch, gamma, capacity):
with tf.GradientTape() as tape:
x, c = batch
loss = compute_loss(model, x, c, gamma, capacity)
tf.print('Total loss: ', loss)
gradients = tape.gradient(loss, model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Note the variables gamma and capacity. They are defined as terms which influence the loss function. I update them after an x number of epochs as follows:
new_weight = min(tf.keras.backend.get_value(capacity) + (20. / capacity_annealtime), 20.)
tf.keras.backend.set_value(capacity, new_weight)
At this point you can easily save the new_weight
for logging purposes or you can defined a custom Tensorflow logger to log into a file. If you really want to use an array, you could simply define a TF array as:
this_array = tf.TensorArray(tf.float32, size=0, dynamic=True)
and update it after an x number of steps:
this_array.write(this_array.size(), new_beta_weight)
You could also use a second array and update it simultaneously in order to record the epoch or batch at which your new_beta_weight
was updated.
Finally, the loss function itself looks like this:
def compute_loss(model, x, c, gamma_weight, capacity_weight):
mean, logvar = model.encode(x, c)
z = model.reparameterize(mean, logvar)
reconstruction = model.decode(z, c)
total_reconstruction_loss =
tf.nn.sigmoid_cross_entropy_with_logits(labels=x,
logits=reconstruction)
total_reconstruction_loss = tf.reduce_sum(total_reconstruction_loss,
1)
kl_loss = 1 + logvar - tf.square(mean) - tf.exp(logvar)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
total_loss = tf.reduce_mean(total_reconstruction_loss * 3 + (
gamma_weight * tf.abs(kl_loss - capacity_weight)))
return total_loss
Note that model is from the type tf.keras.Model
. This should hopefully give you some different insights into this specific topic.
Upvotes: 2