nViz
nViz

Reputation: 79

Class inheritance in classes that encapsulate TensorFlow computational graphs

I like to build python classes around my Tensorflow models to make them portable and easier to work with (at least in my mind). The approach I take is to write something like

class MyAwesomeModel(object):
    def __init__(self, some_graph_params):
        # bunch of code that defines the tensors, optimizer, etc...
        # e.g.  self.mytensor = tf.placeholder(tf.float32, [1])
    def Train(self, tfsession, input_val):
        # some code that calls the run() method on tfsession, etc.
    def other_methods(self):
        # other things like testing, plotting, etc. all managed nicely
        # by the state that MyAwesomeModel instances maintain

I have two models that are very similar. The only differences are in a few places in the computational graph architecture - I would like to create a base class that has all the common functionality and just have child classes which overwrite a few things in the base class. Here's how it would work in my mind

Say my base class looks like this

import tensorflow as tf

class BaseClass(object):
    def __init__(self, multiplier):
        self.multiplier = multiplier
        # This is where I construct the graph
        self.inputnode = tf.placeholder(tf.float32, [1])
        self.tensor1 = tf.constant(self.multiplier, dtype=tf.float32) * self.inputnode
        self.tensor2 = self.tensor1  # this is where the two 
                                     # child classes will differ
        self.tensoroutput = 10*self.tensor2
    def forward_pass(self, tfsession, input_val):
        return tfsession.run(self.tensoroutput, 
                             feed_dict={self.inputnode: [input_val]})
    def other_methods(self):
        print("doing something here...")
        print(self.multiplier)

Then I pop in two child classes, which just redefine the relationship between self.tensor2 and self.tensor1:

class ChildClass1(BaseClass):
    def __init__(self, multiplier):
        BaseClass.__init__(self, multiplier)
        self.tensor2 = self.tensor1 + tf.constant(5.0, dtype=tf.float32)

class ChildClass2(BaseClass):
    def __init__(self, multiplier):
        BaseClass.__init__(self, multiplier)
        self.tensor2 = self.tensor1 + tf.constant(4.0, dtype=tf.float32)

My objective would be to then run the following:

cc1 = ChildClass1(2)   # multiplier is 2
mysession = tf.Session()
mysession.run(tf.global_variables_initializer())
print(cc1.forward_pass(mysession, 5))

If this were to work how I would like to, then the result would be ((5*2)+5)*10 = 150. If the object cc1 was of type ChildClass2(2) then I would like the result to be ((5*2)+4)*10 = 140.

However, when I run the above code, the result is 100, which would be consistent with the child class never overriding the definition for self.tensor2 that is first encountered in the base class. I thought I needed to have that wonky line self.tensor2 = self.tensor1 because otherwise the following line will complain about self.tensor2 not existing. What I really want is for the child classes to overwrite the definition for self.tensor2 and nothing else. What is the proper way to do this?

Thanks a bunch!

Upvotes: 0

Views: 525

Answers (1)

Alexandre Passos
Alexandre Passos

Reputation: 5206

self.tensoroutput is never overridden, so its value does not depend on whatever base class you have. Make it a method, and then it'll work.

Upvotes: 1

Related Questions