André Claudino
André Claudino

Reputation: 588

Passing batches to Tensorflow Structural Time Seires

I am creating a model for time series prediction with Tensorflow Probability, following this tutorial. In these examples I need to pass all data at once, but this is prohibitive when dealing with big data (my case), how should I pass batches or any other kind of lazy loaded data to this tool?

Upvotes: 0

Views: 58

Answers (1)

Brian Patton
Brian Patton

Reputation: 1076

This is a general problem for most probabilistic inference cases: using most non-full-batch gradients will yield biased samples.

You should be able to write a target_log_prob_fn with a tf.custom_gradient to iterate over a tf.data.Dataset iterator. Since the target logprob is a scalar, you can accumulate both gradients and logprobs as the function proceeds over all minibatches in the dataset.

ds = build_dataset()
def build_model(params):
  return time_series_model(..)

@tf.custom_gradient
@tf.function  # autograph should turn the dataset loop into a tf.while_loop.
def log_prob(*params):
  total_lp = 0.
  total_grad = tf.nest.map_structure(tf.zeros_like, params)
  for batch in ds:
    lp, grad = tfp.math.value_and_gradient(
        lambda *p: build_model(p).log_prob(batch),
        params)
    total_lp += lp
    total_grad = tf.nest.map_structure(lambda x,y: x+y, total_grad, grad)
  return total_lp, lambda dy: tf.nest.map_structure(lambda g: dy*g, total_grad)

Upvotes: 1

Related Questions