Reputation: 13
As said in the title, I am trying to create a mixture of multivariate normal distributions using tensorflow probability package.
In my original project, am feeding the weights of the categorical, the loc and the variance from the output of a neural network. However when creating the graph, I get the following error:
components[0] batch shape must be compatible with cat shape and other component batch shapes
I recreated the same problem using placeholders:
import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions
tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()
l1 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_2')
log_std = tf.compat.v1.get_variable('log_std', [1, 2], dtype=tf.float32,
initializer=tf.constant_initializer(1.0),
trainable=True)
mix = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,1], name='weights')
cat = tfp.distributions.Categorical(probs=[mix, 1.-mix])
components = [
tfp.distributions.MultivariateNormalDiag(loc=l1, scale_diag=tf.exp(log_std)),
tfp.distributions.MultivariateNormalDiag(loc=l2, scale_diag=tf.exp(log_std)),
]
bimix_gauss = tfp.distributions.Mixture(
cat=cat,
components=components)
So, my question is, what am I doing wrong? I looked into the error and it seems tensorshape_util.is_compatible_with
is what raises the error but I don't see why.
Thanks!
Upvotes: 1
Views: 1223
Reputation: 1076
When the components are the same type, MixtureSameFamily should be more performant.
There you only pass a single Categorical instance (with .batch_shape [b1,b2,...,bn]) and a single MVNDiag instance (with .batch_shape [b1,b2,...,bn,numcats]).
For only two classes, I wonder if Bernoulli would work?
Upvotes: 1
Reputation: 61
It seems you provided a mis-shaped input to tfp.distributions.Categorical
. It's probs
parameter should be of shape [batch_size, cat_size]
while the one you provide is rather [cat_size, batch_size, 1]
. So maybe try to parametrize probs
with tf.concat([mix, 1-mix], 1)
.
There may also be a problem with yourlog_std
which doesn't have the same shape as l1
and l2
. In case MultivariateNormalDiag
doesn't properly broadcast it, try to specify it's shape as (None, 2)
or to tile it so that it's first dimension corresponds to that of your location parameters.
Upvotes: 0