Skip to content

Commit

Permalink
[Feat] draft: added stepwise ppo with sparse reward and shared baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Jul 11, 2024
1 parent f3808c8 commit ffeffba
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion configs/experiment/scheduling/matnet-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ model:
val_batch_size: 512
test_batch_size: 64
mini_batch_size: 512

n_start: 8
env:
stepwise_reward: True
35 changes: 26 additions & 9 deletions rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.models.rl.common.utils import RewardScaler
from rl4co.utils.ops import batchify, unbatchify
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand All @@ -42,6 +43,7 @@ def __init__(
self,
env: RL4COEnvBase,
policy: nn.Module,
n_start: int = 0,
clip_range: float = 0.2, # epsilon of PPO
update_timestep: int = 1,
buffer_size: int = 100_000,
Expand All @@ -64,7 +66,7 @@ def __init__(
self.automatic_optimization = False # PPO uses custom optimization routine
self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device)
self.scaler = RewardScaler(reward_scale)

self.n_start = n_start
self.ppo_cfg = {
"clip_range": clip_range,
"ppo_epochs": ppo_epochs,
Expand All @@ -87,8 +89,14 @@ def update(self, device):
logprobs, value_pred, entropy = self.policy.evaluate(sub_td)

ratios = torch.exp(logprobs - previous_logp)

advantages = torch.squeeze(previous_reward - value_pred.detach(), 1)
try:
advantages = sub_td["advantage"]
value_loss = 0

except KeyError:
advantages = torch.squeeze(previous_reward - value_pred.detach(), 1)
# compute value function loss
value_loss = F.mse_loss(value_pred, previous_reward)
surr1 = ratios * advantages
surr2 = (
torch.clamp(
Expand All @@ -100,9 +108,6 @@ def update(self, device):
)
surrogate_loss = -torch.min(surr1, surr2).mean()

# compute value function loss
value_loss = F.mse_loss(value_pred, previous_reward)

# compute total loss
loss = (
surrogate_loss
Expand Down Expand Up @@ -141,9 +146,15 @@ def update(self, device):
def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):

next_td = self.env.reset(batch)
device = next_td.device

if phase == "train":

if self.n_start > 1:
next_td = batchify(next_td, self.n_start)

td_buffer = []
while not next_td["done"].all():

Expand All @@ -155,10 +166,16 @@ def shared_step(
next_td = self.env.step(td)["next"]
# get reward of action
reward = self.env.get_reward(next_td, 1)
reward = self.scaler(reward)

# add reward to prior state
td_buffer = [td.set("reward", reward) for td in td_buffer]
if self.n_start > 1:
reward_unbatched = unbatchify(reward, self.n_start)
advantage = reward - batchify(reward_unbatched.mean(-1), self.n_start)
advantage = self.scaler(advantage)
td_buffer = [td.set("advantage", advantage) for td in td_buffer]
else:
reward = self.scaler(reward)
td_buffer = [td.set("reward", reward) for td in td_buffer]

# add tensordict with action, logprobs and reward information to buffer
self.rb.extend(torch.cat(td_buffer, dim=0))

Expand Down

0 comments on commit ffeffba

Please sign in to comment.