Vineeth Reddy
Vineeth Reddy

Reputation: 13

Controlling information flow and gating factor in Keras layers

Given a CNN architecture (architecture image) where the information flow from one layer to another is controlled by a gating factor. The fraction 'g' of information is sent to the immediate next layer and remaining '1-g' is sent to one of the forward layers (like a skip connection)

How to implement such an architecture in Keras? Thanks in advance

Upvotes: 0

Views: 709

Answers (1)

Daniel Möller
Daniel Möller

Reputation: 86600

Use the functional API Model.

For gates (automatic fraction g):

from keras.models import Model
from keras.layers import *

inputTensor = Input(someInputShape)

#the actual value
valueTensor = CreateSomeLayer(parameters)(inputTensor)

#the gate - this is the value of 'g', from zero to 1
gateTensor = AnotherLayer(matchingParameters, activation='sigmoid')(inputTensor)

#value * gate = fraction g
fractionG = Lambda(lambda x: x[0]*x[1])([valueTensor,gateTensor])

#value - fraction = 1 - g
complement = Lambda(lambda x: x[0] - x[1])([valueTensor,fractionG])

#each tensor may go into individual layers and follow individual paths:
immediateNextOutput = ImmediateNextLayer(params)(fractionG)
oneOfTheForwardOutputs = OneOfTheForwardLayers(params)(complement)

#keep going, make one or more outputs, and create your model:
model = Model(inputs=inputTensor, outputs=outputTensorOrListOfOutputTensors)    

For giving two inputs to the same layer, concatenate, sum, multiply, etc., in order to make them one.

#concat
joinedTensor = Concatenate(axis=optionalAxis)([input1,input2])

#add
joinedTensor = Add()([input1,input2])

#etc.....

nextLayerOut = TheLayer(parameters)(joinedTensor)

If you want to control 'g' manually:

In this case, all we have to do is replace the gateTensor by a user defined one:

import keras.backend as K

gateTensor = Input(tensor=K.variable([g]))

Pass this tensor as an input when creating the model. (Since it's a tensor input, it won't change the way you use the fit methods).

model = Model(inputs=[inputTensor,gateTensor], outputs=outputTensorOrListOfOutputTensors)    

Upvotes: 1

Related Questions