Miller Zhu
Miller Zhu

Reputation: 727

Tensorflow print in function

I have a function in a file neural_network.py that defines a loss function:

def loss(a, b):
    ...
    debug = tf.Print(a, [a], message = 'debug: ')
    debug.eval(session = ???)
    return tf.add(a, b)

To explain, somewhere in this function I want to print a tensor. However, I don't have any session declared in this function; my sessions are declared in another file called forecaster.py. Therefore, when I try to put tf.Print() in loss(), I can't because I don't know which session to eval with. Is there a way to solve this problem, either by using tf.Print() or other debug methods? Thanks!

Upvotes: 3

Views: 1142

Answers (1)

standy
standy

Reputation: 1065

tf.Print works as an identity function which returns the same tensor that you passed as the first parameter, having a side effect of printing the list of tensors specified as the second parameter.

So you should use as following:

def loss(a, b):
    ...
    a = tf.Print(a, [a], message = 'debug: ')
    return tf.add(a, b)

a will be printed each time tensor tf.add(a, b) is evaluated.

Upvotes: 2

Related Questions