naizz
naizz

Reputation: 77

How to get the Q-values in DQN in stable baseline 3?

I have an observation space in the format of Box but is actually defined as numpy array.

For example:

Box(low=np.array([0, 0, 0]), high=np.array([15, 10,150]))

Now I want to get the q_value for a single observation, but since the observation is Box the code of the stable baseline 3 is:

if isinstance(observation_space, spaces.Box):
    return obs.float()

But, the input observation does not have float attribute, So in this case how can I access the q_values of all the actions?

Upvotes: 4

Views: 857

Answers (1)

naizz
naizz

Reputation: 77

So, I figured out how to resolve it. Will post it here in case it's someone else's problem too.

observation = obs.reshape((-1,) + model.observation_space.shape)
observation = obs_as_tensor(observation, device)
with th.no_grad():
    q_values = model.q_net(observation)

Upvotes: 3

Related Questions