Skip to content

Commit

Permalink
Merge branch 'release'
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayankm96 committed Oct 11, 2024
2 parents a1d25d1 + 2fab9bb commit 73fd7c6
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions rsl_rl/runners/on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
else:
self.obs_normalizer = torch.nn.Identity() # no normalization
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
# init storage and model
self.alg.init_storage(
self.env.num_envs,
Expand Down Expand Up @@ -109,18 +109,21 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, rewards, dones, infos = self.env.step(actions)
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
# move to the right device
obs, critic_obs, rewards, dones = (
obs.to(self.device),
critic_obs.to(self.device),
rewards.to(self.device),
dones.to(self.device),
)
# perform normalization
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
# process the step
self.alg.process_env_step(rewards, dones, infos)

if self.log_dir is not None:
Expand Down

0 comments on commit 73fd7c6

Please sign in to comment.