harryscholes
harryscholes

Reputation: 1667

How to pass epoch-dependent parameters to a custom layer in Keras

I am trying to implement a k-sparse autoencoder in Keras (with Tensorflow as the backend). The authors of the original paper suggest that the sparsity of the middle layer should be increased progressively during training:

Suppose we are aiming for a sparsity level of k = 15. Then, we start off with a large sparsity level (e.g. k = 100) for which the k-sparse autoencoder can train all the hidden units. We then linearly decrease the sparsity level from k = 100 to k = 15 over the first half of the epochs. This initializes the autoencoder in a good regime, for which all of the hidden units have a significant chance of being picked. Then, we keep k = 15 for the second half of the epochs. With this scheduling, we can train all of the filters, even for low sparsity levels.

My custom layer looks like this and works for a static sparsity level k for all epochs:

from keras.layers import Layer
from keras import backend as K

class KSparse(Layer):
    def __init__(self, k=15, **kwargs):
        self.k = k
        self.uses_learning_phase = True
        super().__init__(**kwargs)

    def call(self, inputs):
        return K.in_train_phase(self.k_sparsify(inputs), inputs)

    def k_sparsify(self, inputs):
        k = self.k
        kth_smallest = K.tf.contrib.framework.sort(inputs)[..., K.shape(inputs)[-1]-1-k]
        return inputs * K.cast(K.greater(inputs, kth_smallest[:, None]), K.floatx())

    def compute_output_shape(self, input_shape):
        return input_shape

I have an array of the desired sparsity levels per epoch and a callback that accesses the current epoch:

import numpy as np
from keras.callbacks import LambdaCallback

def sparsity_level_per_epoch(n_epochs):
    return np.hstack((np.linspace(100, 15, n_epochs // 2, dtype=np.int),
                      np.repeat(15, n_epochs // 2)))


nth_epoch = LambdaCallback(on_epoch_begin=lambda epoch, logs: epoch)

During training, how do I pass the sparsity level of the current epoch to KSparse?

Upvotes: 2

Views: 1427

Answers (1)

nuric
nuric

Reputation: 11225

You can't dynamically change the layer properties because the computation graph is statically generated based on k at generation time. In other words, the k_sparsify is applied with a fixed k, the call method is actually called once. If you have an external changing parameter, one option would be to make it an input:

sparsity_level = Input(shape=(...))
# ...
out = KSparse()([x, sparsity_level])

and pass sparsity_level tensor to the KSparse layer so the value is now dependant on the tensor and not a class property. You can, as you already did, compute these levels offline prior to training and pass as an input now.

# inside KSparse you get a list of inputs
def call(self, inputs):
  x, sparsity_level = inputs # the things we pass in
  # ...

Upvotes: 1

Related Questions