-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat] draft: added stepwise ppo with sparse reward
- Loading branch information
1 parent
c694414
commit f3808c8
Showing
5 changed files
with
180 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import copy | ||
|
||
from typing import Any, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from torchrl.data.replay_buffers import ( | ||
LazyMemmapStorage, | ||
ListStorage, | ||
SamplerWithoutReplacement, | ||
TensorDictReplayBuffer, | ||
) | ||
|
||
from rl4co.envs.common.base import RL4COEnvBase | ||
from rl4co.models.rl.common.base import RL4COLitModule | ||
from rl4co.models.rl.common.utils import RewardScaler | ||
from rl4co.utils.pylogger import get_pylogger | ||
|
||
log = get_pylogger(__name__) | ||
|
||
|
||
def make_replay_buffer(buffer_size, batch_size, device="cpu"): | ||
if device == "cpu": | ||
storage = LazyMemmapStorage(buffer_size, device="cpu") | ||
prefetch = 3 | ||
else: | ||
storage = ListStorage(buffer_size) | ||
prefetch = None | ||
return TensorDictReplayBuffer( | ||
storage=storage, | ||
batch_size=batch_size, | ||
sampler=SamplerWithoutReplacement(drop_last=True), | ||
pin_memory=False, | ||
prefetch=prefetch, | ||
) | ||
|
||
|
||
class StepwisePPOwEpisodicReward(RL4COLitModule): | ||
def __init__( | ||
self, | ||
env: RL4COEnvBase, | ||
policy: nn.Module, | ||
clip_range: float = 0.2, # epsilon of PPO | ||
update_timestep: int = 1, | ||
buffer_size: int = 100_000, | ||
ppo_epochs: int = 2, # inner epoch, K | ||
batch_size: int = 256, | ||
mini_batch_size: int = 256, | ||
vf_lambda: float = 0.5, # lambda of Value function fitting | ||
entropy_lambda: float = 0.01, # lambda of entropy bonus | ||
max_grad_norm: float = 0.5, # max gradient norm | ||
buffer_storage_device: str = "gpu", | ||
metrics: dict = { | ||
"train": ["loss", "surrogate_loss", "value_loss", "entropy"], | ||
}, | ||
reward_scale: Union[str, int] = None, | ||
**kwargs, | ||
): | ||
super().__init__(env, policy, metrics=metrics, batch_size=batch_size, **kwargs) | ||
|
||
self.policy_old = copy.deepcopy(self.policy) | ||
self.automatic_optimization = False # PPO uses custom optimization routine | ||
self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device) | ||
self.scaler = RewardScaler(reward_scale) | ||
|
||
self.ppo_cfg = { | ||
"clip_range": clip_range, | ||
"ppo_epochs": ppo_epochs, | ||
"update_timestep": update_timestep, | ||
"mini_batch_size": mini_batch_size, | ||
"vf_lambda": vf_lambda, | ||
"entropy_lambda": entropy_lambda, | ||
"max_grad_norm": max_grad_norm, | ||
} | ||
|
||
def update(self, device): | ||
outs = [] | ||
# PPO inner epoch | ||
for _ in range(self.ppo_cfg["ppo_epochs"]): | ||
for sub_td in self.rb: | ||
sub_td = sub_td.to(device) | ||
previous_reward = sub_td["reward"].view(-1, 1) | ||
previous_logp = sub_td["logprobs"] | ||
|
||
logprobs, value_pred, entropy = self.policy.evaluate(sub_td) | ||
|
||
ratios = torch.exp(logprobs - previous_logp) | ||
|
||
advantages = torch.squeeze(previous_reward - value_pred.detach(), 1) | ||
surr1 = ratios * advantages | ||
surr2 = ( | ||
torch.clamp( | ||
ratios, | ||
1 - self.ppo_cfg["clip_range"], | ||
1 + self.ppo_cfg["clip_range"], | ||
) | ||
* advantages | ||
) | ||
surrogate_loss = -torch.min(surr1, surr2).mean() | ||
|
||
# compute value function loss | ||
value_loss = F.mse_loss(value_pred, previous_reward) | ||
|
||
# compute total loss | ||
loss = ( | ||
surrogate_loss | ||
+ self.ppo_cfg["vf_lambda"] * value_loss | ||
- self.ppo_cfg["entropy_lambda"] * entropy.mean() | ||
) | ||
|
||
# perform manual optimization following the Lightning routine | ||
# https://lightning.ai/docs/pytorch/stable/common/optimization.html | ||
|
||
opt = self.optimizers() | ||
opt.zero_grad() | ||
self.manual_backward(loss) | ||
if self.ppo_cfg["max_grad_norm"] is not None: | ||
self.clip_gradients( | ||
opt, | ||
gradient_clip_val=self.ppo_cfg["max_grad_norm"], | ||
gradient_clip_algorithm="norm", | ||
) | ||
opt.step() | ||
|
||
out = { | ||
"reward": previous_reward.mean(), | ||
"loss": loss, | ||
"surrogate_loss": surrogate_loss, | ||
"value_loss": value_loss, | ||
"entropy": entropy.mean(), | ||
} | ||
|
||
outs.append(out) | ||
# Copy new weights into old policy: | ||
self.policy_old.load_state_dict(self.policy.state_dict()) | ||
outs = {k: torch.stack([dic[k] for dic in outs], dim=0) for k in outs[0]} | ||
return outs | ||
|
||
def shared_step( | ||
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None | ||
): | ||
next_td = self.env.reset(batch) | ||
device = next_td.device | ||
if phase == "train": | ||
td_buffer = [] | ||
while not next_td["done"].all(): | ||
|
||
with torch.no_grad(): | ||
td = self.policy_old.act(next_td, self.env, phase="train") | ||
|
||
td_buffer.append(td) | ||
# get next state | ||
next_td = self.env.step(td)["next"] | ||
# get reward of action | ||
reward = self.env.get_reward(next_td, 1) | ||
reward = self.scaler(reward) | ||
|
||
# add reward to prior state | ||
td_buffer = [td.set("reward", reward) for td in td_buffer] | ||
# add tensordict with action, logprobs and reward information to buffer | ||
self.rb.extend(torch.cat(td_buffer, dim=0)) | ||
|
||
# if iter mod x = 0 then update the policy (x = 1 in paper) | ||
if batch_idx % self.ppo_cfg["update_timestep"] == 0: | ||
out = self.update(device) | ||
self.rb.empty() | ||
|
||
else: | ||
out = self.policy.generate( | ||
next_td, self.env, phase=phase, select_best=phase != "train" | ||
) | ||
|
||
metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) | ||
return {"loss": out.get("loss", None), **metrics} |