Philip
Philip

Reputation: 11

IndexError in SBX (Stable Baselines 3) with Flax: "tuple index out of range" during Actor Network Initialization

I previously implemented SAC with stable-baselines3 in a custom Gymnasium environment, and it worked. Now, I’m trying to use stable-baselines3 JAX (SBX) in the same environment but encounter this error during SAC model initialization:

"/workspaces/ros2_ws_humble/src/rl_node/training_loops_method1/run_method1.py", line 158, in run_test
    model = SAC(
  File "/usr/local/lib/python3.10/dist-packages/sbx/sac/sac.py", line 112, in __init__
    self._setup_model()
  File "/usr/local/lib/python3.10/dist-packages/sbx/sac/sac.py", line 127, in _setup_model
    self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
  File "/usr/local/lib/python3.10/dist-packages/sbx/sac/policies.py", line 120, in build#
    params=self.actor.init(actor_key, obs),
  File "/usr/local/lib/python3.10/dist-packages/sbx/sac/policies.py", line 35, in __call__
    x = nn.Dense(n_units)(x)
  File "/usr/local/lib/python3.10/dist-packages/flax/linen/linear.py", line 237, in __call__
    (jnp.shape(inputs)[-1], self.features),
IndexError: tuple index out of range

Here I’m initializing the SAC model:

def run_test(config_, rl_node, run_count):
    np.seterr(all='raise')
    th.autograd.set_detect_anomaly(True)
    env = None
    mode = None
    env = make_env(config_, rl_node)
    env = RecordEpisodeStatistics(env, buffer_length=100)
    env = DummyVecEnv([lambda: env])
    model_name = "Agent_Long_absolut_05"
    policy_kwargs = {"activation_fn": th.nn.Mish,"net_arch": {"pi":[32,32],"qf": [64, 64, 64]}}
    model = SAC("MultiInputPolicy", env, learning_rate=0.01, gamma=0.8, batch_size=128, verbose=1, policy_kwargs=policy_kwargs,tensorboard_log=f"{parent_dir_path}/logs",device="cuda")

Below is the observsation space:

 obs_space = {
            'obs_long_acc': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
            'obs_long_jerk': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
            'obs_relative_speed': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
            'obs_relative_distance': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
            'obs_time_elapsed' : gym.spaces.Box(low=0, high=100000, shape=(1,), dtype=np.float32),
            'obs_parameters_valid' : gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
            'obs_first_loop_err' : gym.spaces.Box(low=0, high=10, shape=(1,), dtype=np.float32),
            'obs_second_loop_err': gym.spaces.Box(low=0, high=10, shape=(1,), dtype=np.float32),
            'obs_parameter_score_1': gym.spaces.Box(low=-100, high=0, shape=(24,), dtype=np.float32),
            'obs_parameter_score_2': gym.spaces.Box(low=-100, high=0, shape=(25,), dtype=np.float32),
            'obs_crash': gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
        }
        
        self.observation_space = gym.spaces.Dict(obs_space)

What could be causing the IndexError in this case? Is there a mismatch between observation space and the policy architecture in SBX? I’ve been stuck on this issue for a while and would really appreciate any guidance to resolve it. Thank you

I tried using a flattening operation to ensure the observation space was correctly flattened, but the error still persists during initialization.

Upvotes: 1

Views: 14

Answers (0)

Related Questions