diff --git a/configs/experiment/scheduling/matnet-ppo.yaml b/configs/experiment/scheduling/matnet-ppo.yaml index cd07b802..3abb9c4c 100644 --- a/configs/experiment/scheduling/matnet-ppo.yaml +++ b/configs/experiment/scheduling/matnet-ppo.yaml @@ -37,6 +37,6 @@ model: val_batch_size: 512 test_batch_size: 64 mini_batch_size: 512 - + n_start: 8 env: stepwise_reward: True \ No newline at end of file diff --git a/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py b/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py index cffe7462..5aff5496 100644 --- a/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py +++ b/rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py @@ -16,6 +16,7 @@ from rl4co.envs.common.base import RL4COEnvBase from rl4co.models.rl.common.base import RL4COLitModule from rl4co.models.rl.common.utils import RewardScaler +from rl4co.utils.ops import batchify, unbatchify from rl4co.utils.pylogger import get_pylogger log = get_pylogger(__name__) @@ -42,6 +43,7 @@ def __init__( self, env: RL4COEnvBase, policy: nn.Module, + n_start: int = 0, clip_range: float = 0.2, # epsilon of PPO update_timestep: int = 1, buffer_size: int = 100_000, @@ -64,7 +66,7 @@ def __init__( self.automatic_optimization = False # PPO uses custom optimization routine self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device) self.scaler = RewardScaler(reward_scale) - + self.n_start = n_start self.ppo_cfg = { "clip_range": clip_range, "ppo_epochs": ppo_epochs, @@ -87,8 +89,14 @@ def update(self, device): logprobs, value_pred, entropy = self.policy.evaluate(sub_td) ratios = torch.exp(logprobs - previous_logp) - - advantages = torch.squeeze(previous_reward - value_pred.detach(), 1) + try: + advantages = sub_td["advantage"] + value_loss = 0 + + except KeyError: + advantages = torch.squeeze(previous_reward - value_pred.detach(), 1) + # compute value function loss + value_loss = F.mse_loss(value_pred, previous_reward) surr1 = ratios * advantages surr2 = ( torch.clamp( @@ -100,9 +108,6 @@ def update(self, device): ) surrogate_loss = -torch.min(surr1, surr2).mean() - # compute value function loss - value_loss = F.mse_loss(value_pred, previous_reward) - # compute total loss loss = ( surrogate_loss @@ -141,9 +146,15 @@ def update(self, device): def shared_step( self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None ): + next_td = self.env.reset(batch) device = next_td.device + if phase == "train": + + if self.n_start > 1: + next_td = batchify(next_td, self.n_start) + td_buffer = [] while not next_td["done"].all(): @@ -155,10 +166,16 @@ def shared_step( next_td = self.env.step(td)["next"] # get reward of action reward = self.env.get_reward(next_td, 1) - reward = self.scaler(reward) - # add reward to prior state - td_buffer = [td.set("reward", reward) for td in td_buffer] + if self.n_start > 1: + reward_unbatched = unbatchify(reward, self.n_start) + advantage = reward - batchify(reward_unbatched.mean(-1), self.n_start) + advantage = self.scaler(advantage) + td_buffer = [td.set("advantage", advantage) for td in td_buffer] + else: + reward = self.scaler(reward) + td_buffer = [td.set("reward", reward) for td in td_buffer] + # add tensordict with action, logprobs and reward information to buffer self.rb.extend(torch.cat(td_buffer, dim=0))