Vince Gatto
Vince Gatto

Reputation: 415

Computing exact moving average over multiple batches in tensorflow

During training, I would like to write the average loss over the last N mini-batches to SummaryWriter as a way of smoothing the very noisy batch loss. It's easy to compute this in python and print it, but I would like to add this to a summary so that I can see it in tensorboard. Here's an overly simplified example of what I'm doing now.

losses = []
for i in range(10000):
  _, loss = session.run([train_op, loss_op])
  losses.append(loss)
  if i % 100 == 0:
    # How to produce a scalar_summary here?
    print sum(losses)/len(losses)
    losses = []

I'm aware that I could use ExponentialMovingAverage with a decay of 1.0, but I would still need some way to reset this every N batches. Really, if all I care about is visualizing loss in tensorboard, the reset probably isn't necessary, but I'm still curious how one would go about aggregating across batches for other reasons (e.g. computing total accuracy over a test dataset that is too big to run in a single batch).

Upvotes: 3

Views: 3534

Answers (2)

user728291
user728291

Reputation: 4138

Passing data from python to a graph function like tf.scalar_summary can be done using a placeholder and feed_dict.

average_pl = tf.placeholder(tf.float32)
average_summary = tf.summary.scalar("average_loss", average_pl)
writer = tf.summary.FileWriter("/tmp/mnist_logs", sess.graph_def)

losses = []
for i in range(10000):
  _, loss = session.run([train_op, loss_op])
  losses.append(loss)
  if i % 100 == 0:
    # How to produce a scalar_summary here?
    feed = {average_pl: sum(losses)/len(losses)}
    summary_str = sess.run(average_summary, feed_dict=feed)
    writer.add_summary(summary_str, i)
    losses = []

I haven't tried it and this was hastily copied from the visualizing data how to but I expect something like this would work.

Upvotes: 3

Greg McGlynn
Greg McGlynn

Reputation: 71

You can manually construct the Summary object, like this:

from tensorflow.core.framework import summary_pb2

def make_summary(name, val):
    return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, 
                                                                simple_value=val)])

summary_writer.add_summary(make_summary('myvalue', myvalue), step)

Upvotes: 7

Related Questions