Reputation: 2240
I'd like to have a layer that takes two inputs of the same dimension, A
and B
and generates an output C = A + w*B
where w
is a trainable scalar. In my case, A
and B
are from a 1D CNN so have shape (batch, time, features). Is there a way to make an existing layer perform this function, or do I need to code up a custom layer?
Upvotes: 0
Views: 105
Reputation: 12939
Since you have a state, you want to use subclassing, heres a possible implementation:
class WeightedSum(K.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.W = self.add_weight("W", shape=[1], dtype=tf.float32, trainable=True)
def call(self, inputs, *args, **kwargs):
A,B = inputs
return A + self.W*B
If you want to check that W is trainable, you can see it from here:
input = K.layers.Input(shape=10)
x1 = K.layers.Dense(10, trainable=False)(input)
x2 = K.layers.Dense(10, trainable=False)(input)
sum = WeightedSum()([x1, x2])
res = K.layers.Dense(1, trainable=False)(sum)
model = K.models.Model(inputs = [input], outputs = [res])
model.compile(K.optimizers.Adam(), loss="binary_crossentropy")
model.fit(np.random.random((100, 10)), np.random.choice([0,1], (100,)), epochs=10)
Output:
Epoch 1/10
4/4 [==============================] - 0s 9ms/step - loss: 6.5532
Epoch 2/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5522
Epoch 3/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5513
Epoch 4/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5504
Epoch 5/10
4/4 [==============================] - 0s 4ms/step - loss: 6.5500
Epoch 6/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5491
Epoch 7/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5484
Epoch 8/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5479
Epoch 9/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5472
Epoch 10/10
4/4 [==============================] - 0s 5ms/step - loss: 6.5466
and you see that even though all the other layers are not trainable, the loss decreases (since W is optimized)
About the fact that the 2 vector comes from a 1D conv, should not matter (until they have the same size, or a compatible one)
Upvotes: 5