Reputation: 2755
I need to write a custom Op
in python
, which will generate an output based on a model and another op that will update the model. In the following sample code, I have a very simple model of just a scaler, w
(but in reality it will be a nxm matrix). I figured out how to "read" the model as demonstrated in the custom_model_read_op
function (in reality much more complicated). However, how can I create something similar that will update w
in some custom complicated way (using custom_model_update_op
)? I assume this is possible given the fact that Optimizer
ops like SGD are able to do this. Thanks in advance!
import tensorflow as tf
import numpy
# Create a model
w = tf.Variable(numpy.random.randn(), name="weight")
X = tf.placeholder(tf.int32, shape=(), name="X")
def custom_model_read_op(i, w):
y = i*float(w)
return y
y = tf.py_func(custom_model_read_op, [X, w], [tf.float64], name="read_func")
def custom_model_update_op(i, w):
==> # How to update w (the model stored in a Variable above) based on the value of i and some crazy logic?
return 0
crazy_update = tf.py_func(custom_model_update_op, [X, w], [tf.int64], name="update_func")
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(10):
y_out, __ = sess.run([y, crazy_update], feed_dict={X: i})
print("y=", "{:.4f}".format(y_out[0]))
Upvotes: 2
Views: 216
Reputation: 2755
Well, I'm not sure this is the best way, but it accomplishes when I need. I don't have a py_func
where the update on w
occurs, but I do update it in the read_op
, passing it back as a return value, and finally using the assign
function to modify it outside of the custom op. If any Tensorflow experts can confirm that this is a good legitimate way to do that, I'd appreciate it.
import tensorflow as tf
import numpy
# Create a model
w = tf.Variable(numpy.random.randn(), name="weight")
X = tf.placeholder(tf.int32, shape=(), name="X")
def custom_model_read_op(i, w):
y = i*float(w)
w = custom_model_update(w)
return y, w
y = tf.py_func(custom_model_read_op, [X, w], [tf.float64, tf.float64], name="read_func")
def custom_model_update(w):
# update w (the model stored in a Variable above) based on the vaue of i and some crazy logic
return w + 1
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(10):
y_out, w_modified = sess.run(y, feed_dict={X: i})
print("y=", "{:.4f}".format(y_out))
assign_op = w.assign(w_modified)
sess.run(assign_op)
print("w=", "{:.4f}".format(sess.run(w)))
Upvotes: 1