Reputation: 11
I previously implemented SAC with stable-baselines3 in a custom Gymnasium environment, and it worked. Now, I’m trying to use stable-baselines3 JAX (SBX) in the same environment but encounter this error during SAC model initialization:
"/workspaces/ros2_ws_humble/src/rl_node/training_loops_method1/run_method1.py", line 158, in run_test
model = SAC(
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/sac.py", line 112, in __init__
self._setup_model()
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/sac.py", line 127, in _setup_model
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/policies.py", line 120, in build#
params=self.actor.init(actor_key, obs),
File "/usr/local/lib/python3.10/dist-packages/sbx/sac/policies.py", line 35, in __call__
x = nn.Dense(n_units)(x)
File "/usr/local/lib/python3.10/dist-packages/flax/linen/linear.py", line 237, in __call__
(jnp.shape(inputs)[-1], self.features),
IndexError: tuple index out of range
Here I’m initializing the SAC model:
def run_test(config_, rl_node, run_count):
np.seterr(all='raise')
th.autograd.set_detect_anomaly(True)
env = None
mode = None
env = make_env(config_, rl_node)
env = RecordEpisodeStatistics(env, buffer_length=100)
env = DummyVecEnv([lambda: env])
model_name = "Agent_Long_absolut_05"
policy_kwargs = {"activation_fn": th.nn.Mish,"net_arch": {"pi":[32,32],"qf": [64, 64, 64]}}
model = SAC("MultiInputPolicy", env, learning_rate=0.01, gamma=0.8, batch_size=128, verbose=1, policy_kwargs=policy_kwargs,tensorboard_log=f"{parent_dir_path}/logs",device="cuda")
Below is the observsation space:
obs_space = {
'obs_long_acc': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_long_jerk': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_relative_speed': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_relative_distance': gym.spaces.Box(low=-1000, high=1000, shape=(1,), dtype=np.float32),
'obs_time_elapsed' : gym.spaces.Box(low=0, high=100000, shape=(1,), dtype=np.float32),
'obs_parameters_valid' : gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
'obs_first_loop_err' : gym.spaces.Box(low=0, high=10, shape=(1,), dtype=np.float32),
'obs_second_loop_err': gym.spaces.Box(low=0, high=10, shape=(1,), dtype=np.float32),
'obs_parameter_score_1': gym.spaces.Box(low=-100, high=0, shape=(24,), dtype=np.float32),
'obs_parameter_score_2': gym.spaces.Box(low=-100, high=0, shape=(25,), dtype=np.float32),
'obs_crash': gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
}
self.observation_space = gym.spaces.Dict(obs_space)
What could be causing the IndexError in this case? Is there a mismatch between observation space and the policy architecture in SBX? I’ve been stuck on this issue for a while and would really appreciate any guidance to resolve it. Thank you
I tried using a flattening operation to ensure the observation space was correctly flattened, but the error still persists during initialization.
Upvotes: 1
Views: 14