From 1935c99ce6acdf5c421668bb1009e3736e02eba1 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Tue, 4 Jul 2023 20:46:08 +0200 Subject: [PATCH 01/57] Initial version of the SQIL implementation --- src/imitation/algorithms/sqil.py | 321 +++++++++++++++++++++++++++++++ src/imitation/data/rollout.py | 3 +- 2 files changed, 323 insertions(+), 1 deletion(-) create mode 100644 src/imitation/algorithms/sqil.py diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py new file mode 100644 index 000000000..6e6538cde --- /dev/null +++ b/src/imitation/algorithms/sqil.py @@ -0,0 +1,321 @@ +"""Soft Q Imitation Learning (SQIL). + +Trains a policy via DQN-style Q-learning, +replacing half the buffer with expert demonstrations and adjusting the rewards. +""" +from typing import Any, Dict, Iterable, Optional, Tuple, Type, Union + +import numpy as np +import torch as th +import torch.nn.functional as F +from stable_baselines3 import dqn +from stable_baselines3.common import policies, vec_env +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import ( + MaybeCallback, + ReplayBufferSamples, + Schedule, +) +from stable_baselines3.dqn.policies import DQNPolicy + +from imitation.algorithms import base as algo_base +from imitation.algorithms.base import AnyTransitions +from imitation.data import types +from imitation.data.rollout import flatten_trajectories +from imitation.data.types import Transitions +from imitation.util import logger as imit_logger +from imitation.util.util import get_first_iter_element + + +class SQIL(algo_base.DemonstrationAlgorithm): + """Soft Q Imitation Learning (SQIL). + + Trains a policy via DQN-style Q-learning, + replacing half the buffer with expert demonstrations and adjusting the rewards. + """ + + expert_buffer: ReplayBuffer + + def __init__( + self, + *, + venv: vec_env.VecEnv, + demonstrations: Transitions, + policy: Union[str, Type[DQNPolicy]], + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, + learning_rate: Union[float, Schedule] = 1e-4, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 50000, + batch_size: int = 32, + tau: float = 1.0, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 4, + gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1.0, + exploration_final_eps: float = 0.05, + max_grad_norm: float = 10, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + """Builds SQIL. + + Args: + venv: The vectorized environment to train on. + demonstrations: Demonstrations to use for training. + policy: The policy model to use (SB3). + custom_logger: Where to log to; if None (default), creates a new logger. + learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0). + buffer_size: Size of the replay buffer. + learning_starts: How many steps of the model to collect transitions for + before learning starts. + batch_size: Minibatch size for each gradient update. + tau: The soft update coefficient ("Polyak update", between 0 and 1), + default 1 for hard update. + gamma: The discount factor. + train_freq: Update the model every ``train_freq`` steps. Alternatively + pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + gradient_steps: How many gradient steps to do after each + rollout (see ``train_freq``). + Set to ``-1`` means to do as many gradient steps as steps done + in the environment during the rollout. + replay_buffer_class: Replay buffer class to use + (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + replay_buffer_kwargs: Keyword arguments to pass + to the replay buffer on creation. + optimize_memory_usage: Enable a memory efficient variant of the + replay buffer at a cost of more complexity. + target_update_interval: Update the target network every + ``target_update_interval`` environment steps. + exploration_fraction: Fraction of entire training period over + which the exploration rate is reduced. + exploration_initial_eps: Initial value of random action probability. + exploration_final_eps: Final value of random action probability. + max_grad_norm: The maximum value for the gradient clipping. + tensorboard_log: The log location for tensorboard (if None, no logging). + policy_kwargs: Additional arguments to be passed to the policy on creation. + verbose: Verbosity level: 0 for no output, 1 for info messages + (such as device or wrappers used), 2 for debug messages. + seed: Seed for the pseudo random generators. + device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + _init_setup_model: Whether or not to build the network + at the creation of the instance. + + + """ + self.venv = venv + + super().__init__(demonstrations=demonstrations, custom_logger=custom_logger) + + self.orig_train_freq = train_freq + + self.dqn = dqn.DQN( + policy=policy, + env=venv, + learning_rate=learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + train_freq=train_freq, + gradient_steps=gradient_steps, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + optimize_memory_usage=optimize_memory_usage, + target_update_interval=target_update_interval, + exploration_fraction=exploration_fraction, + exploration_initial_eps=exploration_initial_eps, + exploration_final_eps=exploration_final_eps, + max_grad_norm=max_grad_norm, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=_init_setup_model, + ) + + def set_demonstrations(self, demonstrations: AnyTransitions) -> None: + # If demonstrations is a list of trajectories, + # flatten it into a list of transitions + if isinstance(demonstrations, Iterable): + item, demonstrations = get_first_iter_element( # type: ignore[assignment] + demonstrations, # type: ignore[assignment] + ) + if isinstance(item, types.Trajectory): + demonstrations = flatten_trajectories( + demonstrations, # type: ignore[arg-type] + ) + + n_samples = len(demonstrations) # type: ignore[arg-type] + self.expert_buffer = ReplayBuffer( + n_samples, + self.venv.observation_space, + self.venv.action_space, + handle_timeout_termination=False, + ) + + for transition in demonstrations: + self.expert_buffer.add( + obs=np.array(transition["obs"]), # type: ignore[index] + next_obs=np.array(transition["next_obs"]), # type: ignore[index] + action=np.array(transition["acts"]), # type: ignore[index] + done=np.array(transition["dones"]), # type: ignore[index] + reward=np.array(1), + infos=[{}], + ) + + def train(self, *, total_timesteps: int): + self.learn_dqn(total_timesteps=total_timesteps) + + @property + def policy(self) -> policies.BasePolicy: + assert isinstance(self.dqn.policy, policies.BasePolicy) + return self.dqn.policy + + def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None: + + # Needed to make mypy happy, because SB3 typing is shoddy + assert isinstance(self.dqn.policy, policies.BasePolicy) + + # Switch to train mode (this affects batch norm / dropout) + self.dqn.policy.set_training_mode(True) + # Update learning rate according to schedule + self.dqn._update_learning_rate(self.dqn.policy.optimizer) + + losses = [] + for _ in range(gradient_steps): + # Sample replay buffer + new_data = self.dqn.replay_buffer.sample( + batch_size // 2, + env=self.dqn._vec_normalize_env, + ) + new_data.rewards.zero_() # Zero out the rewards + + expert_data = self.expert_buffer.sample( + batch_size // 2, + env=self.dqn._vec_normalize_env, + ) + + # Concatenate the two batches of data + replay_data = ReplayBufferSamples( + *( + th.cat((getattr(new_data, name), getattr(expert_data, name))) + for name in new_data._fields + ), + ) + + with th.no_grad(): + # Compute the next Q-values using the target network + next_q_values = self.dqn.q_net_target(replay_data.next_observations) + # Follow greedy policy: use the one with the highest value + next_q_values, _ = next_q_values.max(dim=1) + # Avoid potential broadcast issue + next_q_values = next_q_values.reshape(-1, 1) + # 1-step TD target + target_q_values = ( + replay_data.rewards + + (1 - replay_data.dones) * self.dqn.gamma * next_q_values + ) + + # Get current Q-values estimates + current_q_values = self.dqn.q_net(replay_data.observations) + + # Retrieve the q-values for the actions from the replay buffer + current_q_values = th.gather( + current_q_values, + dim=1, + index=replay_data.actions.long(), + ) + + # Compute Huber loss (less sensitive to outliers) + loss = F.smooth_l1_loss(current_q_values, target_q_values) + losses.append(loss.item()) + + # Optimize the policy + self.dqn.policy.optimizer.zero_grad() + loss.backward() + # Clip gradient norm + # For some reason pytype doesn't see nn.utils, so adding a type ignore + th.nn.utils.clip_grad_norm_( # type: ignore[module-attr] + self.dqn.policy.parameters(), + self.dqn.max_grad_norm, + ) + self.dqn.policy.optimizer.step() + + # Increase update counter + self.dqn._n_updates += gradient_steps + + self.dqn.logger.record( + "train/n_updates", + self.dqn._n_updates, + exclude="tensorboard", + ) + self.dqn.logger.record("train/loss", np.mean(losses)) + + def learn_dqn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "run", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> None: + + total_timesteps, callback = self.dqn._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + while self.dqn.num_timesteps < total_timesteps: + rollout = self.dqn.collect_rollouts( + self.dqn.env, # type: ignore[arg-type] # This is from SB3 code + train_freq=self.dqn.train_freq, # type: ignore[arg-type] # SB3 + action_noise=self.dqn.action_noise, + callback=callback, + learning_starts=self.dqn.learning_starts, + replay_buffer=self.dqn.replay_buffer, # type: ignore[arg-type] # SB3 + log_interval=log_interval, + ) + + if rollout.continue_training is False: + break + + if ( + self.dqn.num_timesteps > 0 + and self.dqn.num_timesteps > self.dqn.learning_starts + ): + # If no `gradient_steps` is specified, + # do as many gradients steps as steps performed during the rollout + gradient_steps = ( + self.dqn.gradient_steps + if self.dqn.gradient_steps >= 0 + else rollout.episode_timesteps + ) + # Special case when the user passes `gradient_steps=0` + if gradient_steps > 0: + self.train_dqn( + batch_size=self.dqn.batch_size, + gradient_steps=gradient_steps, + ) + + callback.on_training_end() diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 9a7f7794f..add281a65 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -8,6 +8,7 @@ Callable, Dict, Hashable, + Iterable, List, Mapping, Optional, @@ -527,7 +528,7 @@ def rollout_stats( def flatten_trajectories( - trajectories: Sequence[types.Trajectory], + trajectories: Iterable[types.Trajectory], ) -> types.Transitions: """Flatten a series of trajectory dictionaries into arrays. From 2d4151ed6cfea5e758a93f65187381153f9bd3f2 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Tue, 4 Jul 2023 21:42:03 +0200 Subject: [PATCH 02/57] Pin SB3 version to 1.7.0 (#738) (#745) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 73ffd00ac..fa1d03f31 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ "autorom[accept-rom-license]~=0.6.0", ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] -STABLE_BASELINES3 = "stable-baselines3>=1.7.0" +STABLE_BASELINES3 = "stable-baselines3>=1.7.0,<2.0.0" # pinned to 0.21 until https://github.com/DLR-RM/stable-baselines3/pull/780 goes # upstream. GYM_VERSION_SPECIFIER = "==0.21.0" From 993a0d7c578ec15bda4ab9fb3328fbfa20cf189c Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Tue, 4 Jul 2023 22:42:39 +0200 Subject: [PATCH 03/57] Another redundant type warning --- src/imitation/algorithms/sqil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 6e6538cde..7ef85811d 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -199,7 +199,7 @@ def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None: losses = [] for _ in range(gradient_steps): # Sample replay buffer - new_data = self.dqn.replay_buffer.sample( + new_data = self.dqn.replay_buffer.sample( # type: ignore[union-attr] batch_size // 2, env=self.dqn._vec_normalize_env, ) From 899a5d8d3ddfee9588260784ce8cecce82a5a13f Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Wed, 5 Jul 2023 16:51:38 +0200 Subject: [PATCH 04/57] Correctly set the expert rewards to 1 Remove redundant parameter --- src/imitation/algorithms/sqil.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 7ef85811d..ef4d44330 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -119,8 +119,6 @@ def __init__( super().__init__(demonstrations=demonstrations, custom_logger=custom_logger) - self.orig_train_freq = train_freq - self.dqn = dqn.DQN( policy=policy, env=venv, @@ -210,6 +208,8 @@ def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None: env=self.dqn._vec_normalize_env, ) + expert_data.rewards.fill_(1) # Fill the rewards with 1 + # Concatenate the two batches of data replay_data = ReplayBufferSamples( *( From 73064acd9513a01454733ae69ff5c236871e6b61 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 09:58:29 +0200 Subject: [PATCH 05/57] Update typing, add some tests --- src/imitation/algorithms/sqil.py | 3 +- tests/algorithms/test_sqil.py | 71 ++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 tests/algorithms/test_sqil.py diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index ef4d44330..2b1a94dfd 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -22,7 +22,6 @@ from imitation.algorithms.base import AnyTransitions from imitation.data import types from imitation.data.rollout import flatten_trajectories -from imitation.data.types import Transitions from imitation.util import logger as imit_logger from imitation.util.util import get_first_iter_element @@ -40,7 +39,7 @@ def __init__( self, *, venv: vec_env.VecEnv, - demonstrations: Transitions, + demonstrations: Optional[AnyTransitions], policy: Union[str, Type[DQNPolicy]], custom_logger: Optional[imit_logger.HierarchicalLogger] = None, learning_rate: Union[float, Schedule] = 1e-4, diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py new file mode 100644 index 000000000..dcc2f95dd --- /dev/null +++ b/tests/algorithms/test_sqil.py @@ -0,0 +1,71 @@ +import gym +import numpy as np +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.vec_env import DummyVecEnv + +from imitation.algorithms.sqil import SQIL +from imitation.data import rollout, types, wrappers + + +def test_sqil_demonstration_buffer(rng): + env = gym.make("CartPole-v1") + venv = DummyVecEnv([lambda: env]) + policy = "MlpPolicy" + + sampling_agent = SQIL( + venv=venv, + demonstrations=None, + policy=policy, + ) + + rollouts = rollout.rollout( + sampling_agent.policy, + DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]), + rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + ) + demonstrations = rollout.flatten_trajectories(rollouts) + + model = SQIL( + venv=venv, + demonstrations=demonstrations, + policy=policy, + ) + + # Check that demonstrations are stored in the replay buffer correctly + for i in range(len(demonstrations)): + obs = model.expert_buffer.observations[i] + act = model.expert_buffer.actions[i] + next_obs = model.expert_buffer.next_observations[i] + done = model.expert_buffer.dones[i] + + np.testing.assert_array_equal(obs[0], demonstrations.obs[i]) + np.testing.assert_array_equal(act[0], demonstrations.acts[i]) + np.testing.assert_array_equal(next_obs[0], demonstrations.next_obs[i]) + np.testing.assert_array_equal(done, demonstrations.dones[i]) + + +def test_sqil_cartpole_no_crash(rng): + env = gym.make("CartPole-v1") + venv = DummyVecEnv([lambda: env]) + + policy = "MlpPolicy" + sampling_agent = SQIL( + venv=venv, + demonstrations=None, + policy=policy, + ) + + rollouts = rollout.rollout( + sampling_agent.policy, + DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]), + rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + ) + demonstrations = rollout.flatten_trajectories(rollouts) + model = SQIL( + venv=venv, + demonstrations=demonstrations, + policy=policy, + ) + model.train(total_timesteps=100) From b6c9d2612f8666544f48bfb620f580c8bec4d3ea Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 11:22:00 +0200 Subject: [PATCH 06/57] Update sqil.py --- src/imitation/algorithms/sqil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 2b1a94dfd..21bc339a3 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -1,4 +1,4 @@ -"""Soft Q Imitation Learning (SQIL). +"""Soft Q Imitation Learning (SQIL) (https://arxiv.org/abs/1905.11108). Trains a policy via DQN-style Q-learning, replacing half the buffer with expert demonstrations and adjusting the rewards. From 42d5468f719c773e1933299588b0e88d03790d98 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 14:06:58 +0200 Subject: [PATCH 07/57] Style fixes --- src/imitation/algorithms/sqil.py | 39 ++++++++++++-------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 21bc339a3..7eb159608 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -9,21 +9,12 @@ import torch as th import torch.nn.functional as F from stable_baselines3 import dqn -from stable_baselines3.common import policies, vec_env -from stable_baselines3.common.buffers import ReplayBuffer -from stable_baselines3.common.type_aliases import ( - MaybeCallback, - ReplayBufferSamples, - Schedule, -) +from stable_baselines3.common import buffers, policies, type_aliases, vec_env from stable_baselines3.dqn.policies import DQNPolicy from imitation.algorithms import base as algo_base -from imitation.algorithms.base import AnyTransitions -from imitation.data import types -from imitation.data.rollout import flatten_trajectories -from imitation.util import logger as imit_logger -from imitation.util.util import get_first_iter_element +from imitation.data import types, rollout +from imitation.util import logger, util class SQIL(algo_base.DemonstrationAlgorithm): @@ -33,16 +24,16 @@ class SQIL(algo_base.DemonstrationAlgorithm): replacing half the buffer with expert demonstrations and adjusting the rewards. """ - expert_buffer: ReplayBuffer + expert_buffer: buffers.ReplayBuffer def __init__( self, *, venv: vec_env.VecEnv, - demonstrations: Optional[AnyTransitions], + demonstrations: Optional[algo_base.AnyTransitions], policy: Union[str, Type[DQNPolicy]], - custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - learning_rate: Union[float, Schedule] = 1e-4, + custom_logger: Optional[logger.HierarchicalLogger] = None, + learning_rate: Union[float, type_aliases.Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 50000, batch_size: int = 32, @@ -50,7 +41,7 @@ def __init__( gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_class: Optional[Type[buffers.ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, @@ -145,20 +136,20 @@ def __init__( _init_setup_model=_init_setup_model, ) - def set_demonstrations(self, demonstrations: AnyTransitions) -> None: + def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None: # If demonstrations is a list of trajectories, # flatten it into a list of transitions if isinstance(demonstrations, Iterable): - item, demonstrations = get_first_iter_element( # type: ignore[assignment] + item, demonstrations = util.get_first_iter_element( # type: ignore[assignment] demonstrations, # type: ignore[assignment] ) if isinstance(item, types.Trajectory): - demonstrations = flatten_trajectories( + demonstrations = rollout.flatten_trajectories( demonstrations, # type: ignore[arg-type] ) n_samples = len(demonstrations) # type: ignore[arg-type] - self.expert_buffer = ReplayBuffer( + self.expert_buffer = buffers.ReplayBuffer( n_samples, self.venv.observation_space, self.venv.action_space, @@ -190,7 +181,7 @@ def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None: # Switch to train mode (this affects batch norm / dropout) self.dqn.policy.set_training_mode(True) - # Update learning rate according to schedule + # Update learning rate according to type_aliases.Schedule self.dqn._update_learning_rate(self.dqn.policy.optimizer) losses = [] @@ -210,7 +201,7 @@ def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None: expert_data.rewards.fill_(1) # Fill the rewards with 1 # Concatenate the two batches of data - replay_data = ReplayBufferSamples( + replay_data = type_aliases.ReplayBufferSamples( *( th.cat((getattr(new_data, name), getattr(expert_data, name))) for name in new_data._fields @@ -268,7 +259,7 @@ def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None: def learn_dqn( self, total_timesteps: int, - callback: MaybeCallback = None, + callback: type_aliases.MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "run", reset_num_timesteps: bool = True, From 86825d861a7f401c1781c93e3eb089f1184f23e7 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 14:25:25 +0200 Subject: [PATCH 08/57] Test updates --- tests/algorithms/test_sqil.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index dcc2f95dd..42831bb98 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -1,18 +1,17 @@ import gym import numpy as np -from stable_baselines3.common.buffers import ReplayBuffer -from stable_baselines3.common.vec_env import DummyVecEnv +import stable_baselines3.common.vec_env as vec_env -from imitation.algorithms.sqil import SQIL -from imitation.data import rollout, types, wrappers +from imitation.algorithms import sqil +from imitation.data import rollout, wrappers def test_sqil_demonstration_buffer(rng): env = gym.make("CartPole-v1") - venv = DummyVecEnv([lambda: env]) + venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) policy = "MlpPolicy" - sampling_agent = SQIL( + sampling_agent = sqil.SQIL( venv=venv, demonstrations=None, policy=policy, @@ -20,13 +19,13 @@ def test_sqil_demonstration_buffer(rng): rollouts = rollout.rollout( sampling_agent.policy, - DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]), + venv, rollout.make_sample_until(min_timesteps=None, min_episodes=50), rng=rng, ) demonstrations = rollout.flatten_trajectories(rollouts) - model = SQIL( + model = sqil.SQIL( venv=venv, demonstrations=demonstrations, policy=policy, @@ -47,10 +46,10 @@ def test_sqil_demonstration_buffer(rng): def test_sqil_cartpole_no_crash(rng): env = gym.make("CartPole-v1") - venv = DummyVecEnv([lambda: env]) + venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) policy = "MlpPolicy" - sampling_agent = SQIL( + sampling_agent = sqil.SQIL( venv=venv, demonstrations=None, policy=policy, @@ -58,12 +57,12 @@ def test_sqil_cartpole_no_crash(rng): rollouts = rollout.rollout( sampling_agent.policy, - DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]), + venv, rollout.make_sample_until(min_timesteps=None, min_episodes=50), rng=rng, ) demonstrations = rollout.flatten_trajectories(rollouts) - model = SQIL( + model = sqil.SQIL( venv=venv, demonstrations=demonstrations, policy=policy, From 95a26619a33dbc6b68a54eef596269c49b001a14 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 14:39:07 +0200 Subject: [PATCH 09/57] Add a test to check the buffer --- tests/algorithms/test_sqil.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 42831bb98..ac29050f2 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -1,6 +1,7 @@ import gym import numpy as np import stable_baselines3.common.vec_env as vec_env +import stable_baselines3.common.buffers as buffers from imitation.algorithms import sqil from imitation.data import rollout, wrappers @@ -43,6 +44,31 @@ def test_sqil_demonstration_buffer(rng): np.testing.assert_array_equal(next_obs[0], demonstrations.next_obs[i]) np.testing.assert_array_equal(done, demonstrations.dones[i]) +def test_sqil_demonstration_without_flatten(rng): + env = gym.make("CartPole-v1") + venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) + policy = "MlpPolicy" + + sampling_agent = sqil.SQIL( + venv=venv, + demonstrations=None, + policy=policy, + ) + + rollouts = rollout.rollout( + sampling_agent.policy, + venv, + rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + ) + + model = sqil.SQIL( + venv=venv, + demonstrations=rollouts, + policy=policy, + ) + + assert isinstance(model.expert_buffer, buffers.ReplayBuffer) def test_sqil_cartpole_no_crash(rng): env = gym.make("CartPole-v1") From 67662b47b183d60d4cddf3abfb396c21e6c6fcd8 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 14:43:57 +0200 Subject: [PATCH 10/57] Formatting, docstring --- src/imitation/algorithms/sqil.py | 7 +++++-- tests/algorithms/test_sqil.py | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 7eb159608..456682cbd 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -13,7 +13,7 @@ from stable_baselines3.dqn.policies import DQNPolicy from imitation.algorithms import base as algo_base -from imitation.data import types, rollout +from imitation.data import rollout, types from imitation.util import logger, util @@ -140,7 +140,10 @@ def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None: # If demonstrations is a list of trajectories, # flatten it into a list of transitions if isinstance(demonstrations, Iterable): - item, demonstrations = util.get_first_iter_element( # type: ignore[assignment] + ( + item, + demonstrations, + ) = util.get_first_iter_element( # type: ignore[assignment] demonstrations, # type: ignore[assignment] ) if isinstance(item, types.Trajectory): diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index ac29050f2..bd8033523 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -1,7 +1,9 @@ +"""Tests `imitation.algorithms.sqil`.""" + import gym import numpy as np -import stable_baselines3.common.vec_env as vec_env import stable_baselines3.common.buffers as buffers +import stable_baselines3.common.vec_env as vec_env from imitation.algorithms import sqil from imitation.data import rollout, wrappers @@ -44,6 +46,7 @@ def test_sqil_demonstration_buffer(rng): np.testing.assert_array_equal(next_obs[0], demonstrations.next_obs[i]) np.testing.assert_array_equal(done, demonstrations.dones[i]) + def test_sqil_demonstration_without_flatten(rng): env = gym.make("CartPole-v1") venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) @@ -70,6 +73,7 @@ def test_sqil_demonstration_without_flatten(rng): assert isinstance(model.expert_buffer, buffers.ReplayBuffer) + def test_sqil_cartpole_no_crash(rng): env = gym.make("CartPole-v1") venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) From 68f693b4f48bb6582b77afdf0bdf53639199b2b8 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 14:48:12 +0200 Subject: [PATCH 11/57] Improve test coverage --- tests/algorithms/test_sqil.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index bd8033523..457a3b73e 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -96,5 +96,6 @@ def test_sqil_cartpole_no_crash(rng): venv=venv, demonstrations=demonstrations, policy=policy, + learning_starts=1000, ) - model.train(total_timesteps=100) + model.train(total_timesteps=10_000) From c4b0521a047b49f3d18185e8e8ea297b0fa2ae95 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 16:58:41 +0200 Subject: [PATCH 12/57] Update branch to master (#749) * Pin SB3 version to 1.7.0 (#738) * Update conftest.py (#742) * Custom environment tutorial (#746) * Custom environment tutorial draft * Update the docs website * Clean notebook * Text clarification and new environment * Decrease training duration to hopefully make CI happy * Clarify that BC itself does not learn rewards --------- Co-authored-by: Ariel Kwiatkowski * Tutorial on comparing algorithm performance (#747) * Add a new tutorial * Update index.rst * Improvements to the tutorial * Some more caution words * Fix typos --------- Co-authored-by: Ariel Kwiatkowski --------- Co-authored-by: Adam Gleave --- docs/index.rst | 2 + docs/tutorials/8_train_custom_env.ipynb | 366 +++++++++++++++++ docs/tutorials/9_compare_baselines.ipynb | 481 +++++++++++++++++++++++ tests/conftest.py | 2 +- 4 files changed, 850 insertions(+), 1 deletion(-) create mode 100644 docs/tutorials/8_train_custom_env.ipynb create mode 100644 docs/tutorials/9_compare_baselines.ipynb diff --git a/docs/index.rst b/docs/index.rst index f3ec53b03..0c516c58b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -74,6 +74,8 @@ If you use ``imitation`` in your research project, please cite our paper to help tutorials/5a_train_preference_comparisons_with_cnn tutorials/6_train_mce tutorials/7_train_density + tutorials/8_train_custom_env + tutorials/9_compare_baselines tutorials/trajectories .. toctree:: diff --git a/docs/tutorials/8_train_custom_env.ipynb b/docs/tutorials/8_train_custom_env.ipynb new file mode 100644 index 000000000..6c9e28726 --- /dev/null +++ b/docs/tutorials/8_train_custom_env.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_custom_env.ipynb)\n", + "# Train Behavior Cloning in a Custom Environment\n", + "\n", + "You can use `imitation` to train a policy (and, for many imitation learning algorithm, learn rewards) in a custom environment.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Define the environment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use a simple ObservationMatching environment as an example. The premise is simple -- the agent receives a vector of observations, and must output a vector of actions that matches the observations as closely as possible.\n", + "\n", + "If you have your own environment that you'd like to use, you can replace the code below with your own environment. Make sure it complies with the standard Gym API, and that the observation and action spaces are specified correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import gym\n", + "\n", + "from gym.spaces import Box\n", + "from gym.utils import seeding\n", + "\n", + "\n", + "class ObservationMatchingEnv(gym.Env):\n", + " def __init__(self, num_options: int = 2):\n", + " self.num_options = num_options\n", + " self.observation_space = Box(0, 1, shape=(num_options,), dtype=np.float32)\n", + " self.action_space = Box(0, 1, shape=(num_options,), dtype=np.float32)\n", + " self.seed()\n", + "\n", + " def seed(self, seed=None):\n", + " self.np_random, seed = seeding.np_random(seed)\n", + " return [seed]\n", + "\n", + " def reset(self):\n", + " self.state = self.np_random.uniform(size=self.num_options)\n", + " return self.state\n", + "\n", + " def step(self, action):\n", + " reward = -np.abs(self.state - action).mean()\n", + " self.state = self.np_random.uniform(size=self.num_options)\n", + " return self.state, reward, False, {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Step 2: create the environment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "From here, we have two options:\n", + "- Add the environment to the gym registry, and use it with existing utilities (e.g. `make`)\n", + "- Use the environment directly\n", + "\n", + "You only need to execute the cells in step 2a, or step 2b to proceed.\n", + "\n", + "At the end of these steps, we want to have:\n", + "- `env`: a single environment that we can use for training an expert with SB3\n", + "- `venv`: a vectorized environment where each individual environment is wrapped in `RolloutInfoWrapper`, that we can use for collecting rollouts with `imitation`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2a (recommended): add the environment to the gym registry" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The standard approach is adding the environment to the gym registry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gym.register(\n", + " id=\"custom/ObservationMatching-v0\",\n", + " entry_point=ObservationMatchingEnv, # This can also be the path to the class, e.g. `observation_matching:ObservationMatchingEnv`\n", + " max_episode_steps=500,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After registering, you can create an environment is `gym.make(env_id)` which automatically handles the `TimeLimit` wrapper.\n", + "\n", + "To create a vectorized env, you can use the `make_vec_env` helper function (Option A), or create it directly (Options B1 and B2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gym.wrappers import TimeLimit\n", + "from imitation.data import rollout\n", + "from imitation.data.wrappers import RolloutInfoWrapper\n", + "from imitation.util.util import make_vec_env\n", + "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n", + "\n", + "# Create a single environment for training an expert with SB3\n", + "env = gym.make(\"custom/ObservationMatching-v0\")\n", + "\n", + "\n", + "# Create a vectorized environment for training with `imitation`\n", + "\n", + "# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers=[lambda env, _: RolloutInfoWrapper(env)]`\n", + "venv = make_vec_env(\n", + " \"custom/ObservationMatching-v0\",\n", + " rng=np.random.default_rng(),\n", + " n_envs=4,\n", + " post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", + ")\n", + "\n", + "\n", + "# Option B1: use a custom env creator, and create VecEnv directly\n", + "# def _make_env():\n", + "# \"\"\"Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper.\"\"\"\n", + "# _env = gym.make(\"custom/ObservationMatching-v0\")\n", + "# _env = RolloutInfoWrapper(_env)\n", + "# return _env\n", + "#\n", + "# venv = DummyVecEnv([_make_env for _ in range(4)])\n", + "#\n", + "# # Option B2: we can also use a parallel VecEnv implementation\n", + "# venv = SubprocVecEnv([_make_env for _ in range(4)])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Step 2b: directly use the environment\n", + "\n", + "Alternatively, we can directly initialize the environment by instantiating the class we created earlier, and handle all the additional logic ourselves." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gym.wrappers import TimeLimit\n", + "from imitation.data import rollout\n", + "from imitation.data.wrappers import RolloutInfoWrapper\n", + "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "import numpy as np\n", + "\n", + "# Create a single environment for training with SB3\n", + "env = ObservationMatchingEnv()\n", + "env = TimeLimit(env, max_episode_steps=500)\n", + "\n", + "# Create a vectorized environment for training with `imitation`\n", + "\n", + "\n", + "# Option A: use a helper function to create multiple environments\n", + "def _make_env():\n", + " \"\"\"Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper.\"\"\"\n", + " _env = ObservationMatchingEnv()\n", + " _env = TimeLimit(_env, max_episode_steps=500)\n", + " _env = RolloutInfoWrapper(_env)\n", + " return _env\n", + "\n", + "\n", + "venv = DummyVecEnv([_make_env for _ in range(4)])\n", + "\n", + "\n", + "# Option B: use a single environment\n", + "# env = FixedHorizonCartPoleEnv()\n", + "# venv = DummyVecEnv([lambda: RolloutInfoWrapper(env)]) # Wrap a single environment -- only useful for simple testing like this\n", + "\n", + "# Option C: use multiple environments\n", + "# venv = DummyVecEnv([lambda: RolloutInfoWrapper(ObservationMatchingEnv()) for _ in range(4)]) # Wrap multiple environments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now we're just about done! Whether you used step 2a or 2b, your environment should now be ready to use with SB3 and `imitation`.\n", + "\n", + "For the sake of completeness, we'll train a BC model, the same way as in the first tutorial, but with our custom environment.\n", + "\n", + "Keep in mind that while we're using BC in this tutorial, you can just as easily use any of the other algorithms with the environment prepared in this way." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3 import PPO\n", + "from stable_baselines3.ppo import MlpPolicy\n", + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "from gym.wrappers import TimeLimit\n", + "\n", + "expert = PPO(\n", + " policy=MlpPolicy,\n", + " env=env,\n", + " seed=0,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + " n_steps=64,\n", + ")\n", + "\n", + "reward, _ = evaluate_policy(expert, env, 10)\n", + "print(f\"Reward before training: {reward}\")\n", + "\n", + "\n", + "# Note: if you followed step 2a, i.e. registered the environment, you can use the environment name directly\n", + "\n", + "# expert = PPO(\n", + "# policy=MlpPolicy,\n", + "# env=\"custom/ObservationMatching-v0\",\n", + "# seed=0,\n", + "# batch_size=64,\n", + "# ent_coef=0.0,\n", + "# learning_rate=0.0003,\n", + "# n_epochs=10,\n", + "# n_steps=64,\n", + "# )\n", + "expert.learn(10_000) # Note: set to 100000 to train a proficient expert\n", + "\n", + "reward, _ = evaluate_policy(expert, env, 10)\n", + "print(f\"Expert reward: {reward}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng()\n", + "rollouts = rollout.rollout(\n", + " expert,\n", + " venv,\n", + " rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n", + " rng=rng,\n", + ")\n", + "transitions = rollout.flatten_trajectories(rollouts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.algorithms import bc\n", + "\n", + "bc_trainer = bc.BC(\n", + " observation_space=env.observation_space,\n", + " action_space=env.action_space,\n", + " demonstrations=transitions,\n", + " rng=rng,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As before, the untrained policy only gets poor rewards:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n", + "print(f\"Reward before training: {reward_before_training}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training, we can get much closer to the expert's performance:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bc_trainer.train(n_epochs=1)\n", + "reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n", + "print(f\"Reward after training: {reward_after_training}\")" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/9_compare_baselines.ipynb b/docs/tutorials/9_compare_baselines.ipynb new file mode 100644 index 000000000..c9bc0481a --- /dev/null +++ b/docs/tutorials/9_compare_baselines.ipynb @@ -0,0 +1,481 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/9_compare_baselines.ipynb)\n", + "# Reliably compare algorithm performance\n", + "\n", + "Did we actually match the expert performance or was it just luck? Did this hyperparameter change actually improve the performance of our algorithm? These are questions that we need to answer when we want to compare the performance of different algorithms or hyperparameters.\n", + "\n", + "`imitation` provides some tools to help you answer these questions. For demonstration purposes, we will use Behavior Cloning on the CartPole-v1 environment. We will compare different variants of the trained algorithm, and also compare it with a more sophisticated algorithm, DAgger." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As in the first tutorial, we will start by training an expert." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "from stable_baselines3 import PPO\n", + "from stable_baselines3.ppo import MlpPolicy\n", + "\n", + "env = gym.make(\"CartPole-v1\")\n", + "expert = PPO(\n", + " policy=MlpPolicy,\n", + " env=env,\n", + " seed=0,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + " n_steps=64,\n", + ")\n", + "expert.learn(10_000) # set to 100_000 for better performance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For comparison, let's also train a not-quite-expert." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "not_expert = PPO(\n", + " policy=MlpPolicy,\n", + " env=env,\n", + " seed=0,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + " n_steps=64,\n", + ")\n", + "\n", + "not_expert.learn(1_000) # set to 10_000 for slightly better performance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So are they any good? Let's quickly get a point estimate of their performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "\n", + "env.seed(0)\n", + "\n", + "expert_reward, _ = evaluate_policy(expert, env, 1)\n", + "not_expert_reward, _ = evaluate_policy(not_expert, env, 1)\n", + "\n", + "print(f\"Expert reward: {expert_reward:.2f}\")\n", + "print(f\"Not expert reward: {not_expert_reward:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But wait! We only ran the evaluation once. What if we got lucky? Let's run the evaluation a few more times and see what happens." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "expert_reward, _ = evaluate_policy(expert, env, 10)\n", + "not_expert_reward, _ = evaluate_policy(not_expert, env, 10)\n", + "\n", + "print(f\"Expert reward: {expert_reward:.2f}\")\n", + "print(f\"Not expert reward: {not_expert_reward:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Seems a bit more robust now, but how certain are we? Fortunately, `imitation` provides us with tools to answer this.\n", + "\n", + "We will perform a permutation test using the `is_significant_reward_improvement` function. We want to be very certain -- let's set the bar high and require a p-value of 0.001." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.testing.reward_improvement import is_significant_reward_improvement\n", + "\n", + "expert_rewards, _ = evaluate_policy(expert, env, 10, return_episode_rewards=True)\n", + "not_expert_rewards, _ = evaluate_policy(\n", + " not_expert, env, 10, return_episode_rewards=True\n", + ")\n", + "\n", + "significant = is_significant_reward_improvement(\n", + " not_expert_rewards, expert_rewards, 0.001\n", + ")\n", + "\n", + "print(\n", + " f\"The expert is {'NOT ' if not significant else ''}significantly better than the not-expert.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Huh, turns out we set the bar too high. We could lower our standards, but that's for cowards.\n", + "Instead, we can collect more data and try again." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.testing.reward_improvement import is_significant_reward_improvement\n", + "\n", + "expert_rewards, _ = evaluate_policy(expert, env, 100, return_episode_rewards=True)\n", + "not_expert_rewards, _ = evaluate_policy(\n", + " not_expert, env, 100, return_episode_rewards=True\n", + ")\n", + "\n", + "significant = is_significant_reward_improvement(\n", + " not_expert_rewards, expert_rewards, 0.001\n", + ")\n", + "\n", + "print(\n", + " f\"The expert is {'NOT ' if not significant else ''}significantly better than the not-expert.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we go! We can now be 99.9% confident that the expert is better than the not-expert -- in this specific case, with these specific trained models. It might still be an extraordinary stroke of luck, or a conspiracy to make us choose the wrong algorithm, but outside of that, we can be pretty sure our data's correct." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use the same principle to with imitation learning algorithms. Let's train a behavior cloning algorithm and see how it compares to the expert. This time, we can lower the bar to the standard \"scientific\" threshold of 0.05." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Like in the first tutorial, we will start by collecting some expert data. But to spice it up, let's also get some data from the not-quite-expert." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.data import rollout\n", + "from imitation.data.wrappers import RolloutInfoWrapper\n", + "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "import numpy as np\n", + "\n", + "rng = np.random.default_rng()\n", + "expert_rollouts = rollout.rollout(\n", + " expert,\n", + " DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n", + " rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n", + " rng=rng,\n", + ")\n", + "expert_transitions = rollout.flatten_trajectories(expert_rollouts)\n", + "\n", + "\n", + "not_expert_rollouts = rollout.rollout(\n", + " not_expert,\n", + " DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n", + " rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n", + " rng=rng,\n", + ")\n", + "not_expert_transitions = rollout.flatten_trajectories(not_expert_rollouts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try cloning an expert and a non-expert, and see how they compare." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.algorithms import bc\n", + "\n", + "expert_bc_trainer = bc.BC(\n", + " observation_space=env.observation_space,\n", + " action_space=env.action_space,\n", + " demonstrations=expert_transitions,\n", + " rng=rng,\n", + ")\n", + "\n", + "not_expert_bc_trainer = bc.BC(\n", + " observation_space=env.observation_space,\n", + " action_space=env.action_space,\n", + " demonstrations=not_expert_transitions,\n", + " rng=rng,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "expert_bc_trainer.train(n_epochs=2)\n", + "not_expert_bc_trainer.train(n_epochs=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bc_expert_rewards, _ = evaluate_policy(\n", + " expert_bc_trainer.policy, env, 10, return_episode_rewards=True\n", + ")\n", + "bc_not_expert_rewards, _ = evaluate_policy(\n", + " not_expert_bc_trainer.policy, env, 10, return_episode_rewards=True\n", + ")\n", + "significant = is_significant_reward_improvement(\n", + " bc_not_expert_rewards, bc_expert_rewards, 0.05\n", + ")\n", + "print(f\"Cloned expert rewards: {bc_expert_rewards}\")\n", + "print(f\"Cloned not-expert rewards: {bc_not_expert_rewards}\")\n", + "\n", + "print(\n", + " f\"Cloned expert is {'NOT ' if not significant else ''}significantly better than the cloned not-expert.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "How about comparing the expert clone to the expert itself?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bc_clone_rewards, _ = evaluate_policy(\n", + " expert_bc_trainer.policy, env, 10, return_episode_rewards=True\n", + ")\n", + "\n", + "expert_rewards, _ = evaluate_policy(expert, env, 10, return_episode_rewards=True)\n", + "\n", + "significant = is_significant_reward_improvement(bc_clone_rewards, expert_rewards, 0.05)\n", + "\n", + "print(f\"Cloned expert rewards: {bc_clone_rewards}\")\n", + "print(f\"Expert rewards: {expert_rewards}\")\n", + "\n", + "print(\n", + " f\"Expert is {'NOT ' if not significant else ''}significantly better than the cloned expert.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Turns out the expert is significantly better than the clone -- again, in this case. Note, however, that this is not proof that the clone is as good as the expert -- there's a subtle difference between the two claims in the context of hypothesis testing.\n", + "\n", + "Note: if you changed the duration of the training at the beginning of this tutorial, you might get different results. While this might break the narrative in this tutorial, it's a good learning opportunity.\n", + "\n", + "When comparing the performance of two agents, algorithms, hyperparameter sets, always remember the scope of what you're testing. In this tutorial, we have one instance of an expert -- but RL training is famously unstable, so another training run with another random seed would likely produce a slightly different result. So ideally, we would like to repeat this procedure several times, training the same agent with different random seeds, and then compare the average performance of the two agents.\n", + "\n", + "Even then, this is just on one environment, with one algorithm. So be wary of generalizing your results too much." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use the same method to compare different algorithms. While CartPole is pretty easy, we can make it more difficult by decreasing the number of episodes in our dataset, and generating them with a suboptimal policy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rollouts = rollout.rollout(\n", + " expert,\n", + " DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n", + " rollout.make_sample_until(min_timesteps=None, min_episodes=1),\n", + " rng=rng,\n", + ")\n", + "transitions = rollout.flatten_trajectories(rollouts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try training a behavior cloning algorithm on this dataset.\n", + "\n", + "Note that for DAgger, we have to cheat a little bit -- it's allowed to use the expert policy to generate additional data.\n", + "For the purposes of this tutorial, we'll stick with this to avoid spending hours training an expert for a more complex environment.\n", + "\n", + "So while this little experiment isn't definitive proof that DAgger is better than BC, you can use the same method to compare any two algorithms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.algorithms.dagger import SimpleDAggerTrainer\n", + "import tempfile\n", + "\n", + "bc_trainer = bc.BC(\n", + " observation_space=env.observation_space,\n", + " action_space=env.action_space,\n", + " demonstrations=transitions,\n", + " rng=rng,\n", + ")\n", + "\n", + "bc_trainer.train(n_epochs=1)\n", + "\n", + "\n", + "with tempfile.TemporaryDirectory(prefix=\"dagger_example_\") as tmpdir:\n", + " print(tmpdir)\n", + " dagger_bc_trainer = bc.BC(\n", + " observation_space=env.observation_space,\n", + " action_space=env.action_space,\n", + " rng=np.random.default_rng(),\n", + " )\n", + " dagger_trainer = SimpleDAggerTrainer(\n", + " venv=DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n", + " scratch_dir=tmpdir,\n", + " expert_policy=expert,\n", + " bc_trainer=dagger_bc_trainer,\n", + " rng=np.random.default_rng(),\n", + " )\n", + "\n", + " dagger_trainer.train(5000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training both BC and DAgger, let's compare their performances again! We expect DAgger to be better -- after all, it's a more advanced algorithm. But is it significantly better?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bc_rewards, _ = evaluate_policy(bc_trainer.policy, env, 10, return_episode_rewards=True)\n", + "dagger_rewards, _ = evaluate_policy(\n", + " dagger_trainer.policy, env, 10, return_episode_rewards=True\n", + ")\n", + "\n", + "significant = is_significant_reward_improvement(bc_rewards, dagger_rewards, 0.05)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"BC rewards: {bc_rewards}\")\n", + "print(f\"DAgger rewards: {dagger_rewards}\")\n", + "\n", + "print(\n", + " f\"Our DAgger agent is {'NOT ' if not significant else ''}significantly better than BC.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you increased the number of training iterations for the expert (in the first cell of the tutorial), you should see that DAgger indeed performs better than BC. If you didn't, you likely see the opposite result. Yet another reason to be careful when interpreting results!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's take a moment, to remember the limitations of this experiment. We're comparing two algorithms on one environment, with one dataset. We're also using a suboptimal expert policy, which might not be the best choice for BC. If you want to convince yourself that DAgger is better than BC, you should pick out a more complex environment, you should run this experiment several times, with different random seeds and perform some hyperparameter optimization to make sure we're not just using unlucky hyperparameters. At the end, we would also need to run the same hypothesis test across average returns of several independent runs.\n", + "\n", + "But now you have all the pieces of the puzzle to do that!" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f" + }, + "kernelspec": { + "display_name": "Python 3.8.10 64-bit ('venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/conftest.py b/tests/conftest.py index 10630fe45..6f278499e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,4 +45,4 @@ def custom_logger(tmpdir: str) -> logger.HierarchicalLogger: @pytest.fixture() def rng() -> np.random.Generator: - return np.random.default_rng() + return np.random.default_rng(seed=0) From 1b5338bb97d626201db36e7566b45df4055f3f5a Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 17:23:25 +0200 Subject: [PATCH 13/57] Some documentation updates (not complete) --- README.md | 1 + docs/algorithms/sqil.rst | 59 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 docs/algorithms/sqil.rst diff --git a/README.md b/README.md index 6915814b1..679f37c2f 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Currently, we have implementations of the algorithms below. 'Discrete' and 'Cont | [Adversarial Inverse Reinforcement Learning](https://arxiv.org/abs/1710.11248) | [`algoritms.airl`](https://imitation.readthedocs.io/en/latest/algorithms/airl.html) | ✅ | ✅ | | [Generative Adversarial Imitation Learning](https://arxiv.org/abs/1606.03476) | [`algorithms.gail`](https://imitation.readthedocs.io/en/latest/algorithms/gail.html) | ✅ | ✅ | | [Deep RL from Human Preferences](https://arxiv.org/abs/1706.03741) | [`algorithms.preference_comparisons`](https://imitation.readthedocs.io/en/latest/algorithms/preference_comparisons.html) | ✅ | ✅ | +| [Soft Q Imitation Learning](https://arxiv.org/abs/1905.11108) | [`algorithms.sqil`](https://imitation.readthedocs.io/en/latest/algorithms/sqil.html) | ✅ | ❌ | You can find [the documentation here](https://imitation.readthedocs.io/en/latest/). diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst new file mode 100644 index 000000000..b86e340a6 --- /dev/null +++ b/docs/algorithms/sqil.rst @@ -0,0 +1,59 @@ +.. _soft q imitation learning docs: + +======================= +Soft Q Imitation Learning (SQIL) +======================= + + + +Example +======= + +Detailed example notebook: :doc:`../tutorials/10_train_sqil` + +.. testcode:: + :skipif: skip_doctests + + import numpy as np + import gym + from stable_baselines3 import PPO + from stable_baselines3.common.evaluation import evaluate_policy + from stable_baselines3.common.vec_env import DummyVecEnv + from stable_baselines3.ppo import MlpPolicy + + from imitation.algorithms import sqil + from imitation.data import rollout + from imitation.data.wrappers import RolloutInfoWrapper + + rng = np.random.default_rng(0) + env = gym.make("CartPole-v1") + expert = PPO(policy=MlpPolicy, env=env) + expert.learn(1000) + + rollouts = rollout.rollout( + expert, + DummyVecEnv([lambda: RolloutInfoWrapper(env)]), + rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + ) + transitions = rollout.flatten_trajectories(rollouts) + + sqil_trainer = sqil.SQIL( + venv=DummyVecEnv([lambda: env]), + demonstrations=transitions, + policy="MlpPolicy", + ) + sqil_trainer.train(n_epochs=1) + reward, _ = evaluate_policy(sqil_trainer.policy, env, 10) + print("Reward:", reward) + +.. testoutput:: + :hide: + + ... + +API +=== +.. autoclass:: imitation.algorithms.sqil.SQIL + :members: + :noindex: From 3c78336efc08588b6d6531dbb95b7fe55e8fb63c Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 17:24:15 +0200 Subject: [PATCH 14/57] Add a SQIL tutorial --- docs/tutorials/10_train_sqil.ipynb | 207 +++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 docs/tutorials/10_train_sqil.ipynb diff --git a/docs/tutorials/10_train_sqil.ipynb b/docs/tutorials/10_train_sqil.ipynb new file mode 100644 index 000000000..af39a491b --- /dev/null +++ b/docs/tutorials/10_train_sqil.ipynb @@ -0,0 +1,207 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/10_train_sqil.ipynb)\n", + "# Train an Agent using Soft Q Imitation Learning\n", + "\n", + "Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) is a simple algorithm that can be used to clone expert behavior.\n", + "It's fundamentally a modification of the DQN algorithm. At each training step, whenever we sample a batch of data from the replay buffer,\n", + "we also sample a batch of expert data. Expert demonstrations are assigned a reward of 1, while the agent's own transitions are assigned a reward of 0.\n", + "This approach encourages the agent to imitate the expert's behavior, but also to avoid unfamiliar states.\n", + "\n", + "In this tutorial we will use the `imitation` library to train an agent using SQIL." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we need an expert in CartPole-v1 so that we can sample expert trajectories.\n", + "Let's train one using stable-baselines3.\n", + "\n", + "Note that you can use other environments, but the action space must be discrete for this algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "from stable_baselines3 import PPO\n", + "from stable_baselines3.ppo import MlpPolicy\n", + "\n", + "env = gym.make(\"CartPole-v1\")\n", + "expert = PPO(\n", + " policy=MlpPolicy,\n", + " env=env,\n", + " seed=0,\n", + " batch_size=64,\n", + " ent_coef=0.0,\n", + " learning_rate=0.0003,\n", + " n_epochs=10,\n", + " n_steps=64,\n", + ")\n", + "expert.learn(100_000) # Note: set to 100000 to train a proficient expert" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's quickly check if the expert is any good.\n", + "We usually should be able to reach a reward of 500, which is the maximum achievable value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "\n", + "reward, _ = evaluate_policy(expert, env, 10)\n", + "print(reward)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can use the expert to sample some trajectories.\n", + "We flatten them right away since we only need individual transitions.\n", + "`imitation` comes with a number of helper functions that makes collecting those transitions really easy. First we collect 50 episode rollouts, then we flatten them to just the transitions that we need for training.\n", + "Note that the rollout function requires a vectorized environment and needs the `RolloutInfoWrapper` around each of the environments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.data import rollout\n", + "from imitation.data.wrappers import RolloutInfoWrapper\n", + "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "import numpy as np\n", + "\n", + "venv = DummyVecEnv([lambda: RolloutInfoWrapper(env)])\n", + "rng = np.random.default_rng()\n", + "rollouts = rollout.rollout(\n", + " expert,\n", + " venv,\n", + " rollout.make_sample_until(min_timesteps=None, min_episodes=100),\n", + " rng=rng,\n", + ")\n", + "transitions = rollout.flatten_trajectories(rollouts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's have a quick look at what we just generated using those library functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " f\"\"\"The `rollout` function generated a list of {len(rollouts)} {type(rollouts[0])}.\n", + "After flattening, this list is turned into a {type(transitions)} object containing {len(transitions)} transitions.\n", + "The transitions object contains arrays for: {', '.join(transitions.__dict__.keys())}.\"\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After we collected our transitions, it's time to set up our behavior cloning algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from imitation.algorithms import sqil\n", + "\n", + "sqil_trainer = sqil.SQIL(\n", + " venv=venv,\n", + " demonstrations=transitions,\n", + " policy=\"MlpPolicy\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see the untrained policy only gets poor rewards:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reward_before_training, _ = evaluate_policy(sqil_trainer.policy, env, 10)\n", + "print(f\"Reward before training: {reward_before_training}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training, we can match the rewards of the expert (500):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sqil_trainer.train(total_timesteps=1_000_000) # Note: set to 1_000_000 to obtain good results\n", + "reward_after_training, _ = evaluate_policy(sqil_trainer.policy, env, 10)\n", + "print(f\"Reward after training: {reward_after_training}\")" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f" + }, + "kernelspec": { + "display_name": "Python 3.8.10 64-bit ('venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From c303af14995c49127f0783902319259fec2113de Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 17:29:05 +0200 Subject: [PATCH 15/57] Reduce tutorial runtime --- docs/tutorials/10_train_sqil.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/10_train_sqil.ipynb b/docs/tutorials/10_train_sqil.ipynb index af39a491b..ff980e373 100644 --- a/docs/tutorials/10_train_sqil.ipynb +++ b/docs/tutorials/10_train_sqil.ipynb @@ -46,7 +46,7 @@ " n_epochs=10,\n", " n_steps=64,\n", ")\n", - "expert.learn(100_000) # Note: set to 100000 to train a proficient expert" + "expert.learn(1_000) # Note: set to 100_000 to train a proficient expert" ] }, { @@ -174,7 +174,7 @@ "metadata": {}, "outputs": [], "source": [ - "sqil_trainer.train(total_timesteps=1_000_000) # Note: set to 1_000_000 to obtain good results\n", + "sqil_trainer.train(total_timesteps=1_000) # Note: set to 1_000_000 to obtain good results\n", "reward_after_training, _ = evaluate_policy(sqil_trainer.policy, env, 10)\n", "print(f\"Reward after training: {reward_after_training}\")" ] From bf81940b49d8b63e63396815f08762aec1d0e34f Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 18:48:45 +0200 Subject: [PATCH 16/57] Add SQIL description in docs, try to add it to the right places --- docs/algorithms/sqil.rst | 12 +++++++++--- docs/index.rst | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst index b86e340a6..d680ba3af 100644 --- a/docs/algorithms/sqil.rst +++ b/docs/algorithms/sqil.rst @@ -1,10 +1,16 @@ .. _soft q imitation learning docs: -======================= +================================ Soft Q Imitation Learning (SQIL) -======================= +================================ - +Soft Q Imitation learning learns to imitate a policy from demonstrations by +using the DQN algorithm with modified rewards. During each policy update, half +of the batch is sampled from the demonstrations and half is sampled from the +environment. Expert demonstrations are assigned a reward of 1, and the +environment is assigned a reward of 0. This encourages the policy to imitate +the demonstrations, and to simultaneously avoid states not seen in the +demonstrations. Example ======= diff --git a/docs/index.rst b/docs/index.rst index 0c516c58b..204836d61 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,6 +60,7 @@ If you use ``imitation`` in your research project, please cite our paper to help algorithms/density algorithms/mce_irl algorithms/preference_comparisons + algorithms/sqil .. toctree:: :maxdepth: 2 @@ -76,6 +77,7 @@ If you use ``imitation`` in your research project, please cite our paper to help tutorials/7_train_density tutorials/8_train_custom_env tutorials/9_compare_baselines + tutorials/10_train_sqil tutorials/trajectories .. toctree:: From 5da56f3b541823cefc2469ff19ed1e99635dc404 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 18:55:04 +0200 Subject: [PATCH 17/57] Fix docs --- docs/algorithms/sqil.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst index d680ba3af..ecd9bea6a 100644 --- a/docs/algorithms/sqil.rst +++ b/docs/algorithms/sqil.rst @@ -49,7 +49,7 @@ Detailed example notebook: :doc:`../tutorials/10_train_sqil` demonstrations=transitions, policy="MlpPolicy", ) - sqil_trainer.train(n_epochs=1) + sqil_trainer.train(total_timesteps=1000) reward, _ = evaluate_policy(sqil_trainer.policy, env, 10) print("Reward:", reward) From d8f3c30f1f7825c9ce6625b98f37284f0c2fbb98 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Thu, 6 Jul 2023 19:20:15 +0200 Subject: [PATCH 18/57] Blacken a tutorial --- docs/tutorials/10_train_sqil.ipynb | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/tutorials/10_train_sqil.ipynb b/docs/tutorials/10_train_sqil.ipynb index ff980e373..ddbbeb98a 100644 --- a/docs/tutorials/10_train_sqil.ipynb +++ b/docs/tutorials/10_train_sqil.ipynb @@ -174,7 +174,9 @@ "metadata": {}, "outputs": [], "source": [ - "sqil_trainer.train(total_timesteps=1_000) # Note: set to 1_000_000 to obtain good results\n", + "sqil_trainer.train(\n", + " total_timesteps=1_000\n", + ") # Note: set to 1_000_000 to obtain good results\n", "reward_after_training, _ = evaluate_policy(sqil_trainer.policy, env, 10)\n", "print(f\"Reward after training: {reward_after_training}\")" ] From ae43a75bc574c77466ce25778afeae856420b2ca Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 12:51:37 +0200 Subject: [PATCH 19/57] Reorder things in docs --- docs/algorithms/sqil.rst | 2 +- docs/index.rst | 4 ++-- .../{8_train_custom_env.ipynb => 10_train_custom_env.ipynb} | 2 +- docs/tutorials/{10_train_sqil.ipynb => 8_train_sqil.ipynb} | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) rename docs/tutorials/{8_train_custom_env.ipynb => 10_train_custom_env.ipynb} (99%) rename docs/tutorials/{10_train_sqil.ipynb => 8_train_sqil.ipynb} (98%) diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst index ecd9bea6a..1c794d515 100644 --- a/docs/algorithms/sqil.rst +++ b/docs/algorithms/sqil.rst @@ -15,7 +15,7 @@ demonstrations. Example ======= -Detailed example notebook: :doc:`../tutorials/10_train_sqil` +Detailed example notebook: :doc:`../tutorials/8_train_sqil` .. testcode:: :skipif: skip_doctests diff --git a/docs/index.rst b/docs/index.rst index 204836d61..f3d8f4033 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -75,9 +75,9 @@ If you use ``imitation`` in your research project, please cite our paper to help tutorials/5a_train_preference_comparisons_with_cnn tutorials/6_train_mce tutorials/7_train_density - tutorials/8_train_custom_env + tutorials/8_train_sqil tutorials/9_compare_baselines - tutorials/10_train_sqil + tutorials/10_train_custom_env tutorials/trajectories .. toctree:: diff --git a/docs/tutorials/8_train_custom_env.ipynb b/docs/tutorials/10_train_custom_env.ipynb similarity index 99% rename from docs/tutorials/8_train_custom_env.ipynb rename to docs/tutorials/10_train_custom_env.ipynb index 6c9e28726..e6a39cd87 100644 --- a/docs/tutorials/8_train_custom_env.ipynb +++ b/docs/tutorials/10_train_custom_env.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_custom_env.ipynb)\n", + "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/10_train_custom_env.ipynb)\n", "# Train Behavior Cloning in a Custom Environment\n", "\n", "You can use `imitation` to train a policy (and, for many imitation learning algorithm, learn rewards) in a custom environment.\n", diff --git a/docs/tutorials/10_train_sqil.ipynb b/docs/tutorials/8_train_sqil.ipynb similarity index 98% rename from docs/tutorials/10_train_sqil.ipynb rename to docs/tutorials/8_train_sqil.ipynb index ddbbeb98a..1ada90a69 100644 --- a/docs/tutorials/10_train_sqil.ipynb +++ b/docs/tutorials/8_train_sqil.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/10_train_sqil.ipynb)\n", + "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_sqil.ipynb)\n", "# Train an Agent using Soft Q Imitation Learning\n", "\n", "Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) is a simple algorithm that can be used to clone expert behavior.\n", From 5b23f8421321cc6b2e3cbe8fc39d40575be37137 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 13:35:12 +0200 Subject: [PATCH 20/57] Change the SQIL structure to instead subclass the replay buffer, new test --- src/imitation/algorithms/sqil.py | 260 ++++++++++++++----------------- src/imitation/util/util.py | 15 ++ tests/algorithms/test_sqil.py | 59 +++++-- 3 files changed, 178 insertions(+), 156 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 456682cbd..8959ae18a 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -7,7 +7,7 @@ import numpy as np import torch as th -import torch.nn.functional as F +from gym import spaces from stable_baselines3 import dqn from stable_baselines3.common import buffers, policies, type_aliases, vec_env from stable_baselines3.dqn.policies import DQNPolicy @@ -107,8 +107,6 @@ def __init__( """ self.venv = venv - super().__init__(demonstrations=demonstrations, custom_logger=custom_logger) - self.dqn = dqn.DQN( policy=policy, env=venv, @@ -120,8 +118,8 @@ def __init__( gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, - replay_buffer_class=replay_buffer_class, - replay_buffer_kwargs=replay_buffer_kwargs, + replay_buffer_class=SQILReplayBuffer, + replay_buffer_kwargs={"demonstrations": demonstrations}, optimize_memory_usage=optimize_memory_usage, target_update_interval=target_update_interval, exploration_fraction=exploration_fraction, @@ -136,7 +134,91 @@ def __init__( _init_setup_model=_init_setup_model, ) + super().__init__(demonstrations=demonstrations, custom_logger=custom_logger) + def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None: + assert isinstance(self.dqn.replay_buffer, SQILReplayBuffer) + self.dqn.replay_buffer.set_demonstrations(demonstrations) + + def train(self, *, total_timesteps: int): + self.dqn.learn(total_timesteps=total_timesteps) + + @property + def policy(self) -> policies.BasePolicy: + assert isinstance(self.dqn.policy, policies.BasePolicy) + return self.dqn.policy + + +class SQILReplayBuffer(buffers.ReplayBuffer): + """Replay buffer used in off-policy algorithms like SAC/TD3. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param n_envs: Number of parallel environments + :param optimize_memory_usage: Enable a memory efficient variant + of the replay buffer which reduces by almost a factor two the memory used, + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 + Cannot be used in combination with handle_timeout_termination. + :param handle_timeout_termination: Handle timeout termination (due to timelimit) + separately and treat the task as infinite horizon task. + https://github.com/DLR-RM/stable-baselines3/issues/284 + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + demonstrations: algo_base.AnyTransitions, + device: Union[th.device, str] = "auto", + n_envs: int = 1, + optimize_memory_usage: bool = False, + ): + """A modification of the SB3 ReplayBuffer. + + This buffer is fundamentally the same as ReplayBuffer, + but it includes an expert demonstration internal buffer. + When sampling a batch of data, it will be 50/50 expert and collected data. + + Args: + buffer_size: Max number of element in the buffer + observation_space: Observation space + action_space: Action space + demonstrations: Expert demonstrations. + device: PyTorch device. + n_envs: Number of parallel environments. Defaults to 1. + optimize_memory_usage: Enable a memory efficient variant + of the replay buffer which reduces by almost a factor two + the memory used, at a cost of more complexity. + """ + super().__init__( + buffer_size, + observation_space, + action_space, + device, + n_envs, + optimize_memory_usage, + handle_timeout_termination=False, + ) + + self.expert_buffer = self.set_demonstrations(demonstrations) + + def set_demonstrations( + self, + demonstrations: algo_base.AnyTransitions, + ) -> buffers.ReplayBuffer: + """Set the demonstrations to be used in the buffer. + + Args: + demonstrations (algo_base.AnyTransitions): Expert demonstrations. + + Returns: + buffers.ReplayBuffer: The buffer with demonstrations added + """ # If demonstrations is a list of trajectories, # flatten it into a list of transitions if isinstance(demonstrations, Iterable): @@ -152,15 +234,15 @@ def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None: ) n_samples = len(demonstrations) # type: ignore[arg-type] - self.expert_buffer = buffers.ReplayBuffer( + expert_buffer = buffers.ReplayBuffer( n_samples, - self.venv.observation_space, - self.venv.action_space, + self.observation_space, + self.action_space, handle_timeout_termination=False, ) for transition in demonstrations: - self.expert_buffer.add( + expert_buffer.add( obs=np.array(transition["obs"]), # type: ignore[index] next_obs=np.array(transition["next_obs"]), # type: ignore[index] action=np.array(transition["acts"]), # type: ignore[index] @@ -169,146 +251,40 @@ def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None: infos=[{}], ) - def train(self, *, total_timesteps: int): - self.learn_dqn(total_timesteps=total_timesteps) + return expert_buffer - @property - def policy(self) -> policies.BasePolicy: - assert isinstance(self.dqn.policy, policies.BasePolicy) - return self.dqn.policy - - def train_dqn(self, gradient_steps: int, batch_size: int = 100) -> None: - - # Needed to make mypy happy, because SB3 typing is shoddy - assert isinstance(self.dqn.policy, policies.BasePolicy) - - # Switch to train mode (this affects batch norm / dropout) - self.dqn.policy.set_training_mode(True) - # Update learning rate according to type_aliases.Schedule - self.dqn._update_learning_rate(self.dqn.policy.optimizer) - - losses = [] - for _ in range(gradient_steps): - # Sample replay buffer - new_data = self.dqn.replay_buffer.sample( # type: ignore[union-attr] - batch_size // 2, - env=self.dqn._vec_normalize_env, - ) - new_data.rewards.zero_() # Zero out the rewards - - expert_data = self.expert_buffer.sample( - batch_size // 2, - env=self.dqn._vec_normalize_env, - ) - - expert_data.rewards.fill_(1) # Fill the rewards with 1 + def sample( + self, + batch_size: int, + env: Optional[vec_env.VecNormalize] = None, + ) -> buffers.ReplayBufferSamples: + """Sample a batch of data. - # Concatenate the two batches of data - replay_data = type_aliases.ReplayBufferSamples( - *( - th.cat((getattr(new_data, name), getattr(expert_data, name))) - for name in new_data._fields - ), - ) + Half of the batch will be from the expert buffer, + and the other half will be from the collected data. - with th.no_grad(): - # Compute the next Q-values using the target network - next_q_values = self.dqn.q_net_target(replay_data.next_observations) - # Follow greedy policy: use the one with the highest value - next_q_values, _ = next_q_values.max(dim=1) - # Avoid potential broadcast issue - next_q_values = next_q_values.reshape(-1, 1) - # 1-step TD target - target_q_values = ( - replay_data.rewards - + (1 - replay_data.dones) * self.dqn.gamma * next_q_values - ) - - # Get current Q-values estimates - current_q_values = self.dqn.q_net(replay_data.observations) + Args: + batch_size: Number of element to sample in total + env: associated gym VecEnv to normalize the observations/rewards + when sampling - # Retrieve the q-values for the actions from the replay buffer - current_q_values = th.gather( - current_q_values, - dim=1, - index=replay_data.actions.long(), - ) + Returns: + A batch of samples for DQN - # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q_values, target_q_values) - losses.append(loss.item()) - - # Optimize the policy - self.dqn.policy.optimizer.zero_grad() - loss.backward() - # Clip gradient norm - # For some reason pytype doesn't see nn.utils, so adding a type ignore - th.nn.utils.clip_grad_norm_( # type: ignore[module-attr] - self.dqn.policy.parameters(), - self.dqn.max_grad_norm, - ) - self.dqn.policy.optimizer.step() + """ + new_sample_size, expert_sample_size = util.split_in_half(batch_size) - # Increase update counter - self.dqn._n_updates += gradient_steps + new_sample = super().sample(new_sample_size, env) + new_sample.rewards.fill_(0) - self.dqn.logger.record( - "train/n_updates", - self.dqn._n_updates, - exclude="tensorboard", - ) - self.dqn.logger.record("train/loss", np.mean(losses)) + expert_sample = self.expert_buffer.sample(expert_sample_size, env) + expert_sample.rewards.fill_(1) - def learn_dqn( - self, - total_timesteps: int, - callback: type_aliases.MaybeCallback = None, - log_interval: int = 4, - tb_log_name: str = "run", - reset_num_timesteps: bool = True, - progress_bar: bool = False, - ) -> None: - - total_timesteps, callback = self.dqn._setup_learn( - total_timesteps, - callback, - reset_num_timesteps, - tb_log_name, - progress_bar, + replay_data = type_aliases.ReplayBufferSamples( + *( + th.cat((getattr(new_sample, name), getattr(expert_sample, name))) + for name in new_sample._fields + ), ) - callback.on_training_start(locals(), globals()) - - while self.dqn.num_timesteps < total_timesteps: - rollout = self.dqn.collect_rollouts( - self.dqn.env, # type: ignore[arg-type] # This is from SB3 code - train_freq=self.dqn.train_freq, # type: ignore[arg-type] # SB3 - action_noise=self.dqn.action_noise, - callback=callback, - learning_starts=self.dqn.learning_starts, - replay_buffer=self.dqn.replay_buffer, # type: ignore[arg-type] # SB3 - log_interval=log_interval, - ) - - if rollout.continue_training is False: - break - - if ( - self.dqn.num_timesteps > 0 - and self.dqn.num_timesteps > self.dqn.learning_starts - ): - # If no `gradient_steps` is specified, - # do as many gradients steps as steps performed during the rollout - gradient_steps = ( - self.dqn.gradient_steps - if self.dqn.gradient_steps >= 0 - else rollout.episode_timesteps - ) - # Special case when the user passes `gradient_steps=0` - if gradient_steps > 0: - self.train_dqn( - batch_size=self.dqn.batch_size, - gradient_steps=gradient_steps, - ) - - callback.on_training_end() + return replay_data diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index f2e737b8e..83696028d 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -445,3 +445,18 @@ def parse_optional_path( return None else: return parse_path(path, allow_relative, base_directory) + + +def split_in_half(x: int) -> Tuple[int, int]: + """Split an integer in half, rounding up. + + This is to ensure that the two halves sum to the original integer. + + Args: + x: The integer to split. + + Returns: + A tuple containing the two halves of `x`. + """ + half = x // 2 + return half, x - half diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 457a3b73e..ad696ee45 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -4,6 +4,7 @@ import numpy as np import stable_baselines3.common.buffers as buffers import stable_baselines3.common.vec_env as vec_env +import stable_baselines3.dqn as dqn from imitation.algorithms import sqil from imitation.data import rollout, wrappers @@ -14,9 +15,8 @@ def test_sqil_demonstration_buffer(rng): venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) policy = "MlpPolicy" - sampling_agent = sqil.SQIL( - venv=venv, - demonstrations=None, + sampling_agent = dqn.DQN( + env=env, policy=policy, ) @@ -34,12 +34,15 @@ def test_sqil_demonstration_buffer(rng): policy=policy, ) + assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) + expert_buffer = model.dqn.replay_buffer.expert_buffer + # Check that demonstrations are stored in the replay buffer correctly for i in range(len(demonstrations)): - obs = model.expert_buffer.observations[i] - act = model.expert_buffer.actions[i] - next_obs = model.expert_buffer.next_observations[i] - done = model.expert_buffer.dones[i] + obs = expert_buffer.observations[i] + act = expert_buffer.actions[i] + next_obs = expert_buffer.next_observations[i] + done = expert_buffer.dones[i] np.testing.assert_array_equal(obs[0], demonstrations.obs[i]) np.testing.assert_array_equal(act[0], demonstrations.acts[i]) @@ -52,9 +55,8 @@ def test_sqil_demonstration_without_flatten(rng): venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) policy = "MlpPolicy" - sampling_agent = sqil.SQIL( - venv=venv, - demonstrations=None, + sampling_agent = dqn.DQN( + env=env, policy=policy, ) @@ -71,7 +73,8 @@ def test_sqil_demonstration_without_flatten(rng): policy=policy, ) - assert isinstance(model.expert_buffer, buffers.ReplayBuffer) + assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) + assert isinstance(model.dqn.replay_buffer.expert_buffer, buffers.ReplayBuffer) def test_sqil_cartpole_no_crash(rng): @@ -79,9 +82,8 @@ def test_sqil_cartpole_no_crash(rng): venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) policy = "MlpPolicy" - sampling_agent = sqil.SQIL( - venv=venv, - demonstrations=None, + sampling_agent = dqn.DQN( + env=env, policy=policy, ) @@ -99,3 +101,32 @@ def test_sqil_cartpole_no_crash(rng): learning_starts=1000, ) model.train(total_timesteps=10_000) + + +def test_sqil_cartpole_few_demonstrations(rng): + env = gym.make("CartPole-v1") + venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) + + policy = "MlpPolicy" + sampling_agent = dqn.DQN( + env=env, + policy=policy, + ) + + rollouts = rollout.rollout( + sampling_agent.policy, + venv, + rollout.make_sample_until(min_timesteps=None, min_episodes=1), + rng=rng, + ) + + demonstrations = rollout.flatten_trajectories(rollouts) + demonstrations = demonstrations[:5] + + model = sqil.SQIL( + venv=venv, + demonstrations=demonstrations, + policy=policy, + learning_starts=10, + ) + model.train(total_timesteps=1_000) From bc8152b0bd0612621d37ec03fb1803ac514efd36 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 13:48:31 +0200 Subject: [PATCH 21/57] Add an empty line --- src/imitation/algorithms/sqil.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 8959ae18a..977decd11 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -3,6 +3,7 @@ Trains a policy via DQN-style Q-learning, replacing half the buffer with expert demonstrations and adjusting the rewards. """ + from typing import Any, Dict, Iterable, Optional, Tuple, Type, Union import numpy as np From 7d56e6a7a20901a04806c16bd69fa37cc7c1c679 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 13:57:25 +0200 Subject: [PATCH 22/57] Simplify the arguments --- src/imitation/algorithms/sqil.py | 106 ++----------------------------- tests/algorithms/test_sqil.py | 4 +- 2 files changed, 7 insertions(+), 103 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 977decd11..9bf5a247c 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -4,7 +4,7 @@ replacing half the buffer with expert demonstrations and adjusting the rewards. """ -from typing import Any, Dict, Iterable, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, Optional, Type, Union import numpy as np import torch as th @@ -34,28 +34,7 @@ def __init__( demonstrations: Optional[algo_base.AnyTransitions], policy: Union[str, Type[DQNPolicy]], custom_logger: Optional[logger.HierarchicalLogger] = None, - learning_rate: Union[float, type_aliases.Schedule] = 1e-4, - buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 50000, - batch_size: int = 32, - tau: float = 1.0, - gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 4, - gradient_steps: int = 1, - replay_buffer_class: Optional[Type[buffers.ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, - optimize_memory_usage: bool = False, - target_update_interval: int = 10000, - exploration_fraction: float = 0.1, - exploration_initial_eps: float = 1.0, - exploration_final_eps: float = 0.05, - max_grad_norm: float = 10, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = "auto", - _init_setup_model: bool = True, + dqn_kwargs: Optional[Dict[str, Any]] = None, ): """Builds SQIL. @@ -64,75 +43,16 @@ def __init__( demonstrations: Demonstrations to use for training. policy: The policy model to use (SB3). custom_logger: Where to log to; if None (default), creates a new logger. - learning_rate: The learning rate, it can be a function - of the current progress remaining (from 1 to 0). - buffer_size: Size of the replay buffer. - learning_starts: How many steps of the model to collect transitions for - before learning starts. - batch_size: Minibatch size for each gradient update. - tau: The soft update coefficient ("Polyak update", between 0 and 1), - default 1 for hard update. - gamma: The discount factor. - train_freq: Update the model every ``train_freq`` steps. Alternatively - pass a tuple of frequency and unit - like ``(5, "step")`` or ``(2, "episode")``. - gradient_steps: How many gradient steps to do after each - rollout (see ``train_freq``). - Set to ``-1`` means to do as many gradient steps as steps done - in the environment during the rollout. - replay_buffer_class: Replay buffer class to use - (for instance ``HerReplayBuffer``). - If ``None``, it will be automatically selected. - replay_buffer_kwargs: Keyword arguments to pass - to the replay buffer on creation. - optimize_memory_usage: Enable a memory efficient variant of the - replay buffer at a cost of more complexity. - target_update_interval: Update the target network every - ``target_update_interval`` environment steps. - exploration_fraction: Fraction of entire training period over - which the exploration rate is reduced. - exploration_initial_eps: Initial value of random action probability. - exploration_final_eps: Final value of random action probability. - max_grad_norm: The maximum value for the gradient clipping. - tensorboard_log: The log location for tensorboard (if None, no logging). - policy_kwargs: Additional arguments to be passed to the policy on creation. - verbose: Verbosity level: 0 for no output, 1 for info messages - (such as device or wrappers used), 2 for debug messages. - seed: Seed for the pseudo random generators. - device: Device (cpu, cuda, ...) on which the code should be run. - Setting it to auto, the code will be run on the GPU if possible. - _init_setup_model: Whether or not to build the network - at the creation of the instance. - - + dqn_kwargs: Keyword arguments to pass to the DQN constructor. """ self.venv = venv self.dqn = dqn.DQN( policy=policy, env=venv, - learning_rate=learning_rate, - buffer_size=buffer_size, - learning_starts=learning_starts, - batch_size=batch_size, - tau=tau, - gamma=gamma, - train_freq=train_freq, - gradient_steps=gradient_steps, replay_buffer_class=SQILReplayBuffer, replay_buffer_kwargs={"demonstrations": demonstrations}, - optimize_memory_usage=optimize_memory_usage, - target_update_interval=target_update_interval, - exploration_fraction=exploration_fraction, - exploration_initial_eps=exploration_initial_eps, - exploration_final_eps=exploration_final_eps, - max_grad_norm=max_grad_norm, - tensorboard_log=tensorboard_log, - policy_kwargs=policy_kwargs, - verbose=verbose, - seed=seed, - device=device, - _init_setup_model=_init_setup_model, + **(dqn_kwargs or {}), ) super().__init__(demonstrations=demonstrations, custom_logger=custom_logger) @@ -151,23 +71,7 @@ def policy(self) -> policies.BasePolicy: class SQILReplayBuffer(buffers.ReplayBuffer): - """Replay buffer used in off-policy algorithms like SAC/TD3. - - :param buffer_size: Max number of element in the buffer - :param observation_space: Observation space - :param action_space: Action space - :param device: PyTorch device - :param n_envs: Number of parallel environments - :param optimize_memory_usage: Enable a memory efficient variant - of the replay buffer which reduces by almost a factor two the memory used, - at a cost of more complexity. - See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 - Cannot be used in combination with handle_timeout_termination. - :param handle_timeout_termination: Handle timeout termination (due to timelimit) - separately and treat the task as infinite horizon task. - https://github.com/DLR-RM/stable-baselines3/issues/284 - """ + """Replay buffer used in off-policy algorithms like SAC/TD3, modified for SQIL.""" def __init__( self, diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index ad696ee45..7d7f82d21 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -98,7 +98,7 @@ def test_sqil_cartpole_no_crash(rng): venv=venv, demonstrations=demonstrations, policy=policy, - learning_starts=1000, + dqn_kwargs=dict(learning_starts=1000), ) model.train(total_timesteps=10_000) @@ -127,6 +127,6 @@ def test_sqil_cartpole_few_demonstrations(rng): venv=venv, demonstrations=demonstrations, policy=policy, - learning_starts=10, + dqn_kwargs=dict(learning_starts=10), ) model.train(total_timesteps=1_000) From 4e3f156a50d19defd794a3f3c55a39ad9723ba5c Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 15:52:35 +0200 Subject: [PATCH 23/57] Cover another edge case, another test, fixes --- src/imitation/algorithms/sqil.py | 20 ++++++++++------ src/imitation/data/rollout.py | 29 ++++++++++++++++++++++++ tests/algorithms/test_sqil.py | 39 ++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 9bf5a247c..ddf7169cc 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -4,7 +4,7 @@ replacing half the buffer with expert demonstrations and adjusting the rewards. """ -from typing import Any, Dict, Iterable, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union import numpy as np import torch as th @@ -126,7 +126,7 @@ def set_demonstrations( """ # If demonstrations is a list of trajectories, # flatten it into a list of transitions - if isinstance(demonstrations, Iterable): + if not isinstance(demonstrations, types.TransitionsMinimal): ( item, demonstrations, @@ -137,8 +137,14 @@ def set_demonstrations( demonstrations = rollout.flatten_trajectories( demonstrations, # type: ignore[arg-type] ) + else: # item is a TransitionMapping + demonstrations = rollout.flatten_transition_mappings( + demonstrations, # type: ignore[arg-type] + ) + + assert isinstance(demonstrations, types.Transitions) - n_samples = len(demonstrations) # type: ignore[arg-type] + n_samples = len(demonstrations) expert_buffer = buffers.ReplayBuffer( n_samples, self.observation_space, @@ -148,10 +154,10 @@ def set_demonstrations( for transition in demonstrations: expert_buffer.add( - obs=np.array(transition["obs"]), # type: ignore[index] - next_obs=np.array(transition["next_obs"]), # type: ignore[index] - action=np.array(transition["acts"]), # type: ignore[index] - done=np.array(transition["dones"]), # type: ignore[index] + obs=np.array(transition["obs"]), + next_obs=np.array(transition["next_obs"]), + action=np.array(transition["acts"]), + done=np.array(transition["dones"]), reward=np.array(1), infos=[{}], ) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index add281a65..f18c22661 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -23,6 +23,7 @@ from stable_baselines3.common.utils import check_for_correct_spaces from stable_baselines3.common.vec_env import VecEnv +from imitation.algorithms import base as algo_base from imitation.data import types @@ -565,6 +566,34 @@ def flatten_trajectories( return types.Transitions(**cat_parts) +def flatten_transition_mappings( + trajectories: Iterable[algo_base.TransitionMapping], +) -> types.Transitions: + """Flatten a series of transition mappings (e.g. a dataloader) into arrays. + + Args: + trajectories: list of trajectories. + + Returns: + The trajectories flattened into a single batch of Transitions. + """ + keys = ["obs", "next_obs", "acts", "dones", "infos"] + parts: Mapping[str, List[np.ndarray]] = {key: [] for key in keys} + for data in trajectories: + num_steps = len(data["obs"]) + for i in range(num_steps): + parts["obs"].append(data["obs"][i].detach().cpu().numpy()) + parts["next_obs"].append(data["next_obs"][i].detach().cpu().numpy()) + parts["acts"].append(data["acts"][i].detach().cpu().numpy()) + parts["dones"].append(data["dones"][i].detach().cpu().numpy()) + parts["infos"].append(data["infos"][i]) # type: ignore[arg-type] + + cat_parts = {key: np.stack(part_list, axis=0) for key, part_list in parts.items()} + lengths = set(map(len, cat_parts.values())) + assert len(lengths) == 1, f"expected one length, got {lengths}" + return types.Transitions(**cat_parts) + + def flatten_trajectories_with_rew( trajectories: Sequence[types.TrajectoryWithRew], ) -> types.TransitionsWithRew: diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 7d7f82d21..295b5f031 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -6,6 +6,7 @@ import stable_baselines3.common.vec_env as vec_env import stable_baselines3.dqn as dqn +from imitation.algorithms import base as algo_base from imitation.algorithms import sqil from imitation.data import rollout, wrappers @@ -67,6 +68,9 @@ def test_sqil_demonstration_without_flatten(rng): rng=rng, ) + flat_rollouts = rollout.flatten_trajectories(rollouts) + n_samples = len(flat_rollouts) + model = sqil.SQIL( venv=venv, demonstrations=rollouts, @@ -76,6 +80,41 @@ def test_sqil_demonstration_without_flatten(rng): assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) assert isinstance(model.dqn.replay_buffer.expert_buffer, buffers.ReplayBuffer) + assert len(model.dqn.replay_buffer.expert_buffer.observations) == n_samples + + +def test_sqil_demonstration_data_loader(rng): + env = gym.make("CartPole-v1") + venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) + policy = "MlpPolicy" + + sampling_agent = dqn.DQN( + env=env, + policy=policy, + ) + + rollouts = rollout.rollout( + sampling_agent.policy, + venv, + rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + ) + + transition_mappings = algo_base.make_data_loader(rollouts, batch_size=4) + + model = sqil.SQIL( + venv=venv, + demonstrations=transition_mappings, + policy=policy, + ) + + assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) + assert isinstance(model.dqn.replay_buffer.expert_buffer, buffers.ReplayBuffer) + + assert len(model.dqn.replay_buffer.expert_buffer.observations) == sum( + len(traj["obs"]) for traj in transition_mappings + ) + def test_sqil_cartpole_no_crash(rng): env = gym.make("CartPole-v1") From d018cbd00371b3d6bdfdecf7e0388fd3fd4e9cb1 Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 16:13:46 +0200 Subject: [PATCH 24/57] Fix a circular import issue --- src/imitation/data/rollout.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index f18c22661..7b50aefca 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -18,12 +18,12 @@ ) import numpy as np +import torch as th from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.utils import check_for_correct_spaces from stable_baselines3.common.vec_env import VecEnv -from imitation.algorithms import base as algo_base from imitation.data import types @@ -567,7 +567,7 @@ def flatten_trajectories( def flatten_transition_mappings( - trajectories: Iterable[algo_base.TransitionMapping], + trajectories: Iterable[Mapping[str, Union[np.ndarray, th.Tensor]]], ) -> types.Transitions: """Flatten a series of transition mappings (e.g. a dataloader) into arrays. From 29cdbfaa999f09ef2324ba33568b6aa1445d930d Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 17:02:55 +0200 Subject: [PATCH 25/57] Add a performance test - might be slow? --- tests/algorithms/test_sqil.py | 61 +++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 295b5f031..a07fa9181 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -5,10 +5,13 @@ import stable_baselines3.common.buffers as buffers import stable_baselines3.common.vec_env as vec_env import stable_baselines3.dqn as dqn +from stable_baselines3 import ppo +from stable_baselines3.common.evaluation import evaluate_policy from imitation.algorithms import base as algo_base from imitation.algorithms import sqil from imitation.data import rollout, wrappers +from imitation.testing import reward_improvement def test_sqil_demonstration_buffer(rng): @@ -169,3 +172,61 @@ def test_sqil_cartpole_few_demonstrations(rng): dqn_kwargs=dict(learning_starts=10), ) model.train(total_timesteps=1_000) + + +def test_sqil_performance(rng): + env = gym.make("CartPole-v1") + venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) + + expert = ppo.PPO( + policy=ppo.MlpPolicy, + env=env, + seed=0, + batch_size=64, + ent_coef=0.0, + learning_rate=0.0003, + n_epochs=10, + n_steps=64, + ) + expert.learn(10_000) + + expert_reward, _ = evaluate_policy(expert, env, 10) + print(expert_reward) + + rollouts = rollout.rollout( + expert.policy, + venv, + rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + ) + + demonstrations = rollout.flatten_trajectories(rollouts) + demonstrations = demonstrations[:5] + + model = sqil.SQIL( + venv=venv, + demonstrations=demonstrations, + policy="MlpPolicy", + dqn_kwargs=dict(learning_starts=1000), + ) + + rewards_before, _ = evaluate_policy( + model.policy, + env, + 10, + return_episode_rewards=True, + ) + + model.train(total_timesteps=10_000) + + rewards_after, _ = evaluate_policy( + model.policy, + env, + 10, + return_episode_rewards=True, + ) + + assert reward_improvement.is_significant_reward_improvement( + rewards_before, + rewards_after, + ) From 551fa7e361953c3194fecfb6e28bae9b773529fe Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 7 Jul 2023 17:40:55 +0200 Subject: [PATCH 26/57] Fix coverage --- tests/algorithms/test_sqil.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index a07fa9181..ec46fef46 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -6,6 +6,7 @@ import stable_baselines3.common.vec_env as vec_env import stable_baselines3.dqn as dqn from stable_baselines3 import ppo +from stable_baselines3.common import policies from stable_baselines3.common.evaluation import evaluate_policy from imitation.algorithms import base as algo_base @@ -38,6 +39,8 @@ def test_sqil_demonstration_buffer(rng): policy=policy, ) + assert isinstance(model.policy, policies.BasePolicy) + assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) expert_buffer = model.dqn.replay_buffer.expert_buffer From fcd94b95a021c6191e6fa9fe64bb8e68e223c10c Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Fri, 7 Jul 2023 18:17:50 -0700 Subject: [PATCH 27/57] Improve input validation --- src/imitation/algorithms/sqil.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index ddf7169cc..33ea30bf5 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -47,12 +47,28 @@ def __init__( """ self.venv = venv + if dqn_kwargs is None: + dqn_kwargs = {} + # SOMEDAY(adam): we could support users specifying their own replay buffer + # if we made SQILReplayBuffer a more flexible wrapper. Does not seem worth + # the added complexity until we have a concrete use case, however. + if "replay_buffer_class" in dqn_kwargs: + raise ValueError( + "SQIL uses a custom replay buffer: " + "'replay_buffer_class' not allowed." + ) + if "replay_buffer_kwargs" in dqn_kwargs: + raise ValueError( + "SQIL uses a custom replay buffer: " + "'replay_buffer_kwargs' not allowed." + ) + self.dqn = dqn.DQN( policy=policy, env=venv, replay_buffer_class=SQILReplayBuffer, replay_buffer_kwargs={"demonstrations": demonstrations}, - **(dqn_kwargs or {}), + **dqn_kwargs, ) super().__init__(demonstrations=demonstrations, custom_logger=custom_logger) From 34ddf82400718527bdb0c78f2616f27b34eaf26e Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Fri, 7 Jul 2023 18:18:09 -0700 Subject: [PATCH 28/57] Bugfix: have set_demonstrations set rather than return --- src/imitation/algorithms/sqil.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 33ea30bf5..2215a8ae0 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -126,12 +126,12 @@ def __init__( handle_timeout_termination=False, ) - self.expert_buffer = self.set_demonstrations(demonstrations) + self.set_demonstrations(demonstrations) def set_demonstrations( self, demonstrations: algo_base.AnyTransitions, - ) -> buffers.ReplayBuffer: + ) -> None: """Set the demonstrations to be used in the buffer. Args: @@ -161,7 +161,7 @@ def set_demonstrations( assert isinstance(demonstrations, types.Transitions) n_samples = len(demonstrations) - expert_buffer = buffers.ReplayBuffer( + self.expert_buffer = buffers.ReplayBuffer( n_samples, self.observation_space, self.action_space, @@ -178,8 +178,6 @@ def set_demonstrations( infos=[{}], ) - return expert_buffer - def sample( self, batch_size: int, From cf20fbbd2efe8d543f1fe86bbd038d9afab2ad1b Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Sat, 8 Jul 2023 13:44:16 -0700 Subject: [PATCH 29/57] Move TransitionMapping from algorithms.base to data.types --- src/imitation/algorithms/adversarial/common.py | 4 ++-- src/imitation/algorithms/base.py | 9 ++++----- src/imitation/algorithms/bc.py | 12 ++++++------ src/imitation/algorithms/density.py | 2 +- src/imitation/algorithms/sqil.py | 1 - src/imitation/data/rollout.py | 2 +- src/imitation/data/types.py | 1 + 7 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 62b459a0d..48129fa67 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -98,8 +98,8 @@ class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]): If `debug_use_ground_truth=True` was passed into the initializer then `self.venv_train` is the same as `self.venv`.""" - _demo_data_loader: Optional[Iterable[base.TransitionMapping]] - _endless_expert_iterator: Optional[Iterator[base.TransitionMapping]] + _demo_data_loader: Optional[Iterable[types.TransitionMapping]] + _endless_expert_iterator: Optional[Iterator[types.TransitionMapping]] venv_wrapped: vec_env.VecEnvWrapper diff --git a/src/imitation/algorithms/base.py b/src/imitation/algorithms/base.py index 38d003ed0..3c6422ebf 100644 --- a/src/imitation/algorithms/base.py +++ b/src/imitation/algorithms/base.py @@ -123,11 +123,10 @@ def __setstate__(self, state): self.logger = state.get("_logger") or imit_logger.configure() -TransitionMapping = Mapping[str, Union[np.ndarray, th.Tensor]] TransitionKind = TypeVar("TransitionKind", bound=types.TransitionsMinimal) AnyTransitions = Union[ Iterable[types.Trajectory], - Iterable[TransitionMapping], + Iterable[types.TransitionMapping], types.TransitionsMinimal, ] @@ -190,7 +189,7 @@ class _WrappedDataLoader: def __init__( self, - data_loader: Iterable[TransitionMapping], + data_loader: Iterable[types.TransitionMapping], expected_batch_size: int, ): """Builds _WrappedDataLoader. @@ -202,7 +201,7 @@ def __init__( self.data_loader = data_loader self.expected_batch_size = expected_batch_size - def __iter__(self) -> Iterator[TransitionMapping]: + def __iter__(self) -> Iterator[types.TransitionMapping]: """Yields data from `self.data_loader`, checking `self.expected_batch_size`. Yields: @@ -230,7 +229,7 @@ def make_data_loader( transitions: AnyTransitions, batch_size: int, data_loader_kwargs: Optional[Mapping[str, Any]] = None, -) -> Iterable[TransitionMapping]: +) -> Iterable[types.TransitionMapping]: """Converts demonstration data to Torch data loader. Args: diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 5e3ead089..a940d9cd9 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -38,7 +38,7 @@ class BatchIteratorWithEpochEndCallback: Will throw an exception when an epoch contains no batches. """ - batch_loader: Iterable[algo_base.TransitionMapping] + batch_loader: Iterable[types.TransitionMapping] n_epochs: Optional[int] n_batches: Optional[int] on_epoch_end: Optional[Callable[[int], None]] @@ -55,8 +55,8 @@ def __post_init__(self) -> None: "Must provide exactly one of `n_epochs` and `n_batches` arguments.", ) - def __iter__(self) -> Iterator[algo_base.TransitionMapping]: - def batch_iterator() -> Iterator[algo_base.TransitionMapping]: + def __iter__(self) -> Iterator[types.TransitionMapping]: + def batch_iterator() -> Iterator[types.TransitionMapping]: # Note: the islice here ensures we do not exceed self.n_epochs for epoch_num in itertools.islice(itertools.count(), self.n_epochs): @@ -143,8 +143,8 @@ def __call__( def enumerate_batches( - batch_it: Iterable[algo_base.TransitionMapping], -) -> Iterable[Tuple[Tuple[int, int, int], algo_base.TransitionMapping]]: + batch_it: Iterable[types.TransitionMapping], +) -> Iterable[Tuple[Tuple[int, int, int], types.TransitionMapping]]: """Prepends batch stats before the batches of a batch iterator.""" num_samples_so_far = 0 for num_batches, batch in enumerate(batch_it): @@ -308,7 +308,7 @@ def __init__( parameter `l2_weight` instead), or if the batch size is not a multiple of the minibatch size. """ - self._demo_data_loader: Optional[Iterable[algo_base.TransitionMapping]] = None + self._demo_data_loader: Optional[Iterable[types.TransitionMapping]] = None self.batch_size = batch_size self.minibatch_size = minibatch_size or batch_size if self.batch_size % self.minibatch_size != 0: diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index abbf155be..fcc5e5ac9 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -198,7 +198,7 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: transitions.setdefault(i, []).append(flat_trans) elif isinstance(first_item, Mapping): # analogous to cast above. - demonstrations = cast(Iterable[base.TransitionMapping], demonstrations) + demonstrations = cast(Iterable[types.TransitionMapping], demonstrations) for batch in demonstrations: transitions.update( diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 2215a8ae0..df5d558c6 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -195,7 +195,6 @@ def sample( Returns: A batch of samples for DQN - """ new_sample_size, expert_sample_size = util.split_in_half(batch_size) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 7b50aefca..75f312c64 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -567,7 +567,7 @@ def flatten_trajectories( def flatten_transition_mappings( - trajectories: Iterable[Mapping[str, Union[np.ndarray, th.Tensor]]], + trajectories: Iterable[types.TransitionMapping], ) -> types.Transitions: """Flatten a series of transition mappings (e.g. a dataloader) into arrays. diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 26b6bf15c..97d1b950b 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -22,6 +22,7 @@ T = TypeVar("T") AnyPath = Union[str, bytes, os.PathLike] +TransitionMapping = Mapping[str, Union[np.ndarray, th.Tensor]] def dataclass_quick_asdict(obj) -> Dict[str, Any]: From ee1681825041332b483284c7af4529bc7bbaec09 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Sat, 8 Jul 2023 13:48:22 -0700 Subject: [PATCH 30/57] Fix typo: expert_buffer->self.expert_buffer --- src/imitation/algorithms/sqil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index df5d558c6..750989e78 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -169,7 +169,7 @@ def set_demonstrations( ) for transition in demonstrations: - expert_buffer.add( + self.expert_buffer.add( obs=np.array(transition["obs"]), next_obs=np.array(transition["next_obs"]), action=np.array(transition["acts"]), From 87876aaf384a395c975f35f643ca57acb2e21b83 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Sat, 8 Jul 2023 13:51:26 -0700 Subject: [PATCH 31/57] Bugfix: use safe_to_numpy rather than assuming th.Tensor --- src/imitation/data/rollout.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 75f312c64..6e9694cf7 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -25,6 +25,7 @@ from stable_baselines3.common.vec_env import VecEnv from imitation.data import types +from imitation.util import util def unwrap_traj(traj: types.TrajectoryWithRew) -> types.TrajectoryWithRew: @@ -582,10 +583,10 @@ def flatten_transition_mappings( for data in trajectories: num_steps = len(data["obs"]) for i in range(num_steps): - parts["obs"].append(data["obs"][i].detach().cpu().numpy()) - parts["next_obs"].append(data["next_obs"][i].detach().cpu().numpy()) - parts["acts"].append(data["acts"][i].detach().cpu().numpy()) - parts["dones"].append(data["dones"][i].detach().cpu().numpy()) + parts["obs"].append(util.safe_to_numpy(data["obs"][i])) + parts["next_obs"].append(util.safe_to_numpy(data["next_obs"][i])) + parts["acts"].append(util.safe_to_numpy(data["acts"][i])) + parts["dones"].append(util.safe_to_numpy(data["dones"][i])) parts["infos"].append(data["infos"][i]) # type: ignore[arg-type] cat_parts = {key: np.stack(part_list, axis=0) for key, part_list in parts.items()} From 12e30b16e98831518389f6da4fbe43349771b058 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Sat, 8 Jul 2023 13:54:42 -0700 Subject: [PATCH 32/57] Fix lint --- src/imitation/algorithms/sqil.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 750989e78..59ffe042f 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -44,6 +44,10 @@ def __init__( policy: The policy model to use (SB3). custom_logger: Where to log to; if None (default), creates a new logger. dqn_kwargs: Keyword arguments to pass to the DQN constructor. + + Raises: + ValueError: if `dqn_kwargs` includes a key + `replay_buffer_class` or `replay_buffer_kwargs`. """ self.venv = venv @@ -55,12 +59,12 @@ def __init__( if "replay_buffer_class" in dqn_kwargs: raise ValueError( "SQIL uses a custom replay buffer: " - "'replay_buffer_class' not allowed." + "'replay_buffer_class' not allowed.", ) if "replay_buffer_kwargs" in dqn_kwargs: raise ValueError( "SQIL uses a custom replay buffer: " - "'replay_buffer_kwargs' not allowed." + "'replay_buffer_kwargs' not allowed.", ) self.dqn = dqn.DQN( @@ -136,9 +140,6 @@ def set_demonstrations( Args: demonstrations (algo_base.AnyTransitions): Expert demonstrations. - - Returns: - buffers.ReplayBuffer: The buffer with demonstrations added """ # If demonstrations is a list of trajectories, # flatten it into a list of transitions From 90a3a798eca42396a0d1e7ec2df4f9689e88b0d9 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Sat, 8 Jul 2023 14:18:19 -0700 Subject: [PATCH 33/57] Fix unused imports --- src/imitation/algorithms/base.py | 2 -- src/imitation/data/rollout.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/imitation/algorithms/base.py b/src/imitation/algorithms/base.py index 3c6422ebf..fd33c5f40 100644 --- a/src/imitation/algorithms/base.py +++ b/src/imitation/algorithms/base.py @@ -13,8 +13,6 @@ cast, ) -import numpy as np -import torch as th import torch.utils.data as th_data from stable_baselines3.common import policies diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 6e9694cf7..32376cf1a 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -18,7 +18,6 @@ ) import numpy as np -import torch as th from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.utils import check_for_correct_spaces From ef0fd269f710007ac568b810bafc9698f9af48dc Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Sat, 8 Jul 2023 15:04:48 -0700 Subject: [PATCH 34/57] Refactor tests --- src/imitation/algorithms/sqil.py | 2 +- src/imitation/testing/expert_trajectories.py | 12 +- tests/algorithms/test_sqil.py | 217 +++++-------------- 3 files changed, 66 insertions(+), 165 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 59ffe042f..4bbfcd871 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -18,7 +18,7 @@ from imitation.util import logger, util -class SQIL(algo_base.DemonstrationAlgorithm): +class SQIL(algo_base.DemonstrationAlgorithm[types.Transitions]): """Soft Q Imitation Learning (SQIL). Trains a policy via DQN-style Q-learning, diff --git a/src/imitation/testing/expert_trajectories.py b/src/imitation/testing/expert_trajectories.py index dc640b8c1..220f830e4 100644 --- a/src/imitation/testing/expert_trajectories.py +++ b/src/imitation/testing/expert_trajectories.py @@ -119,17 +119,19 @@ def make_expert_transition_loader( env_name: str, rng: np.random.Generator, num_trajectories: int = 1, + shuffle: bool = True, ): """Creates different kinds of PyTorch data loaders for expert transitions. Args: cache_dir: The directory to use for caching the expert trajectories. batch_size: The batch size to use for the data loader. - expert_data_type: The type of expert data to use. Can be one of "data_loader", - "ducktyped_data_loader", "transitions". + expert_data_type: The type of expert data to use. Can be one of "trajectories", + "data_loader", "ducktyped_data_loader", "transitions". env_name: The environment to generate trajectories for. rng: The random number generator to use. num_trajectories: The number of trajectories to generate. + shuffle: Whether to shuffle the dataset when creating a data loader. Raises: ValueError: If `expert_data_type` is not one of the supported types. @@ -143,6 +145,10 @@ def make_expert_transition_loader( num_trajectories, rng, ) + + if expert_data_type == "trajectories": + return trajectories + transitions = rollout.flatten_trajectories(trajectories) if len(transitions) < batch_size: # pragma: no cover @@ -163,7 +169,7 @@ def make_expert_transition_loader( return th_data.DataLoader( transitions, batch_size=batch_size, - shuffle=True, + shuffle=shuffle, drop_last=True, collate_fn=types.transitions_collate_fn, ) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index ec46fef46..82374cfa0 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -1,53 +1,56 @@ """Tests `imitation.algorithms.sqil`.""" -import gym import numpy as np -import stable_baselines3.common.buffers as buffers -import stable_baselines3.common.vec_env as vec_env -import stable_baselines3.dqn as dqn -from stable_baselines3 import ppo -from stable_baselines3.common import policies +import pytest +from stable_baselines3.common import policies, vec_env from stable_baselines3.common.evaluation import evaluate_policy -from imitation.algorithms import base as algo_base from imitation.algorithms import sqil -from imitation.data import rollout, wrappers -from imitation.testing import reward_improvement +from imitation.testing import expert_trajectories, reward_improvement +EXPERT_DATA_TYPES = ["trajectories", "data_loader", "transitions"] -def test_sqil_demonstration_buffer(rng): - env = gym.make("CartPole-v1") - venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) - policy = "MlpPolicy" - sampling_agent = dqn.DQN( - env=env, - policy=policy, - ) - - rollouts = rollout.rollout( - sampling_agent.policy, - venv, - rollout.make_sample_until(min_timesteps=None, min_episodes=50), +def get_demos(rng: np.random.Generator, pytestconfig: pytest.Config, data_type): + cache = pytestconfig.cache + assert cache is not None + return expert_trajectories.make_expert_transition_loader( + cache_dir=cache.mkdir("experts"), + batch_size=4, + expert_data_type=data_type, + env_name="seals/CartPole-v0", rng=rng, + num_trajectories=60, + shuffle=False, ) - demonstrations = rollout.flatten_trajectories(rollouts) + +@pytest.mark.parametrize("expert_data_type", EXPERT_DATA_TYPES) +def test_sqil_demonstration_buffer( + rng: np.random.Generator, + pytestconfig: pytest.Config, + cartpole_venv: vec_env.VecEnv, + expert_data_type: str, +): + policy = "MlpPolicy" model = sqil.SQIL( - venv=venv, - demonstrations=demonstrations, + venv=cartpole_venv, + demonstrations=get_demos(rng, pytestconfig, expert_data_type), policy=policy, ) assert isinstance(model.policy, policies.BasePolicy) - assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) expert_buffer = model.dqn.replay_buffer.expert_buffer # Check that demonstrations are stored in the replay buffer correctly - for i in range(len(demonstrations)): + demonstrations = get_demos(rng, pytestconfig, "transitions") + n_samples = len(demonstrations) + assert len(model.dqn.replay_buffer.expert_buffer.observations) == n_samples + for i in range(n_samples): obs = expert_buffer.observations[i] act = expert_buffer.actions[i] + assert expert_buffer.next_observations is not None next_obs = expert_buffer.next_observations[i] done = expert_buffer.dones[i] @@ -57,119 +60,32 @@ def test_sqil_demonstration_buffer(rng): np.testing.assert_array_equal(done, demonstrations.dones[i]) -def test_sqil_demonstration_without_flatten(rng): - env = gym.make("CartPole-v1") - venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) - policy = "MlpPolicy" - - sampling_agent = dqn.DQN( - env=env, - policy=policy, - ) - - rollouts = rollout.rollout( - sampling_agent.policy, - venv, - rollout.make_sample_until(min_timesteps=None, min_episodes=50), - rng=rng, - ) - - flat_rollouts = rollout.flatten_trajectories(rollouts) - n_samples = len(flat_rollouts) - - model = sqil.SQIL( - venv=venv, - demonstrations=rollouts, - policy=policy, - ) - - assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) - assert isinstance(model.dqn.replay_buffer.expert_buffer, buffers.ReplayBuffer) - - assert len(model.dqn.replay_buffer.expert_buffer.observations) == n_samples - - -def test_sqil_demonstration_data_loader(rng): - env = gym.make("CartPole-v1") - venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) - policy = "MlpPolicy" - - sampling_agent = dqn.DQN( - env=env, - policy=policy, - ) - - rollouts = rollout.rollout( - sampling_agent.policy, - venv, - rollout.make_sample_until(min_timesteps=None, min_episodes=50), - rng=rng, - ) - - transition_mappings = algo_base.make_data_loader(rollouts, batch_size=4) - - model = sqil.SQIL( - venv=venv, - demonstrations=transition_mappings, - policy=policy, - ) - - assert isinstance(model.dqn.replay_buffer, sqil.SQILReplayBuffer) - assert isinstance(model.dqn.replay_buffer.expert_buffer, buffers.ReplayBuffer) - - assert len(model.dqn.replay_buffer.expert_buffer.observations) == sum( - len(traj["obs"]) for traj in transition_mappings - ) - - -def test_sqil_cartpole_no_crash(rng): - env = gym.make("CartPole-v1") - venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) - +def test_sqil_cartpole_no_crash( + rng: np.random.Generator, + pytestconfig: pytest.Config, + cartpole_venv: vec_env.VecEnv, +): policy = "MlpPolicy" - sampling_agent = dqn.DQN( - env=env, - policy=policy, - ) - - rollouts = rollout.rollout( - sampling_agent.policy, - venv, - rollout.make_sample_until(min_timesteps=None, min_episodes=50), - rng=rng, - ) - demonstrations = rollout.flatten_trajectories(rollouts) model = sqil.SQIL( - venv=venv, - demonstrations=demonstrations, + venv=cartpole_venv, + demonstrations=get_demos(rng, pytestconfig, "transitions"), policy=policy, dqn_kwargs=dict(learning_starts=1000), ) model.train(total_timesteps=10_000) -def test_sqil_cartpole_few_demonstrations(rng): - env = gym.make("CartPole-v1") - venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) - - policy = "MlpPolicy" - sampling_agent = dqn.DQN( - env=env, - policy=policy, - ) - - rollouts = rollout.rollout( - sampling_agent.policy, - venv, - rollout.make_sample_until(min_timesteps=None, min_episodes=1), - rng=rng, - ) - - demonstrations = rollout.flatten_trajectories(rollouts) +def test_sqil_cartpole_few_demonstrations( + rng: np.random.Generator, + pytestconfig: pytest.Config, + cartpole_venv: vec_env.VecEnv, +): + demonstrations = get_demos(rng, pytestconfig, "transitions") demonstrations = demonstrations[:5] + policy = "MlpPolicy" model = sqil.SQIL( - venv=venv, + venv=cartpole_venv, demonstrations=demonstrations, policy=policy, dqn_kwargs=dict(learning_starts=10), @@ -177,37 +93,16 @@ def test_sqil_cartpole_few_demonstrations(rng): model.train(total_timesteps=1_000) -def test_sqil_performance(rng): - env = gym.make("CartPole-v1") - venv = vec_env.DummyVecEnv([lambda: wrappers.RolloutInfoWrapper(env)]) - - expert = ppo.PPO( - policy=ppo.MlpPolicy, - env=env, - seed=0, - batch_size=64, - ent_coef=0.0, - learning_rate=0.0003, - n_epochs=10, - n_steps=64, - ) - expert.learn(10_000) - - expert_reward, _ = evaluate_policy(expert, env, 10) - print(expert_reward) - - rollouts = rollout.rollout( - expert.policy, - venv, - rollout.make_sample_until(min_timesteps=None, min_episodes=50), - rng=rng, - ) - - demonstrations = rollout.flatten_trajectories(rollouts) +def test_sqil_performance( + rng: np.random.Generator, + pytestconfig: pytest.Config, + cartpole_venv: vec_env.VecEnv, +): + demonstrations = get_demos(rng, pytestconfig, "transitions") demonstrations = demonstrations[:5] model = sqil.SQIL( - venv=venv, + venv=cartpole_venv, demonstrations=demonstrations, policy="MlpPolicy", dqn_kwargs=dict(learning_starts=1000), @@ -215,7 +110,7 @@ def test_sqil_performance(rng): rewards_before, _ = evaluate_policy( model.policy, - env, + cartpole_venv, 10, return_episode_rewards=True, ) @@ -224,12 +119,12 @@ def test_sqil_performance(rng): rewards_after, _ = evaluate_policy( model.policy, - env, + cartpole_venv, 10, return_episode_rewards=True, ) assert reward_improvement.is_significant_reward_improvement( - rewards_before, - rewards_after, + rewards_before, # type:ignore[arg-type] + rewards_after, # type:ignore[arg-type] ) From 34241b2aabe78b961abc309024c6021c293d2266 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Sat, 8 Jul 2023 18:01:12 -0700 Subject: [PATCH 35/57] Bump # of rollouts to try to fix MacOS flakiness --- tests/algorithms/test_sqil.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 82374cfa0..81bca0259 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -111,7 +111,7 @@ def test_sqil_performance( rewards_before, _ = evaluate_policy( model.policy, cartpole_venv, - 10, + 20, return_episode_rewards=True, ) @@ -120,7 +120,7 @@ def test_sqil_performance( rewards_after, _ = evaluate_policy( model.policy, cartpole_venv, - 10, + 20, return_episode_rewards=True, ) From c8e9df8ee9d7904d8ca2417ff8600c46e92a63f7 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 10:23:34 +0200 Subject: [PATCH 36/57] Simplify SQIL example and tutorial by 1. downloading expert trajectories instead of training an expert and sampling from the expert and 2. passing trajectories instead of transitions to SQIL. --- docs/algorithms/sqil.rst | 30 +++------ docs/tutorials/8_train_sqil.ipynb | 106 ++++++++---------------------- 2 files changed, 35 insertions(+), 101 deletions(-) diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst index 1c794d515..0f526f745 100644 --- a/docs/algorithms/sqil.rst +++ b/docs/algorithms/sqil.rst @@ -20,37 +20,25 @@ Detailed example notebook: :doc:`../tutorials/8_train_sqil` .. testcode:: :skipif: skip_doctests - import numpy as np + import datasets import gym - from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv - from stable_baselines3.ppo import MlpPolicy from imitation.algorithms import sqil - from imitation.data import rollout - from imitation.data.wrappers import RolloutInfoWrapper + from imitation.data import huggingface_utils - rng = np.random.default_rng(0) - env = gym.make("CartPole-v1") - expert = PPO(policy=MlpPolicy, env=env) - expert.learn(1000) - - rollouts = rollout.rollout( - expert, - DummyVecEnv([lambda: RolloutInfoWrapper(env)]), - rollout.make_sample_until(min_timesteps=None, min_episodes=50), - rng=rng, - ) - transitions = rollout.flatten_trajectories(rollouts) + # Download some expert trajectories from the HuggingFace Datasets Hub. + dataset = datasets.load_dataset("HumanCompatibleAI/ppo-seals-CartPole-v0") + rollouts = huggingface_utils.TrajectoryDatasetSequence(dataset["train"]) sqil_trainer = sqil.SQIL( - venv=DummyVecEnv([lambda: env]), - demonstrations=transitions, + venv=DummyVecEnv([lambda: gym.make("seals:seals/CartPole-v0")]), + demonstrations=rollouts, policy="MlpPolicy", ) - sqil_trainer.train(total_timesteps=1000) - reward, _ = evaluate_policy(sqil_trainer.policy, env, 10) + sqil_trainer.train(total_timesteps=500000) + reward, _ = evaluate_policy(sqil_trainer.policy, sqil_trainer.venv, 10) print("Reward:", reward) .. testoutput:: diff --git a/docs/tutorials/8_train_sqil.ipynb b/docs/tutorials/8_train_sqil.ipynb index 1ada90a69..e00389c9e 100644 --- a/docs/tutorials/8_train_sqil.ipynb +++ b/docs/tutorials/8_train_sqil.ipynb @@ -19,9 +19,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "First, we need an expert in CartPole-v1 so that we can sample expert trajectories.\n", - "Let's train one using stable-baselines3.\n", - "\n", + "First, we need some expert trajectories in our environment (`seals/CartPole-v0`).\n", "Note that you can use other environments, but the action space must be discrete for this algorithm." ] }, @@ -31,53 +29,27 @@ "metadata": {}, "outputs": [], "source": [ - "import gym\n", - "from stable_baselines3 import PPO\n", - "from stable_baselines3.ppo import MlpPolicy\n", + "import datasets\n", + "from stable_baselines3.common.vec_env import DummyVecEnv\n", "\n", - "env = gym.make(\"CartPole-v1\")\n", - "expert = PPO(\n", - " policy=MlpPolicy,\n", - " env=env,\n", - " seed=0,\n", - " batch_size=64,\n", - " ent_coef=0.0,\n", - " learning_rate=0.0003,\n", - " n_epochs=10,\n", - " n_steps=64,\n", - ")\n", - "expert.learn(1_000) # Note: set to 100_000 to train a proficient expert" + "from imitation.data import huggingface_utils\n", + "\n", + "# Download some expert trajectories from the HuggingFace Datasets Hub.\n", + "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-seals-CartPole-v0\")\n", + "\n", + "# Convert the dataset to a format usable by the imitation library.\n", + "expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])" ] }, { "cell_type": "markdown", - "metadata": {}, "source": [ "Let's quickly check if the expert is any good.\n", "We usually should be able to reach a reward of 500, which is the maximum achievable value." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from stable_baselines3.common.evaluation import evaluate_policy\n", - "\n", - "reward, _ = evaluate_policy(expert, env, 10)\n", - "print(reward)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can use the expert to sample some trajectories.\n", - "We flatten them right away since we only need individual transitions.\n", - "`imitation` comes with a number of helper functions that makes collecting those transitions really easy. First we collect 50 episode rollouts, then we flatten them to just the transitions that we need for training.\n", - "Note that the rollout function requires a vectorized environment and needs the `RolloutInfoWrapper` around each of the environments." - ] + ], + "metadata": { + "collapsed": false + } }, { "cell_type": "code", @@ -86,47 +58,18 @@ "outputs": [], "source": [ "from imitation.data import rollout\n", - "from imitation.data.wrappers import RolloutInfoWrapper\n", - "from stable_baselines3.common.vec_env import DummyVecEnv\n", - "import numpy as np\n", + "trajectory_stats = rollout.rollout_stats(expert_trajectories)\n", "\n", - "venv = DummyVecEnv([lambda: RolloutInfoWrapper(env)])\n", - "rng = np.random.default_rng()\n", - "rollouts = rollout.rollout(\n", - " expert,\n", - " venv,\n", - " rollout.make_sample_until(min_timesteps=None, min_episodes=100),\n", - " rng=rng,\n", - ")\n", - "transitions = rollout.flatten_trajectories(rollouts)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's have a quick look at what we just generated using those library functions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\n", - " f\"\"\"The `rollout` function generated a list of {len(rollouts)} {type(rollouts[0])}.\n", - "After flattening, this list is turned into a {type(transitions)} object containing {len(transitions)} transitions.\n", - "The transitions object contains arrays for: {', '.join(transitions.__dict__.keys())}.\"\n", - "\"\"\"\n", - ")" + "print(f\"We have {trajectory_stats['n_traj']} trajectories.\"\n", + " f\"The average length of each trajectory is {trajectory_stats['len_mean']}.\"\n", + " f\"The average return of each trajectory is {trajectory_stats['return_mean']}.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "After we collected our transitions, it's time to set up our behavior cloning algorithm." + "After we collected our expert trajectories, it's time to set up our behavior cloning algorithm." ] }, { @@ -136,10 +79,12 @@ "outputs": [], "source": [ "from imitation.algorithms import sqil\n", + "import gym\n", "\n", + "venv = DummyVecEnv([lambda: gym.make(\"seals:seals/CartPole-v0\")])\n", "sqil_trainer = sqil.SQIL(\n", " venv=venv,\n", - " demonstrations=transitions,\n", + " demonstrations=expert_trajectories,\n", " policy=\"MlpPolicy\",\n", ")" ] @@ -157,7 +102,8 @@ "metadata": {}, "outputs": [], "source": [ - "reward_before_training, _ = evaluate_policy(sqil_trainer.policy, env, 10)\n", + "from stable_baselines3.common.evaluation import evaluate_policy\n", + "reward_before_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n", "print(f\"Reward before training: {reward_before_training}\")" ] }, @@ -175,9 +121,9 @@ "outputs": [], "source": [ "sqil_trainer.train(\n", - " total_timesteps=1_000\n", + " total_timesteps=1_000,\n", ") # Note: set to 1_000_000 to obtain good results\n", - "reward_after_training, _ = evaluate_policy(sqil_trainer.policy, env, 10)\n", + "reward_after_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n", "print(f\"Reward after training: {reward_after_training}\")" ] } From e4e5d9fb79c356459558e891f27bdac2eaefbbff Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 10:42:51 +0200 Subject: [PATCH 37/57] Improve docstring of SQILReplayBuffer. --- src/imitation/algorithms/sqil.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 4bbfcd871..3f4d40aa6 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -91,7 +91,16 @@ def policy(self) -> policies.BasePolicy: class SQILReplayBuffer(buffers.ReplayBuffer): - """Replay buffer used in off-policy algorithms like SAC/TD3, modified for SQIL.""" + """A replay buffer that injects 50% expert demonstrations when sampling. + + This buffer is fundamentally the same as ReplayBuffer, + but it includes an expert demonstration internal buffer. + When sampling a batch of data, it will be 50/50 expert and collected data. + + It can be used in off-policy algorithms like DQN/SAC/TD3. + + Here it is used as part of SQIL, where it is used to train a DQN. + """ def __init__( self, @@ -103,14 +112,10 @@ def __init__( n_envs: int = 1, optimize_memory_usage: bool = False, ): - """A modification of the SB3 ReplayBuffer. - - This buffer is fundamentally the same as ReplayBuffer, - but it includes an expert demonstration internal buffer. - When sampling a batch of data, it will be 50/50 expert and collected data. + """Create a SQILReplayBuffer instance. Args: - buffer_size: Max number of element in the buffer + buffer_size: Max number of elements in the buffer observation_space: Observation space action_space: Action space demonstrations: Expert demonstrations. @@ -136,7 +141,7 @@ def set_demonstrations( self, demonstrations: algo_base.AnyTransitions, ) -> None: - """Set the demonstrations to be used in the buffer. + """Set the expert demonstrations to be injected when sampling from the buffer. Args: demonstrations (algo_base.AnyTransitions): Expert demonstrations. From b89e5d8061a65f941e21b8c39c822578023226c3 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 15:42:00 +0200 Subject: [PATCH 38/57] Set the expert_buffer in the constructor. --- src/imitation/algorithms/sqil.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 3f4d40aa6..4f3ed5a16 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -135,6 +135,7 @@ def __init__( handle_timeout_termination=False, ) + self.expert_buffer = buffers.ReplayBuffer(0, observation_space, action_space) self.set_demonstrations(demonstrations) def set_demonstrations( From c7723e5b148e88d70b5b4fe6276dfef2f32511e6 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 15:45:12 +0200 Subject: [PATCH 39/57] Consistently set expert transition reward to 1 and learner transition reward to 0 when adding them to the SQILReplayBuffer instead of modifying them on-the-fly when sampling. --- src/imitation/algorithms/sqil.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 4f3ed5a16..f55c272d6 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -4,7 +4,7 @@ replacing half the buffer with expert demonstrations and adjusting the rewards. """ -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union, List import numpy as np import torch as th @@ -177,14 +177,25 @@ def set_demonstrations( for transition in demonstrations: self.expert_buffer.add( - obs=np.array(transition["obs"]), - next_obs=np.array(transition["next_obs"]), - action=np.array(transition["acts"]), - done=np.array(transition["dones"]), - reward=np.array(1), + obs=transition["obs"], + next_obs=transition["next_obs"], + action=transition["acts"], + done=transition["dones"], + reward=1, infos=[{}], ) + def add( + self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + super().add(obs, next_obs, action, 0, done, infos) + def sample( self, batch_size: int, @@ -204,18 +215,12 @@ def sample( A batch of samples for DQN """ new_sample_size, expert_sample_size = util.split_in_half(batch_size) - new_sample = super().sample(new_sample_size, env) - new_sample.rewards.fill_(0) - expert_sample = self.expert_buffer.sample(expert_sample_size, env) - expert_sample.rewards.fill_(1) - replay_data = type_aliases.ReplayBufferSamples( + return type_aliases.ReplayBufferSamples( *( th.cat((getattr(new_sample, name), getattr(expert_sample, name))) for name in new_sample._fields ), ) - - return replay_data From e0bc16d6b5c07fbe70b97b86ba5c40776fa1f2c6 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 15:46:28 +0200 Subject: [PATCH 40/57] Fix docstring of SQILReplayBuffer.sample() --- src/imitation/algorithms/sqil.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index f55c272d6..87f7fb62e 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -203,16 +203,16 @@ def sample( ) -> buffers.ReplayBufferSamples: """Sample a batch of data. - Half of the batch will be from the expert buffer, - and the other half will be from the collected data. + Half of the batch will be from expert transitions, + and the other half will be from the learner transitions. Args: - batch_size: Number of element to sample in total + batch_size: Number of elements to sample in total env: associated gym VecEnv to normalize the observations/rewards when sampling Returns: - A batch of samples for DQN + A mix of transitions from the expert and from the learner. """ new_sample_size, expert_sample_size = util.split_in_half(batch_size) new_sample = super().sample(new_sample_size, env) From 203c89feefcfa196041fe2afc6ac5e4460ad38fb Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 16:52:00 +0200 Subject: [PATCH 41/57] Switch back to the CartPole-v1 environment in the SQIL examples --- docs/algorithms/sqil.rst | 4 ++-- docs/tutorials/8_train_sqil.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst index 0f526f745..a17fe6527 100644 --- a/docs/algorithms/sqil.rst +++ b/docs/algorithms/sqil.rst @@ -29,11 +29,11 @@ Detailed example notebook: :doc:`../tutorials/8_train_sqil` from imitation.data import huggingface_utils # Download some expert trajectories from the HuggingFace Datasets Hub. - dataset = datasets.load_dataset("HumanCompatibleAI/ppo-seals-CartPole-v0") + dataset = datasets.load_dataset("HumanCompatibleAI/ppo-CartPole-v1") rollouts = huggingface_utils.TrajectoryDatasetSequence(dataset["train"]) sqil_trainer = sqil.SQIL( - venv=DummyVecEnv([lambda: gym.make("seals:seals/CartPole-v0")]), + venv=DummyVecEnv([lambda: gym.make("CartPole-v1")]), demonstrations=rollouts, policy="MlpPolicy", ) diff --git a/docs/tutorials/8_train_sqil.ipynb b/docs/tutorials/8_train_sqil.ipynb index e00389c9e..8d97b6371 100644 --- a/docs/tutorials/8_train_sqil.ipynb +++ b/docs/tutorials/8_train_sqil.ipynb @@ -35,7 +35,7 @@ "from imitation.data import huggingface_utils\n", "\n", "# Download some expert trajectories from the HuggingFace Datasets Hub.\n", - "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-seals-CartPole-v0\")\n", + "dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-CartPole-v1\")\n", "\n", "# Convert the dataset to a format usable by the imitation library.\n", "expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])" @@ -81,7 +81,7 @@ "from imitation.algorithms import sqil\n", "import gym\n", "\n", - "venv = DummyVecEnv([lambda: gym.make(\"seals:seals/CartPole-v0\")])\n", + "venv = DummyVecEnv([lambda: gym.make(\"CartPole-v1\")])\n", "sqil_trainer = sqil.SQIL(\n", " venv=venv,\n", " demonstrations=expert_trajectories,\n", From c14938535bbce879e573af231a43536197d82e7f Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 16:52:28 +0200 Subject: [PATCH 42/57] Only train for 1k steps in the SQIL example so the doctests don't run for too long. --- docs/algorithms/sqil.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst index a17fe6527..f28098506 100644 --- a/docs/algorithms/sqil.rst +++ b/docs/algorithms/sqil.rst @@ -37,7 +37,8 @@ Detailed example notebook: :doc:`../tutorials/8_train_sqil` demonstrations=rollouts, policy="MlpPolicy", ) - sqil_trainer.train(total_timesteps=500000) + # Hint: set to 1_000_000 to match the expert performance. + sqil_trainer.train(total_timesteps=1_000) reward, _ = evaluate_policy(sqil_trainer.policy, sqil_trainer.venv, 10) print("Reward:", reward) From 18a66227182b59380ff28a90dfd0fb9c157037fe Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 17:04:58 +0200 Subject: [PATCH 43/57] Fix cell metadata for tutorial notebook. --- docs/tutorials/8_train_sqil.ipynb | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/tutorials/8_train_sqil.ipynb b/docs/tutorials/8_train_sqil.ipynb index 8d97b6371..8dbb34b5e 100644 --- a/docs/tutorials/8_train_sqil.ipynb +++ b/docs/tutorials/8_train_sqil.ipynb @@ -43,13 +43,11 @@ }, { "cell_type": "markdown", + "metadata": {}, "source": [ "Let's quickly check if the expert is any good.\n", "We usually should be able to reach a reward of 500, which is the maximum achievable value." - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", @@ -133,7 +131,7 @@ "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f" }, "kernelspec": { - "display_name": "Python 3.8.10 64-bit ('venv': venv)", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, From 9c5b91ca47a8363678d4fd7e32a1e769e80ca953 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 17:15:16 +0200 Subject: [PATCH 44/57] Notebook formatting fixes. --- docs/tutorials/8_train_sqil.ipynb | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/8_train_sqil.ipynb b/docs/tutorials/8_train_sqil.ipynb index 8dbb34b5e..22f0899f3 100644 --- a/docs/tutorials/8_train_sqil.ipynb +++ b/docs/tutorials/8_train_sqil.ipynb @@ -56,11 +56,14 @@ "outputs": [], "source": [ "from imitation.data import rollout\n", + "\n", "trajectory_stats = rollout.rollout_stats(expert_trajectories)\n", "\n", - "print(f\"We have {trajectory_stats['n_traj']} trajectories.\"\n", - " f\"The average length of each trajectory is {trajectory_stats['len_mean']}.\"\n", - " f\"The average return of each trajectory is {trajectory_stats['return_mean']}.\")" + "print(\n", + " f\"We have {trajectory_stats['n_traj']} trajectories.\"\n", + " f\"The average length of each trajectory is {trajectory_stats['len_mean']}.\"\n", + " f\"The average return of each trajectory is {trajectory_stats['return_mean']}.\"\n", + ")" ] }, { @@ -101,6 +104,7 @@ "outputs": [], "source": [ "from stable_baselines3.common.evaluation import evaluate_policy\n", + "\n", "reward_before_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n", "print(f\"Reward before training: {reward_before_training}\")" ] From f8584c3aec81676ea728a27a85ea2e57accb2bea Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 17:17:23 +0200 Subject: [PATCH 45/57] Fix typing error in SQIL implementation. --- src/imitation/algorithms/sqil.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 87f7fb62e..712d9f3f5 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -181,7 +181,7 @@ def set_demonstrations( next_obs=transition["next_obs"], action=transition["acts"], done=transition["dones"], - reward=1, + reward=np.array(1.0), infos=[{}], ) @@ -194,7 +194,7 @@ def add( done: np.ndarray, infos: List[Dict[str, Any]], ) -> None: - super().add(obs, next_obs, action, 0, done, infos) + super().add(obs, next_obs, action, np.array(0.0), done, infos) def sample( self, From 02f3191df40058292abac1f5ee38fc611f3e437f Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Jul 2023 19:08:32 +0200 Subject: [PATCH 46/57] Fix isort issue. --- src/imitation/algorithms/sqil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 712d9f3f5..ac037b113 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -4,7 +4,7 @@ replacing half the buffer with expert demonstrations and adjusting the rewards. """ -from typing import Any, Dict, Optional, Type, Union, List +from typing import Any, Dict, List, Optional, Type, Union import numpy as np import torch as th From 649de468ebe86fc98939440dbf287d0a97a1527f Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 19 Jul 2023 16:10:59 +0200 Subject: [PATCH 47/57] Clarify that our variant of the SQIL implementation is not really "soft". --- docs/algorithms/sqil.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/algorithms/sqil.rst b/docs/algorithms/sqil.rst index f28098506..bde587ee5 100644 --- a/docs/algorithms/sqil.rst +++ b/docs/algorithms/sqil.rst @@ -12,6 +12,13 @@ environment is assigned a reward of 0. This encourages the policy to imitate the demonstrations, and to simultaneously avoid states not seen in the demonstrations. +.. note:: + + This implementation is based on the DQN implementation in Stable Baselines 3, + which does not implement the soft Q-learning and therefore does not support + continuous actions. Therefore, this implementation only supports discrete actions + and the name "soft" Q-learning could be misleading. + Example ======= From c72b088f6683f8d9755ef438bedf16abe571d512 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 19 Jul 2023 17:11:53 +0200 Subject: [PATCH 48/57] Fix link in experts documentation. --- docs/main-concepts/experts.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/main-concepts/experts.rst b/docs/main-concepts/experts.rst index aa9caeb92..cda38cd3f 100644 --- a/docs/main-concepts/experts.rst +++ b/docs/main-concepts/experts.rst @@ -12,7 +12,7 @@ learning library. For example, BC and DAgger can learn from an expert policy and the command line interface of AIRL/GAIL allows one to specify an expert to sample demonstrations from. -In the :doc:`../getting-started/first-steps` tutorial, we first train an expert policy +In the :doc:`../getting-started/first_steps` tutorial, we first train an expert policy using the stable-baselines3 library and then imitate it's behavior using :doc:`../algorithms/bc`. In practice, you may want to load a pre-trained policy for performance reasons. From 8277a5c3ef6bddaebeb6fe3e17fc4d7af277fe29 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 19 Jul 2023 17:37:36 +0200 Subject: [PATCH 49/57] Remove support for transition mappings. --- src/imitation/algorithms/sqil.py | 15 +++++++++------ src/imitation/data/rollout.py | 29 ----------------------------- 2 files changed, 9 insertions(+), 35 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index ac037b113..a334de894 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -146,10 +146,14 @@ def set_demonstrations( Args: demonstrations (algo_base.AnyTransitions): Expert demonstrations. + + Raises: + NotImplementedError: If `demonstrations` is not a transitions object + or a list of trajectories. """ # If demonstrations is a list of trajectories, # flatten it into a list of transitions - if not isinstance(demonstrations, types.TransitionsMinimal): + if not isinstance(demonstrations, types.Transitions): ( item, demonstrations, @@ -160,12 +164,11 @@ def set_demonstrations( demonstrations = rollout.flatten_trajectories( demonstrations, # type: ignore[arg-type] ) - else: # item is a TransitionMapping - demonstrations = rollout.flatten_transition_mappings( - demonstrations, # type: ignore[arg-type] - ) - assert isinstance(demonstrations, types.Transitions) + if not isinstance(demonstrations, types.Transitions): + raise NotImplementedError( + f"Unsupported demonstrations type: {demonstrations}", + ) n_samples = len(demonstrations) self.expert_buffer = buffers.ReplayBuffer( diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 32376cf1a..add281a65 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -24,7 +24,6 @@ from stable_baselines3.common.vec_env import VecEnv from imitation.data import types -from imitation.util import util def unwrap_traj(traj: types.TrajectoryWithRew) -> types.TrajectoryWithRew: @@ -566,34 +565,6 @@ def flatten_trajectories( return types.Transitions(**cat_parts) -def flatten_transition_mappings( - trajectories: Iterable[types.TransitionMapping], -) -> types.Transitions: - """Flatten a series of transition mappings (e.g. a dataloader) into arrays. - - Args: - trajectories: list of trajectories. - - Returns: - The trajectories flattened into a single batch of Transitions. - """ - keys = ["obs", "next_obs", "acts", "dones", "infos"] - parts: Mapping[str, List[np.ndarray]] = {key: [] for key in keys} - for data in trajectories: - num_steps = len(data["obs"]) - for i in range(num_steps): - parts["obs"].append(util.safe_to_numpy(data["obs"][i])) - parts["next_obs"].append(util.safe_to_numpy(data["next_obs"][i])) - parts["acts"].append(util.safe_to_numpy(data["acts"][i])) - parts["dones"].append(util.safe_to_numpy(data["dones"][i])) - parts["infos"].append(data["infos"][i]) # type: ignore[arg-type] - - cat_parts = {key: np.stack(part_list, axis=0) for key, part_list in parts.items()} - lengths = set(map(len, cat_parts.values())) - assert len(lengths) == 1, f"expected one length, got {lengths}" - return types.Transitions(**cat_parts) - - def flatten_trajectories_with_rew( trajectories: Sequence[types.TrajectoryWithRew], ) -> types.TransitionsWithRew: From a0af5c5dbaf788ec87f94def79bccc6e3e8842b9 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 20 Jul 2023 12:30:57 +0200 Subject: [PATCH 50/57] Remove data_loader from SQIL test cases. --- tests/algorithms/test_sqil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 81bca0259..fd652b337 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -8,7 +8,7 @@ from imitation.algorithms import sqil from imitation.testing import expert_trajectories, reward_improvement -EXPERT_DATA_TYPES = ["trajectories", "data_loader", "transitions"] +EXPERT_DATA_TYPES = ["trajectories", "transitions"] def get_demos(rng: np.random.Generator, pytestconfig: pytest.Config, data_type): From 4ccea30095d3d0fdde7e4a59c7166bc2fa757afe Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 21 Jul 2023 14:50:20 +0200 Subject: [PATCH 51/57] Bump number of demonstrations in SQIL performance test to reduce flakiness. --- tests/algorithms/test_sqil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index fd652b337..6d8c248a7 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -99,7 +99,7 @@ def test_sqil_performance( cartpole_venv: vec_env.VecEnv, ): demonstrations = get_demos(rng, pytestconfig, "transitions") - demonstrations = demonstrations[:5] + demonstrations = demonstrations[:20] model = sqil.SQIL( venv=cartpole_venv, From 68cbce8c68a97d333271ea5fb07c368020c767f4 Mon Sep 17 00:00:00 2001 From: Jason Hoelscher-Obermaier Date: Tue, 8 Aug 2023 15:32:25 +0200 Subject: [PATCH 52/57] Adapt hyperparameters in test_sqil_performance to reduce flakiness --- tests/algorithms/test_sqil.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 6d8c248a7..7a2237777 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -1,5 +1,4 @@ """Tests `imitation.algorithms.sqil`.""" - import numpy as np import pytest from stable_baselines3.common import policies, vec_env @@ -99,13 +98,15 @@ def test_sqil_performance( cartpole_venv: vec_env.VecEnv, ): demonstrations = get_demos(rng, pytestconfig, "transitions") - demonstrations = demonstrations[:20] - model = sqil.SQIL( venv=cartpole_venv, demonstrations=demonstrations, policy="MlpPolicy", - dqn_kwargs=dict(learning_starts=1000), + dqn_kwargs=dict( + learning_starts=500, + learning_rate=0.002, + batch_size=220, + ), ) rewards_before, _ = evaluate_policy( From 2bf467dbc469354398f8bd0b39980a39dbf29f5d Mon Sep 17 00:00:00 2001 From: Jason Hoelscher-Obermaier Date: Tue, 8 Aug 2023 15:51:54 +0200 Subject: [PATCH 53/57] Fix seeds for flaky test_sqil_performance --- tests/algorithms/test_sqil.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 7a2237777..26417344b 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -97,6 +97,7 @@ def test_sqil_performance( pytestconfig: pytest.Config, cartpole_venv: vec_env.VecEnv, ): + SEED = 42 demonstrations = get_demos(rng, pytestconfig, "transitions") model = sqil.SQIL( venv=cartpole_venv, @@ -106,9 +107,11 @@ def test_sqil_performance( learning_starts=500, learning_rate=0.002, batch_size=220, + seed=SEED, ), ) + cartpole_venv.seed(SEED) rewards_before, _ = evaluate_policy( model.policy, cartpole_venv, @@ -118,6 +121,7 @@ def test_sqil_performance( model.train(total_timesteps=10_000) + cartpole_venv.seed(SEED) rewards_after, _ = evaluate_policy( model.policy, cartpole_venv, From ccda686f640f0d4812833d0adb73985f27501ba5 Mon Sep 17 00:00:00 2001 From: Jason Hoelscher-Obermaier Date: Tue, 8 Aug 2023 16:47:47 +0200 Subject: [PATCH 54/57] Increase coverage in test_sqil.py --- tests/algorithms/test_sqil.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 26417344b..70248a39b 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -1,4 +1,6 @@ """Tests `imitation.algorithms.sqil`.""" +from unittest.mock import MagicMock + import numpy as np import pytest from stable_baselines3.common import policies, vec_env @@ -133,3 +135,14 @@ def test_sqil_performance( rewards_before, # type:ignore[arg-type] rewards_after, # type:ignore[arg-type] ) + + +@pytest.mark.parametrize("illegal_kw", ["replay_buffer_class", "replay_buffer_kwargs"]) +def test_sqil_constructor_raises(illegal_kw: str): + with pytest.raises(ValueError): + sqil.SQIL( + venv=MagicMock(spec=vec_env.VecEnv), + demonstrations=None, + policy="MlpPolicy", + dqn_kwargs={illegal_kw: None}, + ) From 91b226a073127fed89e5a77122994e29e6ec45ec Mon Sep 17 00:00:00 2001 From: Jason Hoelscher-Obermaier Date: Wed, 9 Aug 2023 09:28:49 +0200 Subject: [PATCH 55/57] Pass kwargs to SQIL.train to DQN.learn - also set default tb_log_name to "SQIL" --- src/imitation/algorithms/sqil.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index a334de894..fc4668a44 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -81,8 +81,12 @@ def set_demonstrations(self, demonstrations: algo_base.AnyTransitions) -> None: assert isinstance(self.dqn.replay_buffer, SQILReplayBuffer) self.dqn.replay_buffer.set_demonstrations(demonstrations) - def train(self, *, total_timesteps: int): - self.dqn.learn(total_timesteps=total_timesteps) + def train(self, *, total_timesteps: int, tb_log_name: str = "SQIL", **kwargs: Any): + self.dqn.learn( + total_timesteps=total_timesteps, + tb_log_name=tb_log_name, + **kwargs, + ) @property def policy(self) -> policies.BasePolicy: From 5cbb6b23cf8f47c811cd39289124f843d5a537e0 Mon Sep 17 00:00:00 2001 From: Jason Hoelscher-Obermaier Date: Wed, 9 Aug 2023 09:34:53 +0200 Subject: [PATCH 56/57] Pass parameters as kwargs for multi-ary methods in sqil.py --- src/imitation/algorithms/sqil.py | 33 +++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index fc4668a44..85c71d890 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -130,16 +130,20 @@ def __init__( the memory used, at a cost of more complexity. """ super().__init__( - buffer_size, - observation_space, - action_space, - device, - n_envs, - optimize_memory_usage, + buffer_size=buffer_size, + observation_space=observation_space, + action_space=action_space, + device=device, + n_envs=n_envs, + optimize_memory_usage=optimize_memory_usage, handle_timeout_termination=False, ) - self.expert_buffer = buffers.ReplayBuffer(0, observation_space, action_space) + self.expert_buffer = buffers.ReplayBuffer( + buffer_size=0, + observation_space=observation_space, + action_space=action_space, + ) self.set_demonstrations(demonstrations) def set_demonstrations( @@ -176,9 +180,9 @@ def set_demonstrations( n_samples = len(demonstrations) self.expert_buffer = buffers.ReplayBuffer( - n_samples, - self.observation_space, - self.action_space, + buffer_size=n_samples, + observation_space=self.observation_space, + action_space=self.action_space, handle_timeout_termination=False, ) @@ -201,7 +205,14 @@ def add( done: np.ndarray, infos: List[Dict[str, Any]], ) -> None: - super().add(obs, next_obs, action, np.array(0.0), done, infos) + super().add( + obs=obs, + next_obs=next_obs, + action=action, + reward=np.array(0.0), + done=done, + infos=infos, + ) def sample( self, From d2124a2a6050c85e568b3b72c2d7f71016a2e342 Mon Sep 17 00:00:00 2001 From: Jason Hoelscher-Obermaier Date: Wed, 9 Aug 2023 09:45:04 +0200 Subject: [PATCH 57/57] Make test for exceptions raised by SQIL constructor more specific - also: adjust imports to conform with style guide --- tests/algorithms/test_sqil.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/algorithms/test_sqil.py b/tests/algorithms/test_sqil.py index 70248a39b..7879e344e 100644 --- a/tests/algorithms/test_sqil.py +++ b/tests/algorithms/test_sqil.py @@ -1,5 +1,5 @@ """Tests `imitation.algorithms.sqil`.""" -from unittest.mock import MagicMock +from unittest import mock import numpy as np import pytest @@ -139,9 +139,9 @@ def test_sqil_performance( @pytest.mark.parametrize("illegal_kw", ["replay_buffer_class", "replay_buffer_kwargs"]) def test_sqil_constructor_raises(illegal_kw: str): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=".*SQIL uses a custom replay buffer.*"): sqil.SQIL( - venv=MagicMock(spec=vec_env.VecEnv), + venv=mock.MagicMock(spec=vec_env.VecEnv), demonstrations=None, policy="MlpPolicy", dqn_kwargs={illegal_kw: None},