Flo
Flo

Reputation: 47

sess.run() does not run?

I'm a new here, studying tensorflow and encountering a problem.

import model_method
fittt(model_method.build(self,...),...parameters...)

The above is in the main.py importing model_method.py. Function fittt in main.py:

def fittt(model,...):
    model.fit(...)

build() in model_method.py:

def build(self,...):
    self.op_C,self.op_A = self.function_A(...)
    self.op_B = self.function_B(self.op_C,...)

fit() in model_method.py:

def fit(self,...):
    sess = tf.Session(graph=self.graph,config=config)
    BB,AA = sess.run([self.op_B,self.op_A],feed_dict)

To check running process, I added pdb.set_trace() at the beginning of function_A() and function_B() in model_method.py as follows:

def function_A(self,...):
    pdb.set_trace()
    ......

def function_B(self,...):
    pdb.set_trace()
    ......

The two pdb.set_trace() only stopped when the build() called and didn't work when sess.run([self.op_B,self.op_A],feed_dict) called. So it means the sess.run() didn't run function_A() and function_B() actually. I wonder why and wanna know how to make the two functions work?

Upvotes: 0

Views: 1389

Answers (1)

Aechlys
Aechlys

Reputation: 1306

By calling the model_method.build() function you create a computation graph. In this call every line of code is executed (hence why pdb stopped).

However, tf.Session.run(...) executes only those parts of computational graph which are necessary to compute the fetched values (self.op_A, self.op_B in your example). The function does not execute the entire build() function again.

Therefore the reason why pdb.set_trace() did not execute when you've run sess.run(...) is because they are not valid Tensor objects and hence not part of the computational graph.

UPDATE

Consider the following:

class My_Model:

  def __init__(self):
      self.np_input = np.random.normal(size=(10,2)) # 10x2

  def build(self):
      self._in = tf.placeholder(dtype=tf.float32, shape=[10, None]) # matrix 10xN
      W_exception = tf.random_normal(dtype=tf.float32, shape=[3,3]) # matrix 3x3
      W_success = tf.random_normal(dtype=tf.float32, shape=[2,3]) # matrix 2x3
      self.op_exception = tf.matmul(self._in, W_exception) # [10x2] x [3x3] = ERROR
      self.op_success = tf.matmul(self._in, W_success) # [10x2] x [2x3] = [10x3]
      print('Computational Graph Built')

  def fit_success(self):
      with tf.Session() as sess:
          res = sess.run(self.op_success, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

  def fit_exception(self):
      with tf.Session() as sess:
          res = sess.run(self.op_exception, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

and then calling:

m = My_Model()
m.build()
#> Computational Graph Built

m.fit_success()
#> Result shape: (10, 3)

m.fit_exception()
#> InvalidArgumentError: Matrix size-incompatible: In[0]: [10,2], In[1]: [3,3]

So to explain what you see there. We first define the computational graph in the build() function. The _in is our input tensor; None means the dimension 1 is determined dynamically - that is once we provide a tensor with specified values.

Then we defined two matrices W_exception and W_success which have all dimensions specified and their values will be randomly generated.

Then we define two operations, matrix multiplication, that each returns a tensor.

We called the build() function and created the computational graph, print() function is also executed but NOT added to the graph. Nothing is computed here. In fact, it can't even be, because the values of _in are not specified.

Now to show, that only necessary parts required for computation are evaluated, we call the fit_success() function, which simply multiplies the input tensor _in with the W_success tensor (with correct dimensions). We receive a tensor with correct shape: [10x3]. Note, that we receive no error that op_exception cannot be computed due to mismatched dimensions. That's because we do not need it to evaluate op_success.

Lastly, I just show that exception is indeed thrown when we try to evaluate the op_exception with the same input tensor.

Upvotes: 1

Related Questions