From 2fab9bbe1ad226f1bb8c2d9bcf50aacf703e28fa Mon Sep 17 00:00:00 2001 From: Mayank Mittal Date: Fri, 11 Oct 2024 12:24:56 +0000 Subject: [PATCH] Fixes device discrepancy for environment and RL agent Approved-by: Fan Yang --- rsl_rl/runners/on_policy_runner.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 9e0a459..4fe58bf 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -45,8 +45,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device) else: - self.obs_normalizer = torch.nn.Identity() # no normalization - self.critic_obs_normalizer = torch.nn.Identity() # no normalization + self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization + self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization # init storage and model self.alg.init_storage( self.env.num_envs, @@ -109,18 +109,21 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals with torch.inference_mode(): for i in range(self.num_steps_per_env): actions = self.alg.act(obs, critic_obs) - obs, rewards, dones, infos = self.env.step(actions) - obs = self.obs_normalizer(obs) - if "critic" in infos["observations"]: - critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"]) - else: - critic_obs = obs + obs, rewards, dones, infos = self.env.step(actions.to(self.env.device)) + # move to the right device obs, critic_obs, rewards, dones = ( obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device), ) + # perform normalization + obs = self.obs_normalizer(obs) + if "critic" in infos["observations"]: + critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"]) + else: + critic_obs = obs + # process the step self.alg.process_env_step(rewards, dones, infos) if self.log_dir is not None: