Reputation: 447
I am trying to read a big Tensorflow project. For a project that the nodes of the computation graph are scattered around the project, I wonder if there is a way to store a Tensor node of the computation graph and add that node to the fetch list in sess.run?
For example, if I want to add probs at line 615 of https://github.com/allenai/document-qa/blob/master/docqa/nn/span_prediction.py to a global namespace, is there a method like tf.add_node(probs, "probs"), and later I could get tf.get_node("probs"), just for the sake of conveniently passing node around the project.
A more general question would be, what will be a better idea to structure the tensorflow code and improve the efficiency of experimenting with different models.
Upvotes: 0
Views: 308
Reputation: 108
Of course you can. To retrieve it later, you'll have to give it a name so that you can retrieve it by name. Take probs
in your code as an example. It's created with tf.nn.softmax()
function, the API for which is shown below.
tf.nn.softmax(
logits,
axis=None,
name=None,
dim=None
)
See the parameter name
? You can add this parameter to line 615 like this:
probs = tf.nn.softmax(all_logits, name='my_tensor')
Later when you need it, you can call tf.Graph.get_tensor_by_name(name)
to retrieve this tensor.
graph = tf.get_default_graph()
retrieved_probs = graph.get_tensor_by_name('my_tensor:0')
'my_tensor'
is the name of the softmax operation, and ':0' should be added to the end of it meaning that you're retrieving the tensor instead of the operation. When calling Graph.get_operation_by_name()
, no ':0' should be added.
You'll have to make sure that the tensor exists(it might be created in the code executed before this line, or it might be restored from a meta graph file). If it's created in a variable scope, you'll also have to add the scope name and a '/' in the front of the name
param. For example, 'my_scope/my_tensor:0'
.
Upvotes: 1