shoubhik
shoubhik

Reputation: 129

How to visualize TensorFlow graph without running train/evaluate with estimator API?

How can I visualize the graph on TensorBoard using the Estimator API of TensorFlow without running training or evaluation?

I know how it is achieved with the session API when you have access to the Graph object, but could not find anything for the Estimator API.

Upvotes: 2

Views: 1242

Answers (2)

Olivier Dehaene
Olivier Dehaene

Reputation: 1680

Estimators create and manage tf.Graph and tf.Session objects for you. These objects are therefore not easily accessible. Please note that, by default, the graph is exported inside the events file when you call estimator.train.

What you can do however, is call your model_function outside of tf.estimator and then use the classic tf.summary.FileWriter() to export the graph.

Here is a code snippet with a very simple estimator that just applies a dense layer to the input:

import tensorflow as tf
import numpy as np

# Basic input_fn
def input_fn(x, y, batch_size=4):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.batch(batch_size).repeat(1)
    return dataset

# Basic model_fn that just apply a dense layer to an input
def model_fn(features, labels, mode):
    global_step = tf.train.get_or_create_global_step()

    y = tf.layers.dense(features, 1)

    increment_global_step = tf.assign_add(global_step, 1)

    return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions={'preds':y},
            loss=tf.constant(0.0, tf.float32),
            train_op=increment_global_step)

# Fake data
x = np.random.normal(size=[10, 100])
y = np.random.normal(size=[10])

# Just to show that the estimator works
estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=lambda: input_fn(x, y), steps=1)


# Classic way of exporting the graph using placeholders and an outside call to the model_fn
with tf.Graph().as_default() as g:
    # Placeholders
    features = tf.placeholder(tf.float32, x.shape)
    labels = tf.placeholder(tf.float32, y.shape)

    # Creates the graph
    _ = model_fn(features, labels, None)

    # Export the graph to ./graph
    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter('./graph', sess.graph)

Upvotes: 1

Dmytro Prylipko
Dmytro Prylipko

Reputation: 5064

To be able visualize the graph with the TensorBoard you must have it in the events file. If during the training you instantiate the writer with the session graph:

train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph)

you should have it.

Given that, just call tensorboard and provide it the path where the events file is stored:

tensorboard --logdir=path/to/log-directory

and open the Graph tab.

Upvotes: 1

Related Questions