From d8fe35c444d4d004bb2d105c283ebc35e7dc3867 Mon Sep 17 00:00:00 2001 From: Simeon Manolov Date: Thu, 14 Sep 2023 13:20:40 +0300 Subject: [PATCH] support for SB3 callbacks in adversarial training --- .../algorithms/adversarial/common.py | 33 +++++++++++------ src/imitation/scripts/train_adversarial.py | 28 ++++++++++++-- tests/algorithms/test_adversarial.py | 37 +++++++++++++++++++ 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 48129fa67..1351c66ca 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -2,13 +2,15 @@ import abc import dataclasses import logging -from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload +from typing import Iterable, Iterator, Mapping, Optional, Type, overload import numpy as np import torch as th import torch.utils.tensorboard as thboard import tqdm from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env +from stable_baselines3.common.type_aliases import MaybeCallback +from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F @@ -386,6 +388,7 @@ def train_gen( self, total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None, + callback: MaybeCallback = None, ) -> None: """Trains the generator to maximize the discriminator loss. @@ -398,17 +401,27 @@ def train_gen( `self.gen_train_timesteps`. learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()` method. + callback: additional callback(s) passed to the generator's `learn` method. """ if total_timesteps is None: total_timesteps = self.gen_train_timesteps if learn_kwargs is None: learn_kwargs = {} + callbacks = [self.gen_callback] + + if isinstance(callback, list): + callbacks.extend(callback) + elif isinstance(callback, BaseCallback): + callbacks.append(callback) + elif callback is not None: + callbacks.append(ConvertCallback(callback)) + with self.logger.accumulate_means("gen"): self.gen_algo.learn( total_timesteps=total_timesteps, reset_num_timesteps=False, - callback=self.gen_callback, + callback=callbacks, **learn_kwargs, ) self._global_step += 1 @@ -421,12 +434,12 @@ def train_gen( def train( self, total_timesteps: int, - callback: Optional[Callable[[int], None]] = None, + callback: MaybeCallback = None, ) -> None: """Alternates between training the generator and discriminator. - Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`, - a call to `train_disc`, and finally a call to `callback(round)`. + Every "round" consists of a call to + `train_gen(self.gen_train_timesteps, callback)`, then a call to `train_disc`. Training ends once an additional "round" would cause the number of transitions sampled from the environment to exceed `total_timesteps`. @@ -434,9 +447,7 @@ def train( Args: total_timesteps: An upper bound on the number of transitions to sample from the environment during training. - callback: A function called at the end of every round which takes in a - single argument, the round number. Round numbers are in - `range(total_timesteps // self.gen_train_timesteps)`. + callback: callback(s) passed to the generator's `learn` method. """ n_rounds = total_timesteps // self.gen_train_timesteps assert n_rounds >= 1, ( @@ -444,14 +455,12 @@ def train( f"{self.gen_train_timesteps} timesteps, have only " f"total_timesteps={total_timesteps})!" ) - for r in tqdm.tqdm(range(0, n_rounds), desc="round"): - self.train_gen(self.gen_train_timesteps) + for _r in tqdm.tqdm(range(0, n_rounds), desc="round"): + self.train_gen(self.gen_train_timesteps, callback=callback) for _ in range(self.n_disc_updates_per_round): with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) self.train_disc() - if callback: - callback(r) self.logger.dump(self._global_step) @overload diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 26c8d7bcf..8de7d22b7 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -8,6 +8,7 @@ import sacred.commands import torch as th from sacred.observers import FileStorageObserver +from stable_baselines3.common.callbacks import BaseCallback from imitation.algorithms.adversarial import airl as airl_algo from imitation.algorithms.adversarial import common @@ -22,6 +23,28 @@ logger = logging.getLogger("imitation.scripts.train_adversarial") +class CheckpointCallback(BaseCallback): + def __init__( + self, + trainer: common.AdversarialTrainer, + log_dir: pathlib.Path, + interval: int + ): + super().__init__(self) + self.trainer = trainer + self.log_dir = log_dir + self.interval = interval + self.round_num = 0 + + def _on_step(self) -> bool: + return True + + def _on_training_end(self) -> None: + self.round_num += 1 + if self.interval > 0 and self.round_num % self.interval == 0: + save(self.trainer, self.log_dir / "checkpoints" / f"{self.round_num:05d}") + + def save(trainer: common.AdversarialTrainer, save_path: pathlib.Path): """Save discriminator and generator.""" # We implement this here and not in Trainer since we do not want to actually @@ -153,10 +176,7 @@ def train_adversarial( **algorithm_kwargs, ) - def callback(round_num: int, /) -> None: - if checkpoint_interval > 0 and round_num % checkpoint_interval == 0: - save(trainer, log_dir / "checkpoints" / f"{round_num:05d}") - + callback = CheckpointCallback(trainer, log_dir, checkpoint_interval) trainer.train(total_timesteps, callback) imit_stats = policy_evaluation.eval_policy(trainer.policy, trainer.venv_train) diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index d3609efaa..5153e98db 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -10,6 +10,7 @@ import stable_baselines3 import torch as th from stable_baselines3.common import policies +from stable_baselines3.common.callbacks import BaseCallback from torch.utils import data as th_data from imitation.algorithms.adversarial import airl, common, gail @@ -464,3 +465,39 @@ def test_regression_gail_with_sac( reward_net=reward_net, ) gail_trainer.train(8) + + +def test_gen_callback(trainer: common.AdversarialTrainer): + learner = stable_baselines3.PPO("MlpPolicy", env=trainer.venv) + + def make_fn_callback(calls, key): + def cb(_a, _b): + calls[key] += 1 + return cb + + class SB3Callback(BaseCallback): + def __init__(self, calls, key): + super().__init__(self) + self.calls = calls + self.key = key + + def _on_step(self): + self.calls[self.key] += 1 + return True + + n_steps = trainer.gen_train_timesteps * 2 + calls = {"fn": 0, "sb3": 0, "list.0": 0, "list.1": 0} + + trainer.train(n_steps, callback=make_fn_callback(calls, "fn")) + trainer.train(n_steps, callback=SB3Callback(calls, "sb3")) + trainer.train(n_steps, callback=[ + SB3Callback(calls, "list.0"), + SB3Callback(calls, "list.1") + ]) + + # Env steps for off-plicy algos (DQN) may exceed `total_timesteps`, + # so we check if the callback was called *at least* that many times. + assert calls["fn"] >= n_steps + assert calls["sb3"] >= n_steps + assert calls["list.0"] >= n_steps + assert calls["list.1"] >= n_steps