Reputation: 4318
I have the following architecture for my neural network
import torch
import torch.distributions as pyd
import toch.nn as nn
from torch.distributions import transforms as tT
from torch.distributions.transformed_distribution import TransformedDistribution
LOG_STD_MIN = -5
LOG_STD_MAX = 0
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
super().__init__(cache_size=cache_size)
@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
return self.atanh(y.clamp(-0.99, 0.99))
def log_abs_det_jacobian(self, x, y):
return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
def get_spec_means_mags(spec):
means = (spec.maximum + spec.minimum) / 2.0
mags = (spec.maximum - spec.minimum) / 2.0
means = Variable(torch.tensor(means).type(torch.FloatTensor), requires_grad=False)
mags = Variable(torch.tensor(mags).type(torch.FloatTensor), requires_grad=False)
return means, mags
class Split(torch.nn.Module):
def __init__(self, module, n_parts: int, dim=1):
super().__init__()
self._n_parts = n_parts
self._dim = dim
self._module = module
def forward(self, inputs):
output = self._module(inputs)
if output.ndim==1:
result=torch.hsplit(output, self._n_parts )
else:
chunk_size = output.shape[self._dim] // self._n_parts
result =torch.split(output, chunk_size, dim=self._dim)
return result
class Network(nn.Module):
def __init__(
self,
state,
act,
fc_layer_params=(),
):
super(Network, self).__init__()
self._act = act
self._layers = nn.ModuleList()
for hidden_size in fc_layer_params:
if len(self._layers)==0:
self._layers.append(nn.Linear(state.shape[0], hidden_size))
else:
self._layers.append(nn.Linear(hidden_size, hidden_size))
self._layers.append(nn.ReLU())
output_layer = nn.Linear(hidden_size,self._act.shape[0] * 2)
self._layers.append(output_layer)
self._act_means, self._act_mags = get_spec_means_mags(
self._act)
def _get_outputs(self, state):
h = state
for l in nn.Sequential(*(list(self._layers.children())[:-1])):
h = l(h)
self._mean_logvar_layers = Split(
self._layers[-1],
n_parts=2,
)
mean, log_std = self._mean_logvar_layers(h)
a_tanh_mode = torch.tanh(mean) * self._action_mags + self._action_means
log_std = torch.tanh(log_std).to(device=self.device)
log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
std = torch.exp(log_std)
a_distribution = TransformedDistribution(
base_distribution=Normal(loc=torch.full_like(mean, 0).to(device=self.device),
scale=torch.full_like(mean, 1).to(device=self.device)),
transforms=tT.ComposeTransform([
tT.AffineTransform(loc=self._action_means, scale=self._action_mags, event_dim=mean.shape[-1]),
TanhTransform(),
tT.AffineTransform(loc=mean, scale=std, event_dim=mean.shape[-1])]))
return a_distribution, a_tanh_mode
def get_log_density(self, state, action):
a_dist, _ = self._get_outputs(state)
log_density = a_dist.log_prob(action)
return log_density
def __call__(self, state):
a_dist, a_tanh_mode = self._get_outputs(state)
a_sample = a_dist.sample()
log_pi_a = a_dist.log_prob(a_sample)
return a_tanh_mode, a_sample, log_pi_a
When I run my code I get this error message:
action = self._a_network(latent_states)[1]
File "/home/planner_regularizer.py", line 182, in __call__
a_dist, a_tanh_mode = self._get_outputs(state.to(device=self.device))
File "/home/planner_regularizer.py", line 159, in _get_outputs
a_distribution = TransformedDistribution(
File "/home/dm_control/lib/python3.8/site-packages/torch/distributions/transformed_distribution.py", line 61, in __init__
raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
ValueError: base_distribution needs to have shape with size at least 6, but got torch.Size([6]).
How can I fix this error message?
Update: if I remove event_dim
from AffineTransform
, I wouldn't get above error but the output of log_prob
would be size 1 which is not correct. Any suggestion?
Upvotes: 2
Views: 267
Reputation: 101
The error is telling you exactly what the problem is: TransformedDistribution expects the base distribution to have event_shape of at least length 6, but you are passing a Normal distribution with event_shape=[6]. This minimum length requirement exists because TransformedDistribution applies affine transforms, which require at least 2 dimensions:
1 for the batch_shape 1 for the event coordinates being transformed
Simply construct your Normal distribution with more dimensions, e.g. Normal(loc=torch.zeros(1, 6), scale=torch.ones(1, 6))
Upvotes: 2