Reputation: 77
I have an observation space in the format of Box but is actually defined as numpy array.
For example:
Box(low=np.array([0, 0, 0]), high=np.array([15, 10,150]))
Now I want to get the q_value for a single observation, but since the observation is Box the code of the stable baseline 3 is:
if isinstance(observation_space, spaces.Box):
return obs.float()
But, the input observation does not have float attribute, So in this case how can I access the q_values of all the actions?
Upvotes: 4
Views: 857
Reputation: 77
So, I figured out how to resolve it. Will post it here in case it's someone else's problem too.
observation = obs.reshape((-1,) + model.observation_space.shape)
observation = obs_as_tensor(observation, device)
with th.no_grad():
q_values = model.q_net(observation)
Upvotes: 3