ailauli69
ailauli69

Reputation: 676

Keras Lambda layer, how to use multiple arguments

I have this function:

def sampling(x):
    zeros = x*0
    samples = tf.random.categorical(tf.math.log(x), 1)
    samples = tf.squeeze(tf.one_hot(samples, depth=2), axis=1)
    return zeros+samples

That I call from this layer:

x = layers.Lambda(sampling, name="lambda")(x)

But I need to change the depth variable in the sampling function, so I would need something like this:

def sampling(x, depth):

But, how can I make it work with the Lambda layer ?

Thanks a lot

Upvotes: 1

Views: 598

Answers (1)

Marco Cerliani
Marco Cerliani

Reputation: 22031

Use a lambda function inside the Lambda layer...

def sampling(x, depth):
    zeros = x*0
    samples = tf.random.categorical(tf.math.log(x), 1)
    samples = tf.squeeze(tf.one_hot(samples, depth=depth), axis=1)
    return zeros+samples

usage:

Lambda(lambda t: sampling(t, depth=3), name="lambda")(x)

Upvotes: 3

Related Questions