David H. J.
David H. J.

Reputation: 412

How to create a custom conditional activation function

I want to create custom activation function in TF2. The math is like this:

def sqrt_activation(x):
    if x >= 0:
        return tf.math.sqrt(x)
    else:
        return -tf.math.sqrt(-x)

The problem is that I can't compare x with 0 since x is a tensor. How to achieve this functionality?

Upvotes: 0

Views: 70

Answers (2)

Vijay Mariappan
Vijay Mariappan

Reputation: 17201

You can skip the comparison by doing,

def sqrt_activation(x):
    return tf.math.sign(x)*tf.math.sqrt(tf.abs(x))

Upvotes: 3

ahmet hamza emra
ahmet hamza emra

Reputation: 640

YOu need to use tf backend functions and convert your code as follows:

import tensorflow as tf
@tf.function
def sqrt_activation(x):
    zeros = tf.zeros_like(x)
    pos = tf.where(x >= 0, tf.math.sqrt(x), zeros)
    neg = tf.where(x < 0, -tf.math.sqrt(-x), zeros)
    return pos + neg

note that this function check all tensor to meet on those conditions ergo returning the pos + neg line

Upvotes: 1

Related Questions