Reputation: 412
I want to create custom activation function in TF2. The math is like this:
def sqrt_activation(x):
if x >= 0:
return tf.math.sqrt(x)
else:
return -tf.math.sqrt(-x)
The problem is that I can't compare x
with 0 since x
is a tensor. How to achieve this functionality?
Upvotes: 0
Views: 70
Reputation: 17201
You can skip the comparison by doing,
def sqrt_activation(x):
return tf.math.sign(x)*tf.math.sqrt(tf.abs(x))
Upvotes: 3
Reputation: 640
YOu need to use tf backend functions and convert your code as follows:
import tensorflow as tf
@tf.function
def sqrt_activation(x):
zeros = tf.zeros_like(x)
pos = tf.where(x >= 0, tf.math.sqrt(x), zeros)
neg = tf.where(x < 0, -tf.math.sqrt(-x), zeros)
return pos + neg
note that this function check all tensor to meet on those conditions ergo returning the pos + neg
line
Upvotes: 1