Reputation: 676
I have this function:
def sampling(x):
zeros = x*0
samples = tf.random.categorical(tf.math.log(x), 1)
samples = tf.squeeze(tf.one_hot(samples, depth=2), axis=1)
return zeros+samples
That I call from this layer:
x = layers.Lambda(sampling, name="lambda")(x)
But I need to change the depth variable in the sampling function, so I would need something like this:
def sampling(x, depth):
But, how can I make it work with the Lambda layer ?
Thanks a lot
Upvotes: 1
Views: 598
Reputation: 22031
Use a lambda function inside the Lambda layer...
def sampling(x, depth):
zeros = x*0
samples = tf.random.categorical(tf.math.log(x), 1)
samples = tf.squeeze(tf.one_hot(samples, depth=depth), axis=1)
return zeros+samples
usage:
Lambda(lambda t: sampling(t, depth=3), name="lambda")(x)
Upvotes: 3