user6952886
user6952886

Reputation: 423

TensorFlow: Where is the actual implementation of RMSprop?

In "rmsprop.py" (in TensorFlow) there is a call to the method apply_rms_prop. This method is defined in "gen_training_ops.py". In the definition of this method there is a comment describing what it is supposed to do:

ms <- rho * ms_{t-1} + (1-rho) * grad * grad
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
var <- var - mom

But I can't seem to find where the actual python implementation of the pseudo code above is. My guess is that it is implemented in cpython since I was able to find the file "__pycache__/rmsprop.cpython-36.pyc". But again, where is the cpython implementation that performs the pseudo code above?

My goal is to implement my own gradient update methods, so I need to see some concrete implementation examples (e.g. rmsprop, adam, etc.). Any help would be much appreciated!

Upvotes: 1

Views: 380

Answers (2)

javidcf
javidcf

Reputation: 59711

You can find the implementations under tensorflow/core/kernels. The CPU version is in training_ops.cc and the GPU (CUDA) version in training_ops_gpu.cu.cc (look for the template struct ApplyRMSProp). Other optimizer update rule implementations can also be found in those files.

I think the CPython code is automatically generated using the kernel registration macros at the end of the file, grouping the different implementations under one op name (translated from camel case to snake case in Python) that you use independently of the device.

Upvotes: 3

Ishant Mrinal
Ishant Mrinal

Reputation: 4918

You can implement your own optimizer from the Optimizer class. You have to implement at least one of the method _apply_dense or _apply_sparse.

A complete implementation of adamax optimizer using pure already available tensorflow ops.

class AdamaxOptimizer(optimizer.Optimizer):
..
   you can create slot variables implementing slot fucntion
   def _create_slots(self, var_list):
       ...
   def _apply_dense(self, grad, var):
        implement your logic for gradient updates here.

Upvotes: 1

Related Questions