Reputation: 129
How can I visualize the graph on TensorBoard using the Estimator API of TensorFlow without running training or evaluation?
I know how it is achieved with the session API when you have access to the Graph object, but could not find anything for the Estimator API.
Upvotes: 2
Views: 1242
Reputation: 1680
Estimators create and manage tf.Graph
and tf.Session
objects for you. These objects are therefore not easily accessible. Please note that, by default, the graph is exported inside the events file when you call estimator.train
.
What you can do however, is call your model_function
outside of tf.estimator
and then use the classic tf.summary.FileWriter()
to export the graph.
Here is a code snippet with a very simple estimator that just applies a dense layer to the input:
import tensorflow as tf
import numpy as np
# Basic input_fn
def input_fn(x, y, batch_size=4):
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(batch_size).repeat(1)
return dataset
# Basic model_fn that just apply a dense layer to an input
def model_fn(features, labels, mode):
global_step = tf.train.get_or_create_global_step()
y = tf.layers.dense(features, 1)
increment_global_step = tf.assign_add(global_step, 1)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions={'preds':y},
loss=tf.constant(0.0, tf.float32),
train_op=increment_global_step)
# Fake data
x = np.random.normal(size=[10, 100])
y = np.random.normal(size=[10])
# Just to show that the estimator works
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=lambda: input_fn(x, y), steps=1)
# Classic way of exporting the graph using placeholders and an outside call to the model_fn
with tf.Graph().as_default() as g:
# Placeholders
features = tf.placeholder(tf.float32, x.shape)
labels = tf.placeholder(tf.float32, y.shape)
# Creates the graph
_ = model_fn(features, labels, None)
# Export the graph to ./graph
with tf.Session() as sess:
train_writer = tf.summary.FileWriter('./graph', sess.graph)
Upvotes: 1
Reputation: 5064
To be able visualize the graph with the TensorBoard you must have it in the events file. If during the training you instantiate the writer with the session graph:
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph)
you should have it.
Given that, just call tensorboard and provide it the path where the events file is stored:
tensorboard --logdir=path/to/log-directory
and open the Graph tab.
Upvotes: 1