Reputation: 3571
When running y = multivariate_normal(np.zeros(d), np.eye(d)).rvs()
we obtain a sample of dimension (d, )
. However, when d=1
we obtain a scalar, which makes sense since it's 1 dimensional. Unfortunately, I have some piece of code that must work for any number of dimensions, including d=1
, and basically takes the dot product of a d
dimensional vector x
with y
. This breaks for d=1
. How can I fix it?
import numpy as np
from scipy.stats import multivariate_normal as MVN
def mwe_function(d, x):
"""Minimal Working Example"""
y = MVN(np.zeros(d), np.eye(d)).rvs()
return x @ y
mwe_function(2, np.ones(2)) # This works
mwe_function(1, np.ones(1)) # This doesn't
IMPORTANT: I want to avoid if statements. One could simply use scipy.stats.norm
in that case, but I want to avoid if statements as they would slow down the code.
Upvotes: 1
Views: 206
Reputation: 1068
You can use np.reshape
to fix the shape of your sample. By using -1
to specify the length of the first dimension, you will always get a 1-dimensional array and no scalar.
import numpy as np
from scipy.stats import multivariate_normal as MVN
def mwe_function(d, x):
"""Minimal Working Example"""
y = MVN(np.zeros(d), np.eye(d)).rvs().reshape([-1])
return x @ y
v0 = mwe_function(2, np.ones(2)) # This works
print(v0) # -0.5718013906409207
v1 = mwe_function(1, np.ones(1)) # This works as well :-)
print(v1) # -0.20196038784485093
where .reshape([-1])
does the job.
Personally, I prefer reshaping over using np.atleast_1d
, since the effect is directly visible - but in the end it is a matter of taste.
Upvotes: 1