fhllw
fhllw

Reputation: 13

Repeat keras (tensorflow) model over additional dimension

Let's say I have a model that maps a tensor with shape [n, 10] to a tensor with shape [n, 2] where n is the batch size. How can I repeat the model so that the resulting model accepts an input tensor with shape [n, k, 10] and outputs a tensor with shape [n, k, 2]? The k versions of the model should share all weights.

Upvotes: 1

Views: 314

Answers (1)

Anna Krogager
Anna Krogager

Reputation: 3588

You can do something like this:

input_ = Input((k, model.input.shape[1]))
input_as_list = Lambda(lambda x: tf.unstack(x, axis=1))(input_)
model_outputs = [model(x) for x in input_as_list] 
model_outputs = [Lambda(lambda x: K.expand_dims(x, axis=1))(y) for y in model_outputs]
concat_output = Concatenate(axis=1)(model_outputs)
new_model = Model(input_, concat_output)

Upvotes: 1

Related Questions