rj21
rj21

Reputation: 1

Reinforcement Learning cartpole

i'm trying to get my feet wet in reinforced learning and come across this cartpole problem, i already trained the model and got a score of 500 (which is way above the passing 200?). I also saved the training file and the model file, but when i test the model and try to predict the action using model.predict it comes with an error.

Code snippet:

environment_name = "CartPole-v1"
env = gym.make(environment_name, render_mode='human')

episodes = 5
for episodes in range(1, episodes + 1):
    obs = env.reset()
    done = False
    score = 0

    while not done:
        env.render()

        # this is the key change
        action, _ = model.predict(obs)  # we're now using our model here!
        obs, reward, done, info = env.step(action)

        score += reward
    print("Episode: {} Score: {}".format(episodes, score))
env.close()

Error in the code:

  ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    Input In [14], in <cell line: 2>()
          8 env.render()
         10 # this is the key change
    ---> 11 action, _= model.predict(obs) # we're now using our model here!
         12 obs, reward, done,
    
    ValueError: Error: Unexpected observation shape (2,) for Box environment, please use (4,) or (n_env, 4) for the observation shape.

i am expecting an output where the cartpole is balance since im getting a high training score.

i try to look at the obs:

env = gym.make('CartPole-v1')
obs = env.reset()
len(obs)

2

why the obs tuple is only 2 shouldn't it be 4? because the model.predict expected an input of (4,)?

Upvotes: 0

Views: 650

Answers (1)

lejlot
lejlot

Reputation: 66815

.reset() does not return observation. It returns a tuple (observation, info).

Change your code to

obs, info = env.reset()

Upvotes: 1

Related Questions