ItsKalvik
ItsKalvik

Reputation: 67

Register Keras neural network weights with GPflow kernel

I am trying to implement a neural kernel function (attentive kernels). The kernel uses a neural network to predict mixture component weights. Here's my init:

class AttentiveKernel(gpflow.kernels.Kernel):
    def __init__(self, 
                 lengthscales, 
                 dim_hidden=10,
                 amplitude=1.0,
                 ndim=2): 
        super().__init__()
        with self.name_scope:
            self.num_lengthscales = len(lengthscales)
            self._free_amplitude = tf.Variable(amplitude, 
                                               shape=[],
                                               trainable=True,
                                               dtype=float_type)
            self.lengthscales = tf.Variable(lengthscales, 
                                            shape=[self.num_lengthscales], 
                                            trainable=False,
                                            dtype=float_type)
            
            self.nn = keras.Sequential([layers.InputLayer(shape=[ndim], batch_size=None),
                                        layers.Dense(dim_hidden, activation='tanh'), 
                                        layers.Dense(dim_hidden, activation='tanh'),
                                        layers.Dense(self.num_lengthscales, activation='softmax')])
            self.nn.build()

I am not sure why, but when I print the kernel function's variables (kernel_function.variables) I only get the explicitly defined variables:

(<tf.Variable 'attentive_kernel/Variable:0' shape=() dtype=float32, numpy=1.0>,
 <tf.Variable 'attentive_kernel/Variable:0' shape=(4,) dtype=float32, numpy=array([0.05, 0.5 , 1.  , 2.  ], dtype=float32)>)

The neural network's variables are not listed, so when I plug the kernel in a GP, the network weights are not trained

I am not sure how to get the nn weights listed as a variable/trainable_variable for the kernel function

I would appreciate any help with this :)

Upvotes: 0

Views: 13

Answers (0)

Related Questions