Reputation: 2985
I am defining several categorical samples as follows:
probs = torch.ones(K) / K
object_class = []
for a in pyro.plate('objects', A):
object_class = pyro.sample('object_classes_{}'.format(a),
dist.Categorical(probs))
object_classes.append(object_class)
Then, what I want to do but just can't figure out is to use those samples as if they were integers. When running inference, each object_class is an integer, but after some iterations, instead of being an integer, it is a tensor instead with all the possible values for the categorical variable and therefore I can't no longer use it as an integer. I tried the following, but it doesn't seem to be working:
torch.tensor([Vindex(cmm)[oject_classes[a]] for a in range(A)])
But I get an error saying that "only one element tensors can be converted to Python scalars".
Is it really possible to use categorical samples as integers or do I need to work with an additional dimension throughout the model?
Upvotes: 0
Views: 20