fsamu
fsamu

Reputation: 33

Keras compute convex combination of two tensors

I have two tensors h1 and h2 both having shape (?, H, T). Which is the best way to merge them by computing their convex combination lambda * h1 + (1 - lambda) * h2 where lambda is a learnable one-dimensional vector with shape (H,)?

I am using keras with tensorflow backend.

Upvotes: 3

Views: 591

Answers (1)

rvinas
rvinas

Reputation: 11895

Define a custom keras layer:

from keras.engine.topology import Layer
from keras.models import Model
from keras.layers import Input
import numpy as np

H = 2
T = 3


class ConvexCombination(Layer):
    def __init__(self, **kwargs):
        super(ConvexCombination, self).__init__(**kwargs)

    def build(self, input_shape):
        batch_size, H, T = input_shape[0]
        self.lambd = self.add_weight(name='lambda',
                                     shape=(H, 1),  # Adding one dimension for broadcasting
                                     initializer='zeros',  # Try also 'ones' and 'uniform'
                                     trainable=True)
        super(ConvexCombination, self).build(input_shape)

    def call(self, x):
        # x is a list of two tensors with shape=(batch_size, H, T)
        h1, h2 = x
        return self.lambd * h1 + (1 - self.lambd) * h2

    def compute_output_shape(self, input_shape):
        return input_shape[0]


h1 = Input(shape=(H, T))
h2 = Input(shape=(H, T))
cc = ConvexCombination()([h1, h2])
model = Model(inputs=[h1, h2],
              outputs=cc)

a = np.zeros(H * T).reshape(1, H, T)
b = np.arange(H * T).reshape(1, H, T)
pred = model.predict([a, b])
print(pred)

Upvotes: 2

Related Questions