From f3808c8cb7a84223ca3655c4be6a2e82e078a7e7 Mon Sep 17 00:00:00 2001 From: Laurin Luttmann Date: Thu, 11 Jul 2024 14:08:04 +0200 Subject: [PATCH] [Feat] draft: added stepwise ppo with sparse reward --- .../experiment/routing/tsp-stepwise-ppo.yaml | 2 +- configs/experiment/scheduling/matnet-ppo.yaml | 2 +- rl4co/models/__init__.py | 2 +- rl4co/models/rl/__init__.py | 1 + rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py | 176 ++++++++++++++++++ 5 files changed, 180 insertions(+), 3 deletions(-) create mode 100644 rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py diff --git a/configs/experiment/routing/tsp-stepwise-ppo.yaml b/configs/experiment/routing/tsp-stepwise-ppo.yaml index 678223ec..60e313dc 100644 --- a/configs/experiment/routing/tsp-stepwise-ppo.yaml +++ b/configs/experiment/routing/tsp-stepwise-ppo.yaml @@ -7,7 +7,7 @@ defaults: - override /logger: wandb.yaml env: - _target_: rl4co.envs.TSPEnv4PPO + _target_: rl4co.envs.DenseRewardTSPEnv generator_params: num_loc: 20 diff --git a/configs/experiment/scheduling/matnet-ppo.yaml b/configs/experiment/scheduling/matnet-ppo.yaml index c88d2c64..cd07b802 100644 --- a/configs/experiment/scheduling/matnet-ppo.yaml +++ b/configs/experiment/scheduling/matnet-ppo.yaml @@ -11,7 +11,7 @@ logger: embed_dim: 256 model: - _target_: rl4co.models.StepwisePPO + _target_: rl4co.models.StepwisePPOwEpisodicReward policy: _target_: rl4co.models.L2DPolicy4PPO decoder: diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py index 339c3b01..b23ec692 100644 --- a/rl4co/models/__init__.py +++ b/rl4co/models/__init__.py @@ -14,7 +14,7 @@ NonAutoregressivePolicy, ) from rl4co.models.common.transductive import TransductiveModel -from rl4co.models.rl import StepwisePPO +from rl4co.models.rl import StepwisePPO, StepwisePPOwEpisodicReward from rl4co.models.rl.a2c.a2c import A2C from rl4co.models.rl.common.base import RL4COLitModule from rl4co.models.rl.ppo.ppo import PPO diff --git a/rl4co/models/rl/__init__.py b/rl4co/models/rl/__init__.py index 1a3bf7e2..2ad6470a 100644 --- a/rl4co/models/rl/__init__.py +++ b/rl4co/models/rl/__init__.py @@ -3,4 +3,5 @@ from rl4co.models.rl.ppo.n_step_ppo import n_step_PPO from rl4co.models.rl.ppo.ppo import PPO from rl4co.models.rl.ppo.stepwise_ppo import StepwisePPO +from rl4co.models.rl.ppo.stepwise_ppo_ep_rew import StepwisePPOwEpisodicReward from rl4co.models.rl.reinforce.reinforce import REINFORCE diff --git a/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py b/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py new file mode 100644 index 00000000..cffe7462 --- /dev/null +++ b/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py @@ -0,0 +1,176 @@ +import copy + +from typing import Any, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchrl.data.replay_buffers import ( + LazyMemmapStorage, + ListStorage, + SamplerWithoutReplacement, + TensorDictReplayBuffer, +) + +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.models.rl.common.base import RL4COLitModule +from rl4co.models.rl.common.utils import RewardScaler +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def make_replay_buffer(buffer_size, batch_size, device="cpu"): + if device == "cpu": + storage = LazyMemmapStorage(buffer_size, device="cpu") + prefetch = 3 + else: + storage = ListStorage(buffer_size) + prefetch = None + return TensorDictReplayBuffer( + storage=storage, + batch_size=batch_size, + sampler=SamplerWithoutReplacement(drop_last=True), + pin_memory=False, + prefetch=prefetch, + ) + + +class StepwisePPOwEpisodicReward(RL4COLitModule): + def __init__( + self, + env: RL4COEnvBase, + policy: nn.Module, + clip_range: float = 0.2, # epsilon of PPO + update_timestep: int = 1, + buffer_size: int = 100_000, + ppo_epochs: int = 2, # inner epoch, K + batch_size: int = 256, + mini_batch_size: int = 256, + vf_lambda: float = 0.5, # lambda of Value function fitting + entropy_lambda: float = 0.01, # lambda of entropy bonus + max_grad_norm: float = 0.5, # max gradient norm + buffer_storage_device: str = "gpu", + metrics: dict = { + "train": ["loss", "surrogate_loss", "value_loss", "entropy"], + }, + reward_scale: Union[str, int] = None, + **kwargs, + ): + super().__init__(env, policy, metrics=metrics, batch_size=batch_size, **kwargs) + + self.policy_old = copy.deepcopy(self.policy) + self.automatic_optimization = False # PPO uses custom optimization routine + self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device) + self.scaler = RewardScaler(reward_scale) + + self.ppo_cfg = { + "clip_range": clip_range, + "ppo_epochs": ppo_epochs, + "update_timestep": update_timestep, + "mini_batch_size": mini_batch_size, + "vf_lambda": vf_lambda, + "entropy_lambda": entropy_lambda, + "max_grad_norm": max_grad_norm, + } + + def update(self, device): + outs = [] + # PPO inner epoch + for _ in range(self.ppo_cfg["ppo_epochs"]): + for sub_td in self.rb: + sub_td = sub_td.to(device) + previous_reward = sub_td["reward"].view(-1, 1) + previous_logp = sub_td["logprobs"] + + logprobs, value_pred, entropy = self.policy.evaluate(sub_td) + + ratios = torch.exp(logprobs - previous_logp) + + advantages = torch.squeeze(previous_reward - value_pred.detach(), 1) + surr1 = ratios * advantages + surr2 = ( + torch.clamp( + ratios, + 1 - self.ppo_cfg["clip_range"], + 1 + self.ppo_cfg["clip_range"], + ) + * advantages + ) + surrogate_loss = -torch.min(surr1, surr2).mean() + + # compute value function loss + value_loss = F.mse_loss(value_pred, previous_reward) + + # compute total loss + loss = ( + surrogate_loss + + self.ppo_cfg["vf_lambda"] * value_loss + - self.ppo_cfg["entropy_lambda"] * entropy.mean() + ) + + # perform manual optimization following the Lightning routine + # https://lightning.ai/docs/pytorch/stable/common/optimization.html + + opt = self.optimizers() + opt.zero_grad() + self.manual_backward(loss) + if self.ppo_cfg["max_grad_norm"] is not None: + self.clip_gradients( + opt, + gradient_clip_val=self.ppo_cfg["max_grad_norm"], + gradient_clip_algorithm="norm", + ) + opt.step() + + out = { + "reward": previous_reward.mean(), + "loss": loss, + "surrogate_loss": surrogate_loss, + "value_loss": value_loss, + "entropy": entropy.mean(), + } + + outs.append(out) + # Copy new weights into old policy: + self.policy_old.load_state_dict(self.policy.state_dict()) + outs = {k: torch.stack([dic[k] for dic in outs], dim=0) for k in outs[0]} + return outs + + def shared_step( + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None + ): + next_td = self.env.reset(batch) + device = next_td.device + if phase == "train": + td_buffer = [] + while not next_td["done"].all(): + + with torch.no_grad(): + td = self.policy_old.act(next_td, self.env, phase="train") + + td_buffer.append(td) + # get next state + next_td = self.env.step(td)["next"] + # get reward of action + reward = self.env.get_reward(next_td, 1) + reward = self.scaler(reward) + + # add reward to prior state + td_buffer = [td.set("reward", reward) for td in td_buffer] + # add tensordict with action, logprobs and reward information to buffer + self.rb.extend(torch.cat(td_buffer, dim=0)) + + # if iter mod x = 0 then update the policy (x = 1 in paper) + if batch_idx % self.ppo_cfg["update_timestep"] == 0: + out = self.update(device) + self.rb.empty() + + else: + out = self.policy.generate( + next_td, self.env, phase=phase, select_best=phase != "train" + ) + + metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) + return {"loss": out.get("loss", None), **metrics}