Jane Sully
Jane Sully

Reputation: 3347

How to use Keras Multiply() with tf.Variable?

How do I multiply tf.keras.layers with tf.Variable?

Context: I am creating a sample dependent convolutional filter, which consists of a generic filter W that is transformed through sample dependent shifting + scaling. Therefore, the convolutional original filter W is transformed into aW + b where a is sample dependent scaling and b is sample dependent shifting. One application of this is training an autoencoder where the sample dependency is the label, so each label shifts/scales the convolutional filter. Because of sample/label dependent convolutions, I am using tf.nn.conv2d which takes the actual filters as input (as opposed to just the number/size of filters) and a lambda layer with tf.map_fn to apply a different "transformed filter" (based on the label) for each sample. Although the details are different, this kind of sample-dependent convolution approach is discussed in this post: Tensorflow: Convolutions with different filter for each sample in the mini-batch.

Here is what I am thinking:

input_img = keras.Input(shape=(28, 28, 1))  
label = keras.Input(shape=(10,)) # number of classes

num_filters = 32
shift = layers.Dense(num_filters, activation=None, name='shift')(label) # (32,)
scale = layers.Dense(num_filters, activation=None, name='scale')(label) # (32,)

# filter is of shape (filter_h, filter_w, input channels, output filters)
filter = tf.Variable(tf.ones((3,3,input_img.shape[-1],num_filters)))
# TODO: need to shift and scale -> shift*(filter) + scale along each output filter dimension (32 filter dimensions)

I am not sure how to implement the TODO part. I was thinking of tf.keras.layers.Multiply() for scaling and tf.keras.layers.Add() for shifting, but they do not seem to work with tf.Variable to my knowledge. How do I get around this? Assuming the dimensions/shape broadcasting work out, I would like to do something like this (note: the output should still be the same shape as var and is just scaled along each of the 32 output filter dimensions)

output = tf.keras.layers.Multiply()([var, scale]) 

Upvotes: 0

Views: 598

Answers (1)

thushv89
thushv89

Reputation: 11333

It requires some work and needs a custom layer. For example you cannot use tf.Variable with tf.keras.Lambda

class ConvNorm(layers.Layer):
    def __init__(self, height, width, n_filters):
        super(ConvNorm, self).__init__()
        self.height = height  
        self.width = width
        self.n_filters = n_filters

    def build(self, input_shape):              
        self.filter = self.add_weight(shape=(self.height, self.width, input_shape[-1], self.n_filters),
                                 initializer='glorot_uniform',
                                 trainable=True)        
        # TODO: Add bias too


    def call(self, x, scale, shift):
        shift_reshaped = tf.expand_dims(tf.expand_dims(shift,1),1)
        scale_reshaped = tf.expand_dims(tf.expand_dims(scale,1),1)

        norm_conv_out = tf.nn.conv2d(x, self.filter*scale + shift, strides=(1,1,1,1), padding='SAME')
                
        return norm_conv_out

Using the layer

import tensorflow as tf
import tensorflow.keras.layers as layers

input_img = layers.Input(shape=(28, 28, 1))  
label = layers.Input(shape=(10,)) # number of classes

num_filters = 32
shift = layers.Dense(num_filters, activation=None, name='shift')(label) # (32,)
scale = layers.Dense(num_filters, activation=None, name='scale')(label) # (32,)

conv_norm_out = ConvNorm(3,3,32)(input_img, scale, shift)
print(norm_conv_out.shape)

Note: Note that I haven't added bias. You will need bias as well for the convolution layer. But that's straightfoward.

Upvotes: 1

Related Questions