yichudu
yichudu

Reputation: 305

How to print tensor in Tensorflow `custom estimator` for debugging?

In low-level-api, we can use

print(session.run(xx_tensor_after_xx_operation, feed_dict=feed_dict))

to get the real data for debugging. But in custom estimator, how to debug these tensors?

Here is my snippet for a vivid sample:

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS


def yichu_dssm_model_fn(
        features,  # This is batch_features from input_fn
        labels,  # This is batch_labels from input_fn
        mode,  # An instance of tf.estimator.ModeKeys
        params):
    # word_id sequence in content
    content_input = tf.feature_column.input_layer(features, params['feature_columns'])
    content_embedding_matrix = tf.get_variable(name='content_embedding_matrix',
                                               shape=[FLAGS.max_vocab_size, FLAGS.word_vec_dim])
    content_embedding = tf.nn.embedding_lookup(content_embedding_matrix, content_input)
    content_embedding = tf.reshape(content_embedding, shape=[-1, FLAGS.max_text_len, FLAGS.word_vec_dim, 1])
    content_conv = tf.layers.Conv2D(filters=100, kernel_size=[3, FLAGS.word_vec_dim])

    content_conv_tensor = content_conv(content_embedding)
    """
      in low-level-api, we can use `print(session.run(content_conv_tensor))` to get the real data to debug.
      But in custom estimator, how to debug these tensors?
    """

Upvotes: 1

Views: 4229

Answers (3)

Jason
Jason

Reputation: 2132

tf.Print is deprecated, use tf.print, but it's not easy to use

best option is a logging hook

hook =  \
    tf.train.LoggingTensorHook({"var is:": var_to_print},
                               every_n_iter=10)
return tf.estimator.EstimatorSpec(mode, loss=loss, 
                                  train_op=train_op,
                                  training_hooks=[hook])

Upvotes: 2

OGCheeze
OGCheeze

Reputation: 74

sess = tf.InteractiveSession() test = sess.run(features) print('features:') print(test)

Although this causes error, it still prints out the tensor values. Error occurs right after the print so you can only use it for checking the tensor values.

Upvotes: 0

dm0_
dm0_

Reputation: 2156

You can use tf.Print. It adds operation to the graph that prints tensors content to standard error when executed.

content_conv_tensor = tf.Print(content_conv_tensor, [content_conv_tensor], 'content_conv_tensor: ')

Upvotes: 3

Related Questions