Reputation: 306
In tensorflow-2.0, I am trying to create a keras.layers.Layer
which outputs the Kullback-Leibler (KL) divergence between two tensorflow_probability.distributions
. I would like to calculate the gradient of the output (i.e. the KL divergence) with respect to the mean value of one of the tensorflow_probability.distributions
.
In all my attempts so far, the resulting gradients are 0
, unfortunately.
I tried implementing a minimal example shown below. I was wondering if the problems might have to do with the eager execution mode of tf 2
, as I know of a similar approach that worked in tf 1
, where eager execution is disabled by default.
This is the minimal example I tried:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer,Input
# 1 Define Layer
class test_layer(Layer):
def __init__(self, **kwargs):
super(test_layer, self).__init__(**kwargs)
def build(self, input_shape):
self.mean_W = self.add_weight('mean_W',trainable=True)
self.kernel_dist = tfp.distributions.MultivariateNormalDiag(
loc=self.mean_W,
scale_diag=(1.,)
)
super(test_layer, self).build(input_shape)
def call(self,x):
return tfp.distributions.kl_divergence(
self.kernel_dist,
tfp.distributions.MultivariateNormalDiag(
loc=self.mean_W*0.,
scale_diag=(1.,)
)
)
# 2 Create model
x = Input(shape=(3,))
fx = test_layer()(x)
test_model = Model(name='test_random', inputs=[x], outputs=[fx])
# 3 Calculate gradient
print('\n\n\nCalculating gradients: ')
# example data, only used as a dummy
x_data = np.random.rand(99,3).astype(np.float32)
for x_now in np.split(x_data,3):
# print(x_now.shape)
with tf.GradientTape() as tape:
fx_now = test_model(x_now)
grads = tape.gradient(
fx_now,
test_model.trainable_variables,
)
print('\nKL-Divergence: ', fx_now, '\nGradient: ',grads,'\n')
print(test_model.summary())
The output of the code above is
Calculating gradients:
KL-Divergence: tf.Tensor(0.0029436834, shape=(), dtype=float32)
Gradient: [<tf.Tensor: id=237, shape=(), dtype=float32, numpy=0.0>]
KL-Divergence: tf.Tensor(0.0029436834, shape=(), dtype=float32)
Gradient: [<tf.Tensor: id=358, shape=(), dtype=float32, numpy=0.0>]
KL-Divergence: tf.Tensor(0.0029436834, shape=(), dtype=float32)
Gradient: [<tf.Tensor: id=479, shape=(), dtype=float32, numpy=0.0>]
Model: "test_random"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 3)] 0
_________________________________________________________________
test_layer_3 (test_layer) () 1
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________
None
The KL divergence is calculated correcly, but the resulting gradient is 0
. What would be a correct way to obtain the gradients?
Upvotes: 2
Views: 1641
Reputation: 1076
We are working our way through distributions & bijectors, making them friendly to closing over variables in the constructor. (Have not yet done the MVNs.) In the meantime, you could use tfd.Independent(tfd.Normal(loc=self.mean_W, scale=1), reinterpreted_batch_ndims=1)
which I think will work inside your build method because we've adapted Normal
.
Also: have you seen the tfp.layers package? In particular https://www.tensorflow.org/probability/api_docs/python/tfp/layers/KLDivergenceAddLoss might be interesting to you.
Upvotes: 2
Reputation: 306
If anybody should be interested, I found out how to solve this:
The line
self.kernel_dist = tfp.distributions.MultivariateNormalDiag(
loc=self.mean_W,
scale_diag=(1.,)
)
should not be inside the build()
- method of the layer class definition, but rather inside the call()
method. Here is the modified example:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer,Input
# 1 Define Layer
class test_layer(Layer):
def __init__(self, **kwargs):
super(test_layer, self).__init__(**kwargs)
def build(self, input_shape):
self.mean_W = self.add_weight('mean_W',trainable=True)
super(test_layer, self).build(input_shape)
def call(self,x):
self.kernel_dist = tfp.distributions.MultivariateNormalDiag(
loc=self.mean_W,
scale_diag=(1.,)
)
return tfp.distributions.kl_divergence(
self.kernel_dist,
tfp.distributions.MultivariateNormalDiag(
loc=self.mean_W*0.,
scale_diag=(1.,)
)
)
# 2 Create model
x = Input(shape=(3,))
fx = test_layer()(x)
test_model = Model(name='test_random', inputs=[x], outputs=[fx])
# 3 Calculate gradient
print('\n\n\nCalculating gradients: ')
# example data, only used as a dummy
x_data = np.random.rand(99,3).astype(np.float32)
for x_now in np.split(x_data,3):
# print(x_now.shape)
with tf.GradientTape() as tape:
fx_now = test_model(x_now)
grads = tape.gradient(
fx_now,
test_model.trainable_variables,
)
print('\nKL-Divergence: ', fx_now, '\nGradient: ',grads,'\n')
print(test_model.summary())
The output now is
Calculating gradients:
KL-Divergence: tf.Tensor(0.024875917, shape=(), dtype=float32)
Gradient: [<tf.Tensor: id=742, shape=(), dtype=float32, numpy=0.22305119>]
KL-Divergence: tf.Tensor(0.024875917, shape=(), dtype=float32)
Gradient: [<tf.Tensor: id=901, shape=(), dtype=float32, numpy=0.22305119>]
KL-Divergence: tf.Tensor(0.024875917, shape=(), dtype=float32)
Gradient: [<tf.Tensor: id=1060, shape=(), dtype=float32, numpy=0.22305119>]
Model: "test_random"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 3)] 0
_________________________________________________________________
test_layer_1 (test_layer) () 1
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________
None
as expected.
Is this something that was changed from tensorflow 1
to tensorflow 2
?
Upvotes: 2