-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4c12437
commit b93018b
Showing
7 changed files
with
524 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
from typing import Any, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
from lightning import LightningModule | ||
from torch.utils.data import DataLoader | ||
|
||
from rl4co.data.dataset import tensordict_collate_fn | ||
from rl4co.data.generate_data import generate_default_datasets | ||
from rl4co.envs.common.base import RL4COEnvBase | ||
from rl4co.utils.lr_scheduler_helpers import create_scheduler | ||
from rl4co.utils.optim_helpers import create_optimizer | ||
from rl4co.utils.pylogger import get_pylogger | ||
|
||
log = get_pylogger(__name__) | ||
|
||
|
||
class RL4COLitModule(LightningModule): | ||
def __init__( | ||
self, | ||
env: RL4COEnvBase, | ||
policy: nn.Module, | ||
batch_size: int = 512, | ||
train_batch_size: int = None, # minimum batch size for training | ||
val_batch_size: int = None, # minimum batch size for validation | ||
test_batch_size: int = None, # minimum batch size for testing | ||
train_dataset_size: int = 1_280_000, | ||
val_dataset_size: int = 10_000, | ||
test_dataset_size: int = 10_000, | ||
optimizer: Union[str, torch.optim.Optimizer] = "Adam", | ||
optimizer_kwargs: dict = {"lr": 1e-4}, | ||
lr_scheduler: Union[str, torch.optim.lr_scheduler.LRScheduler] = "MultiStepLR", | ||
lr_scheduler_kwargs: dict = { | ||
"milestones": [80, 95], | ||
"gamma": 0.1, | ||
}, | ||
lr_scheduler_interval: str = "epoch", | ||
lr_scheduler_monitor: str = "val/reward", | ||
generate_data: bool = True, | ||
shuffle_train_dataloader: bool = True, | ||
dataloader_num_workers: int = 0, | ||
data_dir: str = "data/", | ||
disable_profiling: bool = True, | ||
metrics: dict = {}, | ||
**litmodule_kwargs, | ||
): | ||
super().__init__(**litmodule_kwargs) | ||
|
||
if disable_profiling: | ||
# Disable profiling executor. This reduces memory and increases speed. | ||
# https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 | ||
try: | ||
torch._C._jit_set_profiling_executor(False) | ||
torch._C._jit_set_profiling_mode(False) | ||
except AttributeError: | ||
pass | ||
|
||
self.env = env | ||
self.policy = policy | ||
|
||
self.instantiate_metrics(metrics) | ||
|
||
self.data_config = { | ||
"batch_size": batch_size, | ||
"train_batch_size": train_batch_size, | ||
"val_batch_size": val_batch_size, | ||
"test_batch_size": test_batch_size, | ||
"generate_data": generate_data, | ||
"data_dir": data_dir, | ||
"train_dataset_size": train_dataset_size, | ||
"val_dataset_size": val_dataset_size, | ||
"test_dataset_size": test_dataset_size, | ||
} | ||
|
||
self._optimizer_name_or_cls: Union[str, torch.optim.Optimizer] = optimizer | ||
self.optimizer_kwargs: dict = optimizer_kwargs | ||
self._lr_scheduler_name_or_cls: Union[ | ||
str, torch.optim.lr_scheduler.LRScheduler | ||
] = lr_scheduler | ||
self.lr_scheduler_kwargs: dict = lr_scheduler_kwargs | ||
self.lr_scheduler_interval: str = lr_scheduler_interval | ||
self.lr_scheduler_monitor: str = lr_scheduler_monitor | ||
|
||
self.shuffle_train_dataloader = shuffle_train_dataloader | ||
self.dataloader_num_workers = dataloader_num_workers | ||
self.save_hyperparameters() | ||
|
||
def instantiate_metrics(self, metrics: dict): | ||
"""Dictionary of metrics to be logged at each phase""" | ||
if not metrics: | ||
log.info("No metrics specified, using default") | ||
self.train_metrics = metrics.get("train", ["loss", "reward"]) | ||
self.val_metrics = metrics.get("val", ["reward"]) | ||
self.test_metrics = metrics.get("test", ["reward"]) | ||
self.log_on_step = metrics.get("log_on_step", True) | ||
|
||
def setup(self, stage="fit"): | ||
log.info("Setting up batch sizes for train/val/test") | ||
|
||
batch_size = self.data_config["batch_size"] | ||
if self.data_config["train_batch_size"] is not None: | ||
train_batch_size = self.data_config["train_batch_size"] | ||
if batch_size is not None: | ||
log.warning( | ||
f"`train_batch_size`={train_batch_size} specified, ignoring `batch_size`={batch_size}" | ||
) | ||
elif batch_size is not None: | ||
train_batch_size = batch_size | ||
else: | ||
train_batch_size = 64 | ||
log.warning(f"No batch size specified, using default as {train_batch_size}") | ||
|
||
log.info("Setting up datasets") | ||
|
||
# Create datasets automatically. If found, this will skip | ||
if self.data_cfg["generate_data"]: | ||
generate_default_datasets(data_dir=self.data_config["data_dir"]) | ||
|
||
self.train_dataset = self.wrap_dataset( | ||
self.env.dataset(self.data_config["train_dataset_size"], phase="train") | ||
) | ||
self.val_dataset = self.env.dataset(self.data_config["val_dataset_size"], phase="val") | ||
self.test_dataset = self.env.dataset(self.data_config["test_dataset_size"], phase="test") | ||
|
||
if hasattr(self.policy, "setup"): | ||
self.policy.setup(self) | ||
self.post_setup_hook() | ||
|
||
def post_setup_hook(self): | ||
pass | ||
|
||
def configure_optimizers(self): | ||
""" | ||
Todo: Designing a behavior that can pass user-defined optimizers and schedulers | ||
""" | ||
|
||
# instantiate optimizer | ||
log.info(f"Instantiating optimizer <{self._optimizer_name_or_cls}>") | ||
if isinstance(self._optimizer_name_or_cls, str): | ||
optimizer = create_optimizer( | ||
self.policy, self._optimizer_name_or_cls, **self.optimizer_kwargs | ||
) | ||
else: # User-defined optimizer | ||
opt_cls = self._optimizer_name_or_cls | ||
assert isinstance(optimizer, torch.optim.Optimizer) | ||
optimizer = opt_cls(self.policy.parameters(), **self.optimizer_kwargs) | ||
|
||
# instantiate lr scheduler | ||
if self._lr_scheduler_name_or_cls is None: | ||
return optimizer | ||
else: | ||
log.info(f"Instantiating LR scheduler <{self._lr_scheduler_name_or_cls}>") | ||
if isinstance(self._lr_scheduler_name_or_cls, str): | ||
scheduler = create_scheduler( | ||
optimizer, self._lr_scheduler_name_or_cls, **self.lr_scheduler_kwargs | ||
) | ||
else: # User-defined scheduler | ||
scheduler_cls = self._lr_scheduler_name_or_cls | ||
assert isinstance(scheduler_cls, torch.optim.lr_scheduler.LRScheduler) | ||
scheduler = scheduler_cls(optimizer, **self.lr_scheduler_kwargs) | ||
return [optimizer], { | ||
"scheduler": scheduler, | ||
"interval": self.lr_scheduler_interval, | ||
"monitor": self.lr_scheduler_monitor, | ||
} | ||
|
||
def log_metrics(self, metric_dict: dict, phase: str): | ||
metrics = getattr(self, f"{phase}_metrics") | ||
metrics = {f"{phase}/{k}": v.mean() for k, v in metric_dict.items() if k in metrics} | ||
|
||
log_on_step = self.log_on_step if phase == "train" else False | ||
on_epoch = False if phase == "train" else True | ||
self.log_dict( | ||
metrics, | ||
on_step=log_on_step, | ||
on_epoch=on_epoch, | ||
prog_bar=True, | ||
sync_dist=True, | ||
add_dataloader_idx=False, | ||
) | ||
return metrics | ||
|
||
def shared_step(self, batch: Any, batch_idx: int, phase: str): | ||
raise NotImplementedError("Shared step is required to implemented in subclass") | ||
|
||
def training_step(self, batch: Any, batch_idx: int): | ||
# To use new data every epoch, we need to call reload_dataloaders_every_epoch=True in Trainer | ||
return self.shared_step(batch, batch_idx, phase="train") | ||
|
||
def validation_step(self, batch: Any, batch_idx: int): | ||
return self.shared_step(batch, batch_idx, phase="val") | ||
|
||
def test_step(self, batch: Any, batch_idx: int): | ||
return self.shared_step(batch, batch_idx, phase="test") | ||
|
||
def train_dataloader(self): | ||
return self._dataloader( | ||
self.train_dataset, self.train_batch_size, self.shuffle_train_dataloader | ||
) | ||
|
||
def val_dataloader(self): | ||
return self._dataloader(self.val_dataset, self.val_batch_size) | ||
|
||
def test_dataloader(self): | ||
return self._dataloader(self.test_dataset, self.test_batch_size) | ||
|
||
def on_train_epoch_end(self): | ||
if hasattr(self.model, "on_train_epoch_end"): | ||
self.model.on_train_epoch_end(self) | ||
train_dataset = self.env.dataset(self.train_size, "train") | ||
self.train_dataset = self.wrap_dataset(train_dataset) | ||
|
||
def wrap_dataset(self, dataset): | ||
if hasattr(self.model, "wrap_dataset") and not self.cfg.get("disable_wrap_dataset", False): | ||
dataset = self.policy.wrap_dataset(self, dataset) | ||
return dataset | ||
|
||
def _dataloader(self, dataset, batch_size, shuffle=False): | ||
return DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=self.dataloader_num_workers, | ||
collate_fn=tensordict_collate_fn, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Any | ||
|
||
from rl4co.algo.base import RL4COLitModule | ||
from rl4co.utils.lightning import get_lightning_device | ||
|
||
|
||
class REINFORCE(RL4COLitModule): | ||
def __init__(self, baseline=None, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.baseline = baseline | ||
|
||
def shared_step(self, batch: Any, batch_idx: int, phase: str): | ||
td = self.env.reset(batch) | ||
extra = td.get("extra", None) | ||
|
||
# Perform forward pass (i.e., constructing solution and computing log-likelihoods) | ||
out: dict = self.policy(td, "train", extra) | ||
|
||
# Compute loss | ||
if phase == "train": | ||
bl_val, bl_neg_loss = ( | ||
self.baseline.eval(td, "train", extra) if self.baseline is not None else (extra, 0) | ||
) | ||
|
||
advantage = out["reward"] - bl_val # advantage = reward - baseline | ||
reinforce_loss = -(advantage * out["log_likelihood"]).mean() | ||
loss = reinforce_loss - bl_neg_loss | ||
out.update( | ||
{ | ||
"loss": loss, | ||
"reinforce_loss": reinforce_loss, | ||
"bl_loss": -bl_neg_loss, | ||
"bl_val": bl_val, | ||
} | ||
) | ||
|
||
metrics = self.log_metrics(out, "train") | ||
return {"loss": out.get("loss", None), **metrics} | ||
|
||
def setup(self, lit_module): | ||
# Make baseline taking model itself and train_dataloader from model as input | ||
self.baseline.setup( | ||
self.policy, | ||
self.env, | ||
batch_size=lit_module.val_batch_size, | ||
device=get_lightning_device(lit_module), | ||
dataset_size=lit_module.cfg.data.val_size, | ||
) | ||
|
||
def on_train_epoch_end(self, lit_module): | ||
self.baseline.epoch_callback( | ||
self.policy, | ||
env=self.env, | ||
batch_size=lit_module.val_batch_size, | ||
device=get_lightning_device(lit_module), | ||
epoch=lit_module.current_epoch, | ||
dataset_size=lit_module.cfg.data.val_size, | ||
) | ||
|
||
def wrap_dataset(self, lit_module, dataset): | ||
"""Wrap dataset for baseline evaluation""" | ||
return self.baseline.wrap_dataset( | ||
dataset, | ||
self.env, | ||
batch_size=lit_module.val_batch_size, | ||
device=get_lightning_device(lit_module), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.