Reputation: 588
I am creating a model for time series prediction with Tensorflow Probability, following this tutorial. In these examples I need to pass all data at once, but this is prohibitive when dealing with big data (my case), how should I pass batches or any other kind of lazy loaded data to this tool?
Upvotes: 0
Views: 58
Reputation: 1076
This is a general problem for most probabilistic inference cases: using most non-full-batch gradients will yield biased samples.
You should be able to write a target_log_prob_fn
with a tf.custom_gradient
to iterate over a tf.data.Dataset
iterator. Since the target logprob is a scalar, you can accumulate both gradients and logprobs as the function proceeds over all minibatches in the dataset.
ds = build_dataset()
def build_model(params):
return time_series_model(..)
@tf.custom_gradient
@tf.function # autograph should turn the dataset loop into a tf.while_loop.
def log_prob(*params):
total_lp = 0.
total_grad = tf.nest.map_structure(tf.zeros_like, params)
for batch in ds:
lp, grad = tfp.math.value_and_gradient(
lambda *p: build_model(p).log_prob(batch),
params)
total_lp += lp
total_grad = tf.nest.map_structure(lambda x,y: x+y, total_grad, grad)
return total_lp, lambda dy: tf.nest.map_structure(lambda g: dy*g, total_grad)
Upvotes: 1