Reputation: 79
I like to build python classes around my Tensorflow models to make them portable and easier to work with (at least in my mind). The approach I take is to write something like
class MyAwesomeModel(object):
def __init__(self, some_graph_params):
# bunch of code that defines the tensors, optimizer, etc...
# e.g. self.mytensor = tf.placeholder(tf.float32, [1])
def Train(self, tfsession, input_val):
# some code that calls the run() method on tfsession, etc.
def other_methods(self):
# other things like testing, plotting, etc. all managed nicely
# by the state that MyAwesomeModel instances maintain
I have two models that are very similar. The only differences are in a few places in the computational graph architecture - I would like to create a base class that has all the common functionality and just have child classes which overwrite a few things in the base class. Here's how it would work in my mind
Say my base class looks like this
import tensorflow as tf
class BaseClass(object):
def __init__(self, multiplier):
self.multiplier = multiplier
# This is where I construct the graph
self.inputnode = tf.placeholder(tf.float32, [1])
self.tensor1 = tf.constant(self.multiplier, dtype=tf.float32) * self.inputnode
self.tensor2 = self.tensor1 # this is where the two
# child classes will differ
self.tensoroutput = 10*self.tensor2
def forward_pass(self, tfsession, input_val):
return tfsession.run(self.tensoroutput,
feed_dict={self.inputnode: [input_val]})
def other_methods(self):
print("doing something here...")
print(self.multiplier)
Then I pop in two child classes, which just redefine the relationship between self.tensor2
and self.tensor1
:
class ChildClass1(BaseClass):
def __init__(self, multiplier):
BaseClass.__init__(self, multiplier)
self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)
class ChildClass2(BaseClass):
def __init__(self, multiplier):
BaseClass.__init__(self, multiplier)
self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)
My objective would be to then run the following:
cc1 = ChildClass1(2) # multiplier is 2
mysession = tf.Session()
mysession.run(tf.global_variables_initializer())
print(cc1.forward_pass(mysession, 5))
If this were to work how I would like to, then the result would be ((5*2)+5)*10 = 150. If the object cc1 was of type ChildClass2(2) then I would like the result to be ((5*2)+4)*10 = 140.
However, when I run the above code, the result is 100, which would be consistent with the child class never overriding the definition for self.tensor2 that is first encountered in the base class. I thought I needed to have that wonky line self.tensor2 = self.tensor1
because otherwise the following line will complain about self.tensor2
not existing. What I really want is for the child classes to overwrite the definition for self.tensor2
and nothing else. What is the proper way to do this?
Thanks a bunch!
Upvotes: 0
Views: 525
Reputation: 5206
self.tensoroutput is never overridden, so its value does not depend on whatever base class you have. Make it a method, and then it'll work.
Upvotes: 1