From a6c12d7c107bde5882751e1cd935d17c067218d2 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Sat, 10 Feb 2024 23:18:12 +0100 Subject: [PATCH 01/18] add accelerate example --- cleanrl/ppo_atari_accelerate.py | 342 ++++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 cleanrl/ppo_atari_accelerate.py diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py new file mode 100644 index 000000000..f0463757e --- /dev/null +++ b/cleanrl/ppo_atari_accelerate.py @@ -0,0 +1,342 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_ataripy +import os +import random +import time +from dataclasses import dataclass + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tyro +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter + +from stable_baselines3.common.atari_wrappers import ( # isort:skip + ClipRewardEnv, + EpisodicLifeEnv, + FireResetEnv, + MaxAndSkipEnv, + NoopResetEnv, +) + +from accelerate import Accelerator + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + + # Algorithm specific arguments + env_id: str = "BreakoutNoFrameskip-v4" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 2.5e-4 + """the learning rate of the optimizer""" + num_envs: int = 8 + """the number of parallel game environments""" + num_steps: int = 128 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 4 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.1 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.01 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +def make_env(env_id, idx, capture_video, run_name): + def thunk(): + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + env = EpisodicLifeEnv(env) + if "FIRE" in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = ClipRewardEnv(env) + env = gym.wrappers.ResizeObservation(env, (84, 84)) + env = gym.wrappers.GrayScaleObservation(env) + env = gym.wrappers.FrameStack(env, 4) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.network = nn.Sequential( + layer_init(nn.Conv2d(4, 32, 8, stride=4)), + nn.ReLU(), + layer_init(nn.Conv2d(32, 64, 4, stride=2)), + nn.ReLU(), + layer_init(nn.Conv2d(64, 64, 3, stride=1)), + nn.ReLU(), + nn.Flatten(), + layer_init(nn.Linear(64 * 7 * 7, 512)), + nn.ReLU(), + ) + self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01) + self.critic = layer_init(nn.Linear(512, 1), std=1) + + def get_value(self, x): + return self.critic(self.network(x / 255.0)) + + def get_action_and_value(self, x, action=None): + hidden = self.network(x / 255.0) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + + +if __name__ == "__main__": + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + + accelerator = Accelerator() + + if args.track and accelerator.is_main_process: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + # env setup + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)], + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + agent = Agent(envs) + agent = accelerator.prepare(agent) + + device = accelerator.device + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset(seed=args.seed) + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + + for iteration in range(1, args.num_iterations + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / args.num_iterations + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) + next_done = np.logical_or(terminations, truncations) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) + + + if "final_info" in infos and accelerator.is_main_process: + for info in infos["final_info"]: + if info and "episode" in info: + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + accelerator.backward(loss) + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None and approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + if accelerator.is_main_process: + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + + if accelerator.is_main_process: + writer.close() From a9935256abbb8df4bd620cfe03bd592a74abe105 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:04:01 +0100 Subject: [PATCH 02/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index f0463757e..4f39b32ce 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -50,8 +50,8 @@ class Args: """total timesteps of the experiments""" learning_rate: float = 2.5e-4 """the learning rate of the optimizer""" - num_envs: int = 8 - """the number of parallel game environments""" + local_num_envs: int = 8 + """the number of parallel game environments (in the local rank)""" num_steps: int = 128 """the number of steps to run in each environment per policy rollout""" anneal_lr: bool = True From f7cd97054d38f12e724ac76330899f3efc325ee2 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:05:27 +0100 Subject: [PATCH 03/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 4f39b32ce..9419c0f30 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -80,6 +80,12 @@ class Args: """the target KL divergence threshold""" # to be filled in runtime + local_batch_size: int = 0 + """the local batch size in the local rank (computed in runtime)""" + local_minibatch_size: int = 0 + """the local mini-batch size in the local rank (computed in runtime)""" + num_envs: int = 0 + """the number of parallel game environments (computed in runtime)""" batch_size: int = 0 """the batch size (computed in runtime)""" minibatch_size: int = 0 From ea6fc9994297109dc868a5ac20e6f96e2bd83202 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:05:33 +0100 Subject: [PATCH 04/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 9419c0f30..175cb1b61 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -92,6 +92,8 @@ class Args: """the mini-batch size (computed in runtime)""" num_iterations: int = 0 """the number of iterations (computed in runtime)""" + world_size: int = 0 + """the number of processes (computed in runtime)""" def make_env(env_id, idx, capture_video, run_name): From e3d2c446f62c34aad53de83bb1e4583f4738e0ee Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:06:02 +0100 Subject: [PATCH 05/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 175cb1b61..2f339d2f2 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -184,9 +184,11 @@ def get_action_and_value(self, x, action=None): ) # TRY NOT TO MODIFY: seeding + # CRUCIAL: note that we needed to pass a different seed for each data parallelism worker + args.seed += accelerator.process_index * 100003 # Prime random.seed(args.seed) np.random.seed(args.seed) - torch.manual_seed(args.seed) + torch.manual_seed(args.seed - local_rank) torch.backends.cudnn.deterministic = args.torch_deterministic # env setup From 9a8451e5362695bf82b46b910403a4d00e2720ee Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:06:13 +0100 Subject: [PATCH 06/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 2f339d2f2..d07ee566a 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -193,7 +193,7 @@ def get_action_and_value(self, x, action=None): # env setup envs = gym.vector.SyncVectorEnv( - [make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.num_envs)], + [make_env(args.env_id, i, args.capture_video, run_name) for i in range(args.local_num_envs)], ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" From ba8fbd85c5b6950d29d9e7db1373c46b77c2b05a Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:06:49 +0100 Subject: [PATCH 07/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index d07ee566a..0084a1c3e 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -243,8 +243,10 @@ def get_action_and_value(self, x, action=None): rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) - - if "final_info" in infos and accelerator.is_main_process: + if not writer: + continue + + if "final_info" in infos: for info in infos["final_info"]: if info and "episode" in info: print(f"global_step={global_step}, episodic_return={info['episode']['r']}") From 03d1a1c2fe4fef5b13911f05cbdda0c7f8850a49 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:07:17 +0100 Subject: [PATCH 08/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 0084a1c3e..f4eeac2ce 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -278,12 +278,13 @@ def get_action_and_value(self, x, action=None): b_values = values.reshape(-1) # Optimizing the policy and value network - b_inds = np.arange(args.batch_size) + b_inds = np.arange(args.local_batch_size) clipfracs = [] for epoch in range(args.update_epochs): np.random.shuffle(b_inds) - for start in range(0, args.batch_size, args.minibatch_size): - end = start + args.minibatch_size + for start in range(0, args.local_batch_size, args.local_minibatch_size): + end = start + args.local_minibatch_size + mb_inds = b_inds[start:end] mb_inds = b_inds[start:end] _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) From ea79777260c852d3c8777a13e110fd582c64834a Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:07:33 +0100 Subject: [PATCH 09/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index f4eeac2ce..76f330573 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -353,3 +353,5 @@ def get_action_and_value(self, x, action=None): if accelerator.is_main_process: writer.close() + if args.track: + wandb.finish() From 655267cca797c40cc04b4d30e75ca61a3089912e Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:07:43 +0100 Subject: [PATCH 10/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 76f330573..f0e0858f9 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -326,7 +326,7 @@ def get_action_and_value(self, x, action=None): optimizer.zero_grad() accelerator.backward(loss) - nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() if args.target_kl is not None and approx_kl > args.target_kl: From 3afd7e1fa5e18afdad1c5a148e485532095f4843 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:07:57 +0100 Subject: [PATCH 11/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index f0e0858f9..8f866af0b 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -253,6 +253,9 @@ def get_action_and_value(self, x, action=None): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + print( + f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {agent.actor.weight.sum()}" + ) # bootstrap value if not done with torch.no_grad(): next_value = agent.get_value(next_obs).reshape(1, -1) From 25c5a3a82185f7771db09549f7c1bb1daadbdf2e Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:08:09 +0100 Subject: [PATCH 12/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 8f866af0b..b9279dfb7 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -158,6 +158,12 @@ def get_action_and_value(self, x, action=None): if __name__ == "__main__": args = tyro.cli(Args) + accelerator = Accelerator() + local_rank = accelerator.process_index + args.world_size = accelerator.num_processes + args.local_batch_size = int(args.local_num_envs * args.num_steps) + args.local_minibatch_size = int(args.local_batch_size // args.num_minibatches) + args.num_envs = args.local_num_envs * args.world_size args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) args.num_iterations = args.total_timesteps // args.batch_size From a6f884f9c788e9b9b6b14afa8e9e623b50b787d1 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:08:38 +0100 Subject: [PATCH 13/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index b9279dfb7..9cd1586a1 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -203,11 +203,11 @@ def get_action_and_value(self, x, action=None): ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" - agent = Agent(envs) - agent = accelerator.prepare(agent) - - device = accelerator.device + agent = Agent(envs).to(device) + torch.manual_seed(args.seed) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + agent, optimizer = accelerator.prepare(agent, optimizer) + device = accelerator.device # ALGO Logic: Storage setup obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) From 95e690077237d94c6b9a538a78ae420d2b5cdf9d Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:08:47 +0100 Subject: [PATCH 14/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 9cd1586a1..22d49b2eb 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -169,7 +169,6 @@ def get_action_and_value(self, x, action=None): args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" - accelerator = Accelerator() if args.track and accelerator.is_main_process: import wandb From 26dd0e3df4c25d39e73a3cc2c235e298f1bc8874 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:09:00 +0100 Subject: [PATCH 15/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 22d49b2eb..fb80b6890 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -209,12 +209,12 @@ def get_action_and_value(self, x, action=None): device = accelerator.device # ALGO Logic: Storage setup - obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) - actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) - logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) - rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) - dones = torch.zeros((args.num_steps, args.num_envs)).to(device) - values = torch.zeros((args.num_steps, args.num_envs)).to(device) + obs = torch.zeros((args.num_steps, args.local_num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.local_num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.local_num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.local_num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.local_num_envs)).to(device) + values = torch.zeros((args.num_steps, args.local_num_envs)).to(device) # TRY NOT TO MODIFY: start the game global_step = 0 From 6232f50393fab6e4a2d8b7c5edd98e8be63bc9a7 Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Tue, 13 Feb 2024 10:09:12 +0100 Subject: [PATCH 16/18] Update cleanrl/ppo_atari_accelerate.py Co-authored-by: Costa Huang --- cleanrl/ppo_atari_accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index fb80b6890..55a3a1ee8 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -221,7 +221,7 @@ def get_action_and_value(self, x, action=None): start_time = time.time() next_obs, _ = envs.reset(seed=args.seed) next_obs = torch.Tensor(next_obs).to(device) - next_done = torch.zeros(args.num_envs).to(device) + next_done = torch.zeros(args.local_num_envs).to(device) for iteration in range(1, args.num_iterations + 1): # Annealing the rate if instructed to do so. From 6d61b990dfc04c9f34536207f03203bafe7ae3bf Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 22 Feb 2024 22:25:36 +0000 Subject: [PATCH 17/18] fixes grads not being synced due to forward not being called with DDP --- cleanrl/ppo_atari_accelerate.py | 48 ++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 55a3a1ee8..0a2668359 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -155,8 +155,11 @@ def get_action_and_value(self, x, action=None): action = probs.sample() return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + # required due to how DistributedDataParallel wraps the model + def forward(self, x, action): + return self.get_action_and_value(x, action) -if __name__ == "__main__": +def main(): args = tyro.cli(Args) accelerator = Accelerator() local_rank = accelerator.process_index @@ -202,11 +205,12 @@ def get_action_and_value(self, x, action=None): ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + device = accelerator.device agent = Agent(envs).to(device) torch.manual_seed(args.seed) optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) agent, optimizer = accelerator.prepare(agent, optimizer) - device = accelerator.device + # ALGO Logic: Storage setup obs = torch.zeros((args.num_steps, args.local_num_envs) + envs.single_observation_space.shape).to(device) @@ -237,7 +241,7 @@ def get_action_and_value(self, x, action=None): # ALGO LOGIC: action logic with torch.no_grad(): - action, logprob, _, value = agent.get_action_and_value(next_obs) + action, logprob, _, value = accelerator.unwrap_model(agent).get_action_and_value(next_obs) values[step] = value.flatten() actions[step] = action logprobs[step] = logprob @@ -248,7 +252,7 @@ def get_action_and_value(self, x, action=None): rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) - if not writer: + if not (args.track and accelerator.is_main_process): continue if "final_info" in infos: @@ -259,11 +263,11 @@ def get_action_and_value(self, x, action=None): writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) print( - f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {agent.actor.weight.sum()}" + f"local_rank: {local_rank}, action.sum(): {action.sum()}, iteration: {iteration}, agent.actor.weight.sum(): {accelerator.unwrap_model(agent).actor.weight.sum()}" ) # bootstrap value if not done with torch.no_grad(): - next_value = agent.get_value(next_obs).reshape(1, -1) + next_value = accelerator.unwrap_model(agent).get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 for t in reversed(range(args.num_steps)): @@ -293,9 +297,8 @@ def get_action_and_value(self, x, action=None): for start in range(0, args.local_batch_size, args.local_minibatch_size): end = start + args.local_minibatch_size mb_inds = b_inds[start:end] - mb_inds = b_inds[start:end] - _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + _, newlogprob, entropy, newvalue = agent(b_obs[mb_inds], b_actions.long()[mb_inds]) logratio = newlogprob - b_logprobs[mb_inds] ratio = logratio.exp() @@ -332,10 +335,11 @@ def get_action_and_value(self, x, action=None): entropy_loss = entropy.mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef - optimizer.zero_grad() + accelerator.backward(loss) - accelerator.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + #accelerator.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() + optimizer.zero_grad() if args.target_kl is not None and approx_kl > args.target_kl: break @@ -346,16 +350,18 @@ def get_action_and_value(self, x, action=None): # TRY NOT TO MODIFY: record rewards for plotting purposes if accelerator.is_main_process: - writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) - writer.add_scalar("losses/value_loss", v_loss.item(), global_step) - writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) - writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) - writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) - writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) - writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) - writer.add_scalar("losses/explained_variance", explained_var, global_step) print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + if args.track: + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() @@ -363,3 +369,7 @@ def get_action_and_value(self, x, action=None): writer.close() if args.track: wandb.finish() + +if __name__ == "__main__": + main() + \ No newline at end of file From 181d48bec6a150263b978ad0e534e76c08030b94 Mon Sep 17 00:00:00 2001 From: edbeeching Date: Thu, 22 Feb 2024 22:28:21 +0000 Subject: [PATCH 18/18] adds back grad clip --- cleanrl/ppo_atari_accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo_atari_accelerate.py b/cleanrl/ppo_atari_accelerate.py index 0a2668359..8c0cb82a4 100644 --- a/cleanrl/ppo_atari_accelerate.py +++ b/cleanrl/ppo_atari_accelerate.py @@ -337,7 +337,7 @@ def main(): accelerator.backward(loss) - #accelerator.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad()