herrtim
herrtim

Reputation: 2755

How to update model (Variable) in custom python operator (tf.py_func)?

I need to write a custom Op in python, which will generate an output based on a model and another op that will update the model. In the following sample code, I have a very simple model of just a scaler, w (but in reality it will be a nxm matrix). I figured out how to "read" the model as demonstrated in the custom_model_read_op function (in reality much more complicated). However, how can I create something similar that will update w in some custom complicated way (using custom_model_update_op)? I assume this is possible given the fact that Optimizer ops like SGD are able to do this. Thanks in advance!

import tensorflow as tf
import numpy

# Create a model
w = tf.Variable(numpy.random.randn(), name="weight")
X = tf.placeholder(tf.int32, shape=(), name="X")

def custom_model_read_op(i, w):
    y = i*float(w)
    return y
y = tf.py_func(custom_model_read_op, [X, w], [tf.float64], name="read_func")

def custom_model_update_op(i, w):


 ==>       # How to update w (the model stored in a Variable above) based on the value of i and some crazy logic?

    return 0
crazy_update = tf.py_func(custom_model_update_op, [X, w], [tf.int64], name="update_func")



with tf.Session() as sess:

    tf.global_variables_initializer().run()

    for i in range(10):
        y_out, __ = sess.run([y, crazy_update], feed_dict={X: i})
        print("y=", "{:.4f}".format(y_out[0]))

Upvotes: 2

Views: 216

Answers (1)

herrtim
herrtim

Reputation: 2755

Well, I'm not sure this is the best way, but it accomplishes when I need. I don't have a py_func where the update on w occurs, but I do update it in the read_op, passing it back as a return value, and finally using the assign function to modify it outside of the custom op. If any Tensorflow experts can confirm that this is a good legitimate way to do that, I'd appreciate it.

import tensorflow as tf
import numpy

# Create a model
w = tf.Variable(numpy.random.randn(), name="weight")
X = tf.placeholder(tf.int32, shape=(), name="X")

def custom_model_read_op(i, w):
    y = i*float(w)
    w = custom_model_update(w)
    return y, w
y = tf.py_func(custom_model_read_op, [X, w], [tf.float64, tf.float64], name="read_func")

def custom_model_update(w):
    # update w (the model stored in a Variable above) based on the vaue of i and some crazy logic
    return w + 1

with tf.Session() as sess:

    tf.global_variables_initializer().run()

    for i in range(10):
        y_out, w_modified = sess.run(y, feed_dict={X: i})
        print("y=", "{:.4f}".format(y_out))
        assign_op = w.assign(w_modified)
        sess.run(assign_op)
        print("w=", "{:.4f}".format(sess.run(w)))

Upvotes: 1

Related Questions