From 961bd56a1e847422b35163e95b705b7b9d19056b Mon Sep 17 00:00:00 2001 From: LTluttmann Date: Thu, 5 Sep 2024 11:47:49 +0200 Subject: [PATCH] [Feat] changed structure of stepwise ppo --- configs/experiment/scheduling/base.yaml | 1 - configs/experiment/scheduling/matnet-ppo.yaml | 2 +- rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py | 83 ++++++++++++------- rl4co/models/zoo/l2d/policy.py | 2 +- 4 files changed, 53 insertions(+), 35 deletions(-) diff --git a/configs/experiment/scheduling/base.yaml b/configs/experiment/scheduling/base.yaml index c15a6c45..0252a7e0 100644 --- a/configs/experiment/scheduling/base.yaml +++ b/configs/experiment/scheduling/base.yaml @@ -9,7 +9,6 @@ defaults: logger: wandb: project: "rl4co" - log_model: "all" group: "${env.name}-${env.generator_params.num_jobs}-${env.generator_params.num_machines}" tags: ??? name: ??? diff --git a/configs/experiment/scheduling/matnet-ppo.yaml b/configs/experiment/scheduling/matnet-ppo.yaml index 3abb9c4c..f489ae6b 100644 --- a/configs/experiment/scheduling/matnet-ppo.yaml +++ b/configs/experiment/scheduling/matnet-ppo.yaml @@ -37,6 +37,6 @@ model: val_batch_size: 512 test_batch_size: 64 mini_batch_size: 512 - n_start: 8 + n_start: 4 env: stepwise_reward: True \ No newline at end of file diff --git a/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py b/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py index 5aff5496..a14df6a1 100644 --- a/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py +++ b/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py @@ -45,14 +45,14 @@ def __init__( policy: nn.Module, n_start: int = 0, clip_range: float = 0.2, # epsilon of PPO - update_timestep: int = 1, buffer_size: int = 100_000, ppo_epochs: int = 2, # inner epoch, K batch_size: int = 256, mini_batch_size: int = 256, + rollout_batch_size: int = 256, vf_lambda: float = 0.5, # lambda of Value function fitting - entropy_lambda: float = 0.01, # lambda of entropy bonus - max_grad_norm: float = 0.5, # max gradient norm + entropy_lambda: float = 0.0, # lambda of entropy bonus + max_grad_norm: float = 1.0, # max gradient norm buffer_storage_device: str = "gpu", metrics: dict = { "train": ["loss", "surrogate_loss", "value_loss", "entropy"], @@ -67,10 +67,10 @@ def __init__( self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device) self.scaler = RewardScaler(reward_scale) self.n_start = n_start + self.rollout_batch_size = rollout_batch_size self.ppo_cfg = { "clip_range": clip_range, "ppo_epochs": ppo_epochs, - "update_timestep": update_timestep, "mini_batch_size": mini_batch_size, "vf_lambda": vf_lambda, "entropy_lambda": entropy_lambda, @@ -133,7 +133,7 @@ def update(self, device): "reward": previous_reward.mean(), "loss": loss, "surrogate_loss": surrogate_loss, - "value_loss": value_loss, + # "value_loss": value_loss, "entropy": entropy.mean(), } @@ -147,47 +147,66 @@ def shared_step( self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None ): - next_td = self.env.reset(batch) - device = next_td.device - if phase == "train": - if self.n_start > 1: - next_td = batchify(next_td, self.n_start) + for i in range(0, batch.shape[0], self.rollout_batch_size): + + mini_batch = batch[i : i + self.rollout_batch_size] + rollout_td_buffer = [] + next_td = self.env.reset(mini_batch) + device = next_td.device + + if self.n_start > 1: + next_td = batchify(next_td, self.n_start) + + n_steps = 0 + while not next_td["done"].all(): - td_buffer = [] - while not next_td["done"].all(): + with torch.no_grad(): + td = self.policy_old.act( + next_td, self.env, phase="train", temp=2.0 + ) - with torch.no_grad(): - td = self.policy_old.act(next_td, self.env, phase="train") + rollout_td_buffer.append(td) + # get next state + next_td = self.env.step(td)["next"] + n_steps += 1 - td_buffer.append(td) - # get next state - next_td = self.env.step(td)["next"] - # get reward of action - reward = self.env.get_reward(next_td, 1) + # get rewards + reward = self.env.get_reward(next_td, 1) / n_steps - if self.n_start > 1: - reward_unbatched = unbatchify(reward, self.n_start) - advantage = reward - batchify(reward_unbatched.mean(-1), self.n_start) - advantage = self.scaler(advantage) - td_buffer = [td.set("advantage", advantage) for td in td_buffer] - else: - reward = self.scaler(reward) - td_buffer = [td.set("reward", reward) for td in td_buffer] + if self.n_start > 1: + reward_unbatched = unbatchify(reward, self.n_start) + advantage = ( + reward + - batchify(reward_unbatched.mean(-1), self.n_start).detach() + ) + advantage = self.scaler(advantage) + rollout_td_buffer = [ + td.set("advantage", advantage) for td in rollout_td_buffer + ] + + else: + reward = self.scaler(reward) + rollout_td_buffer = [ + td.set("reward", reward) for td in rollout_td_buffer + ] - # add tensordict with action, logprobs and reward information to buffer - self.rb.extend(torch.cat(td_buffer, dim=0)) + # add tensordict with action, logprobs and reward information to buffer + self.rb.extend(torch.cat(rollout_td_buffer, dim=0)) # if iter mod x = 0 then update the policy (x = 1 in paper) - if batch_idx % self.ppo_cfg["update_timestep"] == 0: - out = self.update(device) - self.rb.empty() + out = self.update(device) + + self.rb.empty() + torch.cuda.empty_cache() else: + next_td = self.env.reset(batch) out = self.policy.generate( next_td, self.env, phase=phase, select_best=phase != "train" ) metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) + return {"loss": out.get("loss", None), **metrics} diff --git a/rl4co/models/zoo/l2d/policy.py b/rl4co/models/zoo/l2d/policy.py index 0cfac356..8368ae71 100644 --- a/rl4co/models/zoo/l2d/policy.py +++ b/rl4co/models/zoo/l2d/policy.py @@ -225,7 +225,7 @@ def evaluate(self, td): return action_logprobs, value_pred, dist_entropys - def act(self, td, env, phase: str = "train"): + def act(self, td, env, phase: str = "train", temp: float = 1.0): logits, mask = self.decoder(td, hidden=None, num_starts=0) logprobs = process_logits(logits, mask, tanh_clipping=self.tanh_clipping)