From b93018b72586a4c190f869b937dc21dc5219bb4d Mon Sep 17 00:00:00 2001 From: junyoungpark Date: Wed, 12 Jul 2023 23:45:09 +0900 Subject: [PATCH] Refactor algorithm base and decoder --- rl4co/algo/__init__.py | 0 rl4co/algo/base.py | 225 ++++++++++++++++++++++++++++ rl4co/algo/reinforce.py | 67 +++++++++ rl4co/envs/__init__.py | 15 ++ rl4co/models/nn/decoder.py | 183 ++++++++++++++++++++++ rl4co/utils/lr_scheduler_helpers.py | 14 ++ rl4co/utils/optim_helpers.py | 20 +++ 7 files changed, 524 insertions(+) create mode 100644 rl4co/algo/__init__.py create mode 100644 rl4co/algo/base.py create mode 100644 rl4co/algo/reinforce.py create mode 100644 rl4co/models/nn/decoder.py create mode 100644 rl4co/utils/lr_scheduler_helpers.py create mode 100644 rl4co/utils/optim_helpers.py diff --git a/rl4co/algo/__init__.py b/rl4co/algo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rl4co/algo/base.py b/rl4co/algo/base.py new file mode 100644 index 00000000..b1fe2c20 --- /dev/null +++ b/rl4co/algo/base.py @@ -0,0 +1,225 @@ +from typing import Any, Union + +import torch +import torch.nn as nn +from lightning import LightningModule +from torch.utils.data import DataLoader + +from rl4co.data.dataset import tensordict_collate_fn +from rl4co.data.generate_data import generate_default_datasets +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.utils.lr_scheduler_helpers import create_scheduler +from rl4co.utils.optim_helpers import create_optimizer +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class RL4COLitModule(LightningModule): + def __init__( + self, + env: RL4COEnvBase, + policy: nn.Module, + batch_size: int = 512, + train_batch_size: int = None, # minimum batch size for training + val_batch_size: int = None, # minimum batch size for validation + test_batch_size: int = None, # minimum batch size for testing + train_dataset_size: int = 1_280_000, + val_dataset_size: int = 10_000, + test_dataset_size: int = 10_000, + optimizer: Union[str, torch.optim.Optimizer] = "Adam", + optimizer_kwargs: dict = {"lr": 1e-4}, + lr_scheduler: Union[str, torch.optim.lr_scheduler.LRScheduler] = "MultiStepLR", + lr_scheduler_kwargs: dict = { + "milestones": [80, 95], + "gamma": 0.1, + }, + lr_scheduler_interval: str = "epoch", + lr_scheduler_monitor: str = "val/reward", + generate_data: bool = True, + shuffle_train_dataloader: bool = True, + dataloader_num_workers: int = 0, + data_dir: str = "data/", + disable_profiling: bool = True, + metrics: dict = {}, + **litmodule_kwargs, + ): + super().__init__(**litmodule_kwargs) + + if disable_profiling: + # Disable profiling executor. This reduces memory and increases speed. + # https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 + try: + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + except AttributeError: + pass + + self.env = env + self.policy = policy + + self.instantiate_metrics(metrics) + + self.data_config = { + "batch_size": batch_size, + "train_batch_size": train_batch_size, + "val_batch_size": val_batch_size, + "test_batch_size": test_batch_size, + "generate_data": generate_data, + "data_dir": data_dir, + "train_dataset_size": train_dataset_size, + "val_dataset_size": val_dataset_size, + "test_dataset_size": test_dataset_size, + } + + self._optimizer_name_or_cls: Union[str, torch.optim.Optimizer] = optimizer + self.optimizer_kwargs: dict = optimizer_kwargs + self._lr_scheduler_name_or_cls: Union[ + str, torch.optim.lr_scheduler.LRScheduler + ] = lr_scheduler + self.lr_scheduler_kwargs: dict = lr_scheduler_kwargs + self.lr_scheduler_interval: str = lr_scheduler_interval + self.lr_scheduler_monitor: str = lr_scheduler_monitor + + self.shuffle_train_dataloader = shuffle_train_dataloader + self.dataloader_num_workers = dataloader_num_workers + self.save_hyperparameters() + + def instantiate_metrics(self, metrics: dict): + """Dictionary of metrics to be logged at each phase""" + if not metrics: + log.info("No metrics specified, using default") + self.train_metrics = metrics.get("train", ["loss", "reward"]) + self.val_metrics = metrics.get("val", ["reward"]) + self.test_metrics = metrics.get("test", ["reward"]) + self.log_on_step = metrics.get("log_on_step", True) + + def setup(self, stage="fit"): + log.info("Setting up batch sizes for train/val/test") + + batch_size = self.data_config["batch_size"] + if self.data_config["train_batch_size"] is not None: + train_batch_size = self.data_config["train_batch_size"] + if batch_size is not None: + log.warning( + f"`train_batch_size`={train_batch_size} specified, ignoring `batch_size`={batch_size}" + ) + elif batch_size is not None: + train_batch_size = batch_size + else: + train_batch_size = 64 + log.warning(f"No batch size specified, using default as {train_batch_size}") + + log.info("Setting up datasets") + + # Create datasets automatically. If found, this will skip + if self.data_cfg["generate_data"]: + generate_default_datasets(data_dir=self.data_config["data_dir"]) + + self.train_dataset = self.wrap_dataset( + self.env.dataset(self.data_config["train_dataset_size"], phase="train") + ) + self.val_dataset = self.env.dataset(self.data_config["val_dataset_size"], phase="val") + self.test_dataset = self.env.dataset(self.data_config["test_dataset_size"], phase="test") + + if hasattr(self.policy, "setup"): + self.policy.setup(self) + self.post_setup_hook() + + def post_setup_hook(self): + pass + + def configure_optimizers(self): + """ + Todo: Designing a behavior that can pass user-defined optimizers and schedulers + """ + + # instantiate optimizer + log.info(f"Instantiating optimizer <{self._optimizer_name_or_cls}>") + if isinstance(self._optimizer_name_or_cls, str): + optimizer = create_optimizer( + self.policy, self._optimizer_name_or_cls, **self.optimizer_kwargs + ) + else: # User-defined optimizer + opt_cls = self._optimizer_name_or_cls + assert isinstance(optimizer, torch.optim.Optimizer) + optimizer = opt_cls(self.policy.parameters(), **self.optimizer_kwargs) + + # instantiate lr scheduler + if self._lr_scheduler_name_or_cls is None: + return optimizer + else: + log.info(f"Instantiating LR scheduler <{self._lr_scheduler_name_or_cls}>") + if isinstance(self._lr_scheduler_name_or_cls, str): + scheduler = create_scheduler( + optimizer, self._lr_scheduler_name_or_cls, **self.lr_scheduler_kwargs + ) + else: # User-defined scheduler + scheduler_cls = self._lr_scheduler_name_or_cls + assert isinstance(scheduler_cls, torch.optim.lr_scheduler.LRScheduler) + scheduler = scheduler_cls(optimizer, **self.lr_scheduler_kwargs) + return [optimizer], { + "scheduler": scheduler, + "interval": self.lr_scheduler_interval, + "monitor": self.lr_scheduler_monitor, + } + + def log_metrics(self, metric_dict: dict, phase: str): + metrics = getattr(self, f"{phase}_metrics") + metrics = {f"{phase}/{k}": v.mean() for k, v in metric_dict.items() if k in metrics} + + log_on_step = self.log_on_step if phase == "train" else False + on_epoch = False if phase == "train" else True + self.log_dict( + metrics, + on_step=log_on_step, + on_epoch=on_epoch, + prog_bar=True, + sync_dist=True, + add_dataloader_idx=False, + ) + return metrics + + def shared_step(self, batch: Any, batch_idx: int, phase: str): + raise NotImplementedError("Shared step is required to implemented in subclass") + + def training_step(self, batch: Any, batch_idx: int): + # To use new data every epoch, we need to call reload_dataloaders_every_epoch=True in Trainer + return self.shared_step(batch, batch_idx, phase="train") + + def validation_step(self, batch: Any, batch_idx: int): + return self.shared_step(batch, batch_idx, phase="val") + + def test_step(self, batch: Any, batch_idx: int): + return self.shared_step(batch, batch_idx, phase="test") + + def train_dataloader(self): + return self._dataloader( + self.train_dataset, self.train_batch_size, self.shuffle_train_dataloader + ) + + def val_dataloader(self): + return self._dataloader(self.val_dataset, self.val_batch_size) + + def test_dataloader(self): + return self._dataloader(self.test_dataset, self.test_batch_size) + + def on_train_epoch_end(self): + if hasattr(self.model, "on_train_epoch_end"): + self.model.on_train_epoch_end(self) + train_dataset = self.env.dataset(self.train_size, "train") + self.train_dataset = self.wrap_dataset(train_dataset) + + def wrap_dataset(self, dataset): + if hasattr(self.model, "wrap_dataset") and not self.cfg.get("disable_wrap_dataset", False): + dataset = self.policy.wrap_dataset(self, dataset) + return dataset + + def _dataloader(self, dataset, batch_size, shuffle=False): + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=self.dataloader_num_workers, + collate_fn=tensordict_collate_fn, + ) diff --git a/rl4co/algo/reinforce.py b/rl4co/algo/reinforce.py new file mode 100644 index 00000000..c57d376f --- /dev/null +++ b/rl4co/algo/reinforce.py @@ -0,0 +1,67 @@ +from typing import Any + +from rl4co.algo.base import RL4COLitModule +from rl4co.utils.lightning import get_lightning_device + + +class REINFORCE(RL4COLitModule): + def __init__(self, baseline=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.baseline = baseline + + def shared_step(self, batch: Any, batch_idx: int, phase: str): + td = self.env.reset(batch) + extra = td.get("extra", None) + + # Perform forward pass (i.e., constructing solution and computing log-likelihoods) + out: dict = self.policy(td, "train", extra) + + # Compute loss + if phase == "train": + bl_val, bl_neg_loss = ( + self.baseline.eval(td, "train", extra) if self.baseline is not None else (extra, 0) + ) + + advantage = out["reward"] - bl_val # advantage = reward - baseline + reinforce_loss = -(advantage * out["log_likelihood"]).mean() + loss = reinforce_loss - bl_neg_loss + out.update( + { + "loss": loss, + "reinforce_loss": reinforce_loss, + "bl_loss": -bl_neg_loss, + "bl_val": bl_val, + } + ) + + metrics = self.log_metrics(out, "train") + return {"loss": out.get("loss", None), **metrics} + + def setup(self, lit_module): + # Make baseline taking model itself and train_dataloader from model as input + self.baseline.setup( + self.policy, + self.env, + batch_size=lit_module.val_batch_size, + device=get_lightning_device(lit_module), + dataset_size=lit_module.cfg.data.val_size, + ) + + def on_train_epoch_end(self, lit_module): + self.baseline.epoch_callback( + self.policy, + env=self.env, + batch_size=lit_module.val_batch_size, + device=get_lightning_device(lit_module), + epoch=lit_module.current_epoch, + dataset_size=lit_module.cfg.data.val_size, + ) + + def wrap_dataset(self, lit_module, dataset): + """Wrap dataset for baseline evaluation""" + return self.baseline.wrap_dataset( + dataset, + self.env, + batch_size=lit_module.val_batch_size, + device=get_lightning_device(lit_module), + ) diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py index 455adde3..b06874b5 100644 --- a/rl4co/envs/__init__.py +++ b/rl4co/envs/__init__.py @@ -12,3 +12,18 @@ from rl4co.envs.sdvrp import SDVRPEnv from rl4co.envs.spctsp import SPCTSPEnv from rl4co.envs.tsp import TSPEnv + +# Register environments +ENV_REGISTRY = { + "atsp": ATSPEnv, + "cvrp": CVRPEnv, + "dpp": DPPEnv, + "mdpp": MDPPEnv, + "mtsp": MTSPEnv, + "op": OPEnv, + "pctsp": PCTSPEnv, + "pdp": PDPEnv, + "sdvrp": SDVRPEnv, + "spctsp": SPCTSPEnv, + "tsp": TSPEnv, +} diff --git a/rl4co/models/nn/decoder.py b/rl4co/models/nn/decoder.py new file mode 100644 index 00000000..e850338f --- /dev/null +++ b/rl4co/models/nn/decoder.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from typing import Union + +import torch +import torch.nn as nn +from einops import rearrange +from tensordict import TensorDict + +from rl4co.envs import ENV_REGISTRY +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.models.nn.attention import LogitAttention +from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding +from rl4co.models.nn.utils import decode_probs +from rl4co.utils.ops import batchify, select_start_nodes, unbatchify + + +@dataclass +class PrecomputedCache: + node_embeddings: torch.Tensor + graph_context: torch.Tensor + glimpse_key: torch.Tensor + glimpse_val: torch.Tensor + logit_key: torch.Tensor + + +class AutoregressiveDecoder(nn.Module): + """Auto-regressive decoder for the Attention Model for constructing solutions + We additionally include support for greedy multi-starts during inference (as in POMO) + + Args: + env_name: environment name to solve + embedding_dim: Dimension of the embeddings + num_heads: Number of heads for the attention + """ + + def __init__(self, env_name: str, embedding_dim: int, num_heads: int, **logit_attn_kwargs): + super().__init__() + + self.env_name = env_name + self.embedding_dim = embedding_dim + self.num_heads = num_heads + + assert embedding_dim % num_heads == 0 + + self.context = env_context_embedding(self.env_name, {"embedding_dim": embedding_dim}) + self.dynamic_embedding = env_dynamic_embedding( + self.env_name, {"embedding_dim": embedding_dim} + ) + + # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim + self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False) + self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False) + + # MHA + self.logit_attention = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs) + + def forward( + self, + td: TensorDict, + embeddings: torch.Tensor, + decode_type: str = "sampling", + softmax_temp: float = None, + num_starts: int = None, + calc_reward: bool = True, + env: Union[str, RL4COEnvBase] = None, + ): + # Greedy multi-start decoding if num_starts > 1 + num_starts = 0 if num_starts is None else num_starts + assert not ( + "multistart" in decode_type and num_starts <= 1 + ), "Multi-start decoding requires `num_starts` > 1" + + # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step + cached_embeds = self._precompute(embeddings, num_starts=num_starts) + + # Collect outputs + outputs = [] + actions = [] + + # Construct environment if needed + if isinstance(env, str): + env_name = self.env_name if env is None else env + env = ENV_REGISTRY[env_name]() + elif isinstance(env, RL4COEnvBase): + env = env + else: + raise ValueError(f"env must be either str or RL4COEnvBase, got {type(env)}") + + # Multi-start decoding: first action is chosen by ad-hoc node selection + if num_starts > 1 or "multistart" in decode_type: + action = select_start_nodes(td, num_starts, self.env) + + # Expand td to batch_size * num_starts + td = batchify(td, num_starts) + + td.set("action", action) + td = env.step(td)["next"] + log_p = torch.zeros_like( + td["action_mask"], device=td.device + ) # first log_p is 0, so p = log_p.exp() = 1 + + outputs.append(log_p) + actions.append(action) + + # Main decoding + while not td["done"].all(): + log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts) + + # Select the indices of the next nodes in the sequences, result (batch_size) long + action = decode_probs(log_p.exp(), mask, decode_type=decode_type) + + td.set("action", action) + td = env.step(td)["next"] + + # Collect output of step + outputs.append(log_p) + actions.append(action) + + outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1) + if calc_reward: + td.set("reward", self.env.get_reward(td, actions)) + + return outputs, actions, td + + def _precompute(self, embeddings, num_starts=0): + # The projection of the node embeddings for the attention is calculated once up front + ( + glimpse_key_fixed, + glimpse_val_fixed, + logit_key_fixed, + ) = self.project_node_embeddings( + embeddings + ).chunk(3, dim=-1) + + # Batchify and unbatchify have no effect if num_starts = 0. + # Otherwise, we need to batchify the embeddings to modify key value (i.e. for the lenght of queries) + graph_context = unbatchify( + batchify(self.project_fixed_context(embeddings.mean(1)), num_starts), + num_starts, + ) + + # Organize in a dataclass for easy access + cached_embeds = PrecomputedCache( + node_embeddings=embeddings, + graph_context=graph_context, + glimpse_key=glimpse_key_fixed, + glimpse_val=glimpse_val_fixed, + logit_key=logit_key_fixed, + ) + + return cached_embeds + + def _get_log_p(self, cached, td, softmax_temp=None, num_starts=0): + # Compute the query based on the context (computes automatically the first and last node context) + + # Unbatchify to [batch_size, num_starts, ...]. Has no effect if num_starts = 0 + td_unbatch = unbatchify(td, num_starts) + + step_context = self.context(cached.node_embeddings, td_unbatch) + glimpse_q = step_context + cached.graph_context + glimpse_q = glimpse_q.unsqueeze(1) if glimpse_q.ndim == 2 else glimpse_q + + # Compute keys and values for the nodes + ( + glimpse_key_dynamic, + glimpse_val_dynamic, + logit_key_dynamic, + ) = self.dynamic_embedding(td_unbatch) + glimpse_k = cached.glimpse_key + glimpse_key_dynamic + glimpse_v = cached.glimpse_val + glimpse_val_dynamic + logit_k = cached.logit_key + logit_key_dynamic + + # Get the mask + mask = ~td_unbatch["action_mask"] + + # Compute logits + log_p = self.logit_attention(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp) + + # Now we need to reshape the logits and log_p to [batch_size*num_starts, num_nodes] + # Note that rearranging order is important here + log_p = rearrange(log_p, "b s l -> (s b) l") if num_starts > 1 else log_p + mask = rearrange(mask, "b s l -> (s b) l") if num_starts > 1 else mask + return log_p, mask diff --git a/rl4co/utils/lr_scheduler_helpers.py b/rl4co/utils/lr_scheduler_helpers.py new file mode 100644 index 00000000..d679df55 --- /dev/null +++ b/rl4co/utils/lr_scheduler_helpers.py @@ -0,0 +1,14 @@ +import torch +from torch.optim import Optimizer + + +def get_pytorch_lr_schedulers(): + return torch.optim.lr_scheduler.__all__ + + +def create_scheduler(optimizer: Optimizer, scheduler_name: str, **scheduler_kwargs): + if scheduler_name in get_pytorch_lr_schedulers(): + scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_name) + return scheduler_cls(optimizer, **scheduler_kwargs) + else: + raise ValueError(f"Scheduler {scheduler_name} not found.") diff --git a/rl4co/utils/optim_helpers.py b/rl4co/utils/optim_helpers.py new file mode 100644 index 00000000..fe2f00cf --- /dev/null +++ b/rl4co/utils/optim_helpers.py @@ -0,0 +1,20 @@ +import inspect + +import torch +import torch.nn as nn + + +def get_pytorch_optimizers(): + optimizers = [] + for name, obj in inspect.getmembers(torch.optim): + if inspect.isclass(obj) and issubclass(obj, torch.optim.Optimizer): + optimizers.append(name) + return optimizers + + +def create_optimizer(model: nn.Module, optimizer_name: str, **optimizer_kwargs): + if optimizer_name in get_pytorch_optimizers(): + optimizer_cls = getattr(torch.optim, optimizer_name) + return optimizer_cls(model.parameters(), **optimizer_kwargs) + else: + raise ValueError(f"Optimizer {optimizer_name} not found.")