Mastiff
Mastiff

Reputation: 2240

How to do a weighted combine of two tensors with trainable weighting?

I'd like to have a layer that takes two inputs of the same dimension, A and B and generates an output C = A + w*B where w is a trainable scalar. In my case, A and B are from a 1D CNN so have shape (batch, time, features). Is there a way to make an existing layer perform this function, or do I need to code up a custom layer?

Upvotes: 0

Views: 105

Answers (1)

Alberto
Alberto

Reputation: 12939

Since you have a state, you want to use subclassing, heres a possible implementation:

class WeightedSum(K.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.W = self.add_weight("W", shape=[1], dtype=tf.float32, trainable=True)
    def call(self, inputs, *args, **kwargs):
        A,B = inputs
        return A + self.W*B

If you want to check that W is trainable, you can see it from here:

input = K.layers.Input(shape=10)
x1 = K.layers.Dense(10, trainable=False)(input)
x2 = K.layers.Dense(10, trainable=False)(input)
sum = WeightedSum()([x1, x2])
res = K.layers.Dense(1, trainable=False)(sum)

model = K.models.Model(inputs = [input], outputs = [res])
model.compile(K.optimizers.Adam(), loss="binary_crossentropy")
model.fit(np.random.random((100, 10)), np.random.choice([0,1], (100,)), epochs=10)

Output:

Epoch 1/10
4/4 [==============================] - 0s 9ms/step - loss: 6.5532
Epoch 2/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5522
Epoch 3/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5513
Epoch 4/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5504
Epoch 5/10
4/4 [==============================] - 0s 4ms/step - loss: 6.5500
Epoch 6/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5491
Epoch 7/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5484
Epoch 8/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5479
Epoch 9/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5472
Epoch 10/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5466

and you see that even though all the other layers are not trainable, the loss decreases (since W is optimized)

About the fact that the 2 vector comes from a 1D conv, should not matter (until they have the same size, or a compatible one)

Upvotes: 5

Related Questions