Baba Yara
Baba Yara

Reputation: 119

Tensorflow: Differentiable Primitives

I was under the impression that all tensorflow primitives are differentiable. Under this "illusion" I wrote this function in the hopes that tensorflow will just automatically differentiate it and I can backprop erros through it.

Rank-weight function:

def ranked(a):
     lens     = tf.convert_to_tensor(tf.range(1, (tf.size(a) + 1)))
     rankw01  = tf.cast(tf.convert_to_tensor(tf.contrib.framework.argsort(tf.contrib.framework.argsort(a)) + 1),
                 tf.float64)
     rankw02  = tf.convert_to_tensor(rankw01 - ((tf.size(a) + 1)/2))
     rankw03  = tf.divide(rankw02, tf.reduce_sum(tf.gather(rankw02, tf.where(tf.greater(rankw02, 0)))))
     rankw04  = tf.cast(rankw03, tf.float32)

     return rankw04

Unfortunately the function works as expected in the forward pass but does not work in the reverse pass because the derivative does not exist (from the error I keep getting).

The function is explained in the attached image:

enter image description here

I have the following questions:

1: Why can't I take the derivative of the function above.

2: If it is an implementation issue, can you suggest how I can rewrite it so I can take its derivative and backprop errors through it?

3: Are all tensorflow ops differentiable?

Upvotes: 1

Views: 780

Answers (1)

Baba Yara
Baba Yara

Reputation: 119

So I followed @DomJack 's advice and removed the tf.convert_to_tensor calls and did a little house cleaning all the way through. Now the function is differentiable.

def ranked(a):
    rankw01  = tf.cast(tf.contrib.framework.argsort(tf.contrib.framework.argsort(a)) + 1, tf.float32)
    rankw02  = rankw01 - tf.cast((tf.shape(a)[-1] + 1)/2, tf.float32)
    rankw03  = tf.div(rankw02, tf.reduce_sum(tf.nn.relu(rankw02), axis = -1, keepdims=True))

    return rankw033

Upvotes: 1

Related Questions