Reputation: 515
This question is with regards to the low-level Tensorflow 1.x API. Given a Tensor
to Session.run()
, I am unclear as how to how Tensorflow traverses the computation graph.
Say I have some code like this:
a = tf.constant(1.0)
b = tf.subtract(a, 1.0)
c = tf.add(b, 2.0)
d = tf.multiply(c,3)
sess = tf.Session()
sess.run(d)
The subtract, add, and multiply operations are not all stored in the Tensor d
, right? I know the Tensor
object have graph
and op
fields; are these fields some how accessed recursively to get all the operations required to compute d
?
EDIT: adding output
print(tf.get_default_graph().as_graph_def())
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node {
name: "Sub/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node {
name: "Sub"
op: "Sub"
input: "Const"
input: "Sub/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Add/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "Add"
op: "Add"
input: "Sub"
input: "Add/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Mul/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 3.0
}
}
}
}
node {
name: "Mul"
op: "Mul"
input: "Add"
input: "Mul/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 38
}
Upvotes: 2
Views: 580
Reputation: 5555
That's the whole point of Tensorflow's static computational graph. When you build the graph, Tensorflow implicitly builds a static graph in the background. Then, when you execute a node in the graph, Tensorflow knows the exact set of operations that lead to that node. This has several benefits:
Use this command, to see the inputs of each node:
print(tf.get_default_graph().as_graph_def())
For example, if you execute this on your small graph, you will see the following, starting from the node d = tf.multiply(c,3)
:
name: "Mul"
op: "Mul"
input: "Add"
Then c = tf.add(b, 2.0)
:
name: "Add"
op: "Add"
input: "Sub"
Then b = tf.subtract(a, 1.0)
:
name: "Sub"
op: "Sub"
input: "Const"
And finally a = tf.constant(1.0)
:
name: "Const"
op: "Const"
Upvotes: 2