Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of the SQIL algorithm #744

Merged
merged 60 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
1935c99
Initial version of the SQIL implementation
RedTachyon Jul 4, 2023
2d4151e
Pin SB3 version to 1.7.0 (#738) (#745)
RedTachyon Jul 4, 2023
993a0d7
Another redundant type warning
RedTachyon Jul 4, 2023
899a5d8
Correctly set the expert rewards to 1
RedTachyon Jul 5, 2023
73064ac
Update typing, add some tests
RedTachyon Jul 6, 2023
b6c9d26
Update sqil.py
RedTachyon Jul 6, 2023
42d5468
Style fixes
RedTachyon Jul 6, 2023
86825d8
Test updates
RedTachyon Jul 6, 2023
95a2661
Add a test to check the buffer
RedTachyon Jul 6, 2023
67662b4
Formatting, docstring
RedTachyon Jul 6, 2023
68f693b
Improve test coverage
RedTachyon Jul 6, 2023
c4b0521
Update branch to master (#749)
RedTachyon Jul 6, 2023
1b5338b
Some documentation updates (not complete)
RedTachyon Jul 6, 2023
3c78336
Add a SQIL tutorial
RedTachyon Jul 6, 2023
c303af1
Reduce tutorial runtime
RedTachyon Jul 6, 2023
bf81940
Add SQIL description in docs, try to add it to the right places
RedTachyon Jul 6, 2023
0f95524
Merge branch 'master' into redtachyon/740-sqil
RedTachyon Jul 6, 2023
5da56f3
Fix docs
RedTachyon Jul 6, 2023
e410c39
Merge remote-tracking branch 'HCAI/redtachyon/740-sqil' into redtachy…
RedTachyon Jul 6, 2023
d8f3c30
Blacken a tutorial
RedTachyon Jul 6, 2023
ae43a75
Reorder things in docs
RedTachyon Jul 7, 2023
5b23f84
Change the SQIL structure to instead subclass the replay buffer, new …
RedTachyon Jul 7, 2023
bc8152b
Add an empty line
RedTachyon Jul 7, 2023
7d56e6a
Simplify the arguments
RedTachyon Jul 7, 2023
4e3f156
Cover another edge case, another test, fixes
RedTachyon Jul 7, 2023
d018cbd
Fix a circular import issue
RedTachyon Jul 7, 2023
29cdbfa
Add a performance test - might be slow?
RedTachyon Jul 7, 2023
551fa7e
Fix coverage
RedTachyon Jul 7, 2023
fcd94b9
Improve input validation
AdamGleave Jul 8, 2023
34ddf82
Bugfix: have set_demonstrations set rather than return
AdamGleave Jul 8, 2023
cf20fbb
Move TransitionMapping from algorithms.base to data.types
AdamGleave Jul 8, 2023
ee16818
Fix typo: expert_buffer->self.expert_buffer
AdamGleave Jul 8, 2023
87876aa
Bugfix: use safe_to_numpy rather than assuming th.Tensor
AdamGleave Jul 8, 2023
12e30b1
Fix lint
AdamGleave Jul 8, 2023
90a3a79
Fix unused imports
AdamGleave Jul 8, 2023
ef0fd26
Refactor tests
AdamGleave Jul 8, 2023
34241b2
Bump # of rollouts to try to fix MacOS flakiness
AdamGleave Jul 9, 2023
ed399d3
Merge branch 'master' into redtachyon/740-sqil
ernestum Jul 18, 2023
c8e9df8
Simplify SQIL example and tutorial by 1. downloading expert trajector…
ernestum Jul 18, 2023
e4e5d9f
Improve docstring of SQILReplayBuffer.
ernestum Jul 18, 2023
b89e5d8
Set the expert_buffer in the constructor.
ernestum Jul 18, 2023
c7723e5
Consistently set expert transition reward to 1 and learner transition…
ernestum Jul 18, 2023
e0bc16d
Fix docstring of SQILReplayBuffer.sample()
ernestum Jul 18, 2023
203c89f
Switch back to the CartPole-v1 environment in the SQIL examples
ernestum Jul 18, 2023
c149385
Only train for 1k steps in the SQIL example so the doctests don't run…
ernestum Jul 18, 2023
18a6622
Fix cell metadata for tutorial notebook.
ernestum Jul 18, 2023
9c5b91c
Notebook formatting fixes.
ernestum Jul 18, 2023
f8584c3
Fix typing error in SQIL implementation.
ernestum Jul 18, 2023
02f3191
Fix isort issue.
ernestum Jul 18, 2023
649de46
Clarify that our variant of the SQIL implementation is not really "so…
ernestum Jul 19, 2023
c72b088
Fix link in experts documentation.
ernestum Jul 19, 2023
8277a5c
Remove support for transition mappings.
ernestum Jul 19, 2023
a0af5c5
Remove data_loader from SQIL test cases.
ernestum Jul 20, 2023
4ccea30
Bump number of demonstrations in SQIL performance test to reduce flak…
ernestum Jul 21, 2023
68cbce8
Adapt hyperparameters in test_sqil_performance to reduce flakiness
jas-ho Aug 8, 2023
2bf467d
Fix seeds for flaky test_sqil_performance
jas-ho Aug 8, 2023
ccda686
Increase coverage in test_sqil.py
jas-ho Aug 8, 2023
91b226a
Pass kwargs to SQIL.train to DQN.learn
jas-ho Aug 9, 2023
5cbb6b2
Pass parameters as kwargs for multi-ary methods in sqil.py
jas-ho Aug 9, 2023
d2124a2
Make test for exceptions raised by SQIL constructor more specific
jas-ho Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Currently, we have implementations of the algorithms below. 'Discrete' and 'Cont
| [Adversarial Inverse Reinforcement Learning](https://arxiv.org/abs/1710.11248) | [`algoritms.airl`](https://imitation.readthedocs.io/en/latest/algorithms/airl.html) | ✅ | ✅ |
| [Generative Adversarial Imitation Learning](https://arxiv.org/abs/1606.03476) | [`algorithms.gail`](https://imitation.readthedocs.io/en/latest/algorithms/gail.html) | ✅ | ✅ |
| [Deep RL from Human Preferences](https://arxiv.org/abs/1706.03741) | [`algorithms.preference_comparisons`](https://imitation.readthedocs.io/en/latest/algorithms/preference_comparisons.html) | ✅ | ✅ |
| [Soft Q Imitation Learning](https://arxiv.org/abs/1905.11108) | [`algorithms.sqil`](https://imitation.readthedocs.io/en/latest/algorithms/sqil.html) | ✅ | ❌ |


You can find [the documentation here](https://imitation.readthedocs.io/en/latest/).
Expand Down
61 changes: 61 additions & 0 deletions docs/algorithms/sqil.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
.. _soft q imitation learning docs:

================================
Soft Q Imitation Learning (SQIL)
================================

Soft Q Imitation learning learns to imitate a policy from demonstrations by
using the DQN algorithm with modified rewards. During each policy update, half
of the batch is sampled from the demonstrations and half is sampled from the
environment. Expert demonstrations are assigned a reward of 1, and the
environment is assigned a reward of 0. This encourages the policy to imitate
the demonstrations, and to simultaneously avoid states not seen in the
demonstrations.

.. note::

This implementation is based on the DQN implementation in Stable Baselines 3,
which does not implement the soft Q-learning and therefore does not support
continuous actions. Therefore, this implementation only supports discrete actions
and the name "soft" Q-learning could be misleading.

Example
=======

Detailed example notebook: :doc:`../tutorials/8_train_sqil`

.. testcode::
:skipif: skip_doctests

import datasets
import gym
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv

from imitation.algorithms import sqil
from imitation.data import huggingface_utils

# Download some expert trajectories from the HuggingFace Datasets Hub.
dataset = datasets.load_dataset("HumanCompatibleAI/ppo-CartPole-v1")
rollouts = huggingface_utils.TrajectoryDatasetSequence(dataset["train"])

sqil_trainer = sqil.SQIL(
venv=DummyVecEnv([lambda: gym.make("CartPole-v1")]),
demonstrations=rollouts,
policy="MlpPolicy",
)
# Hint: set to 1_000_000 to match the expert performance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100_000 was already sufficient to reach expert performance (tried only a couple of times though)

sqil_trainer.train(total_timesteps=1_000)
reward, _ = evaluate_policy(sqil_trainer.policy, sqil_trainer.venv, 10)
print("Reward:", reward)

.. testoutput::
:hide:

...

API
===
.. autoclass:: imitation.algorithms.sqil.SQIL
:members:
:noindex:
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ If you use ``imitation`` in your research project, please cite our paper to help
algorithms/density
algorithms/mce_irl
algorithms/preference_comparisons
algorithms/sqil

.. toctree::
:maxdepth: 2
Expand All @@ -85,8 +86,9 @@ If you use ``imitation`` in your research project, please cite our paper to help
tutorials/5a_train_preference_comparisons_with_cnn
tutorials/6_train_mce
tutorials/7_train_density
tutorials/8_train_custom_env
tutorials/8_train_sqil
tutorials/9_compare_baselines
RedTachyon marked this conversation as resolved.
Show resolved Hide resolved
tutorials/10_train_custom_env

API Reference
~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/main-concepts/experts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ learning library.
For example, BC and DAgger can learn from an expert policy and the command line
interface of AIRL/GAIL allows one to specify an expert to sample demonstrations from.

In the :doc:`../getting-started/first-steps` tutorial, we first train an expert policy
In the :doc:`../getting-started/first_steps` tutorial, we first train an expert policy
using the stable-baselines3 library and then imitate it's behavior using
:doc:`../algorithms/bc`.
In practice, you may want to load a pre-trained policy for performance reasons.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_custom_env.ipynb)\n",
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/10_train_custom_env.ipynb)\n",
"# Train Behavior Cloning in a Custom Environment\n",
"\n",
"You can use `imitation` to train a policy (and, for many imitation learning algorithm, learn rewards) in a custom environment.\n",
Expand Down
157 changes: 157 additions & 0 deletions docs/tutorials/8_train_sqil.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_sqil.ipynb)\n",
"# Train an Agent using Soft Q Imitation Learning\n",
"\n",
"Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) is a simple algorithm that can be used to clone expert behavior.\n",
"It's fundamentally a modification of the DQN algorithm. At each training step, whenever we sample a batch of data from the replay buffer,\n",
"we also sample a batch of expert data. Expert demonstrations are assigned a reward of 1, while the agent's own transitions are assigned a reward of 0.\n",
"This approach encourages the agent to imitate the expert's behavior, but also to avoid unfamiliar states.\n",
"\n",
"In this tutorial we will use the `imitation` library to train an agent using SQIL."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we need some expert trajectories in our environment (`seals/CartPole-v0`).\n",
"Note that you can use other environments, but the action space must be discrete for this algorithm."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import datasets\n",
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
"\n",
"from imitation.data import huggingface_utils\n",
"\n",
"# Download some expert trajectories from the HuggingFace Datasets Hub.\n",
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-CartPole-v1\")\n",
"\n",
"# Convert the dataset to a format usable by the imitation library.\n",
"expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's quickly check if the expert is any good.\n",
"We usually should be able to reach a reward of 500, which is the maximum achievable value."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from imitation.data import rollout\n",
"\n",
"trajectory_stats = rollout.rollout_stats(expert_trajectories)\n",
"\n",
"print(\n",
" f\"We have {trajectory_stats['n_traj']} trajectories.\"\n",
" f\"The average length of each trajectory is {trajectory_stats['len_mean']}.\"\n",
" f\"The average return of each trajectory is {trajectory_stats['return_mean']}.\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After we collected our expert trajectories, it's time to set up our behavior cloning algorithm."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from imitation.algorithms import sqil\n",
"import gym\n",
"\n",
"venv = DummyVecEnv([lambda: gym.make(\"CartPole-v1\")])\n",
"sqil_trainer = sqil.SQIL(\n",
" venv=venv,\n",
" demonstrations=expert_trajectories,\n",
" policy=\"MlpPolicy\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see the untrained policy only gets poor rewards:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from stable_baselines3.common.evaluation import evaluate_policy\n",
"\n",
"reward_before_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n",
"print(f\"Reward before training: {reward_before_training}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After training, we can match the rewards of the expert (500):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sqil_trainer.train(\n",
" total_timesteps=1_000,\n",
") # Note: set to 1_000_000 to obtain good results\n",
"reward_after_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n",
"print(f\"Reward after training: {reward_after_training}\")"
]
}
],
"metadata": {
"interpreter": {
"hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 2 additions & 2 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]):
If `debug_use_ground_truth=True` was passed into the initializer then
`self.venv_train` is the same as `self.venv`."""

_demo_data_loader: Optional[Iterable[base.TransitionMapping]]
_endless_expert_iterator: Optional[Iterator[base.TransitionMapping]]
_demo_data_loader: Optional[Iterable[types.TransitionMapping]]
_endless_expert_iterator: Optional[Iterator[types.TransitionMapping]]

venv_wrapped: vec_env.VecEnvWrapper

Expand Down
11 changes: 4 additions & 7 deletions src/imitation/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
cast,
)

import numpy as np
import torch as th
import torch.utils.data as th_data
from stable_baselines3.common import policies

Expand Down Expand Up @@ -123,11 +121,10 @@ def __setstate__(self, state):
self.logger = state.get("_logger") or imit_logger.configure()


TransitionMapping = Mapping[str, Union[np.ndarray, th.Tensor]]
TransitionKind = TypeVar("TransitionKind", bound=types.TransitionsMinimal)
AnyTransitions = Union[
Iterable[types.Trajectory],
Iterable[TransitionMapping],
Iterable[types.TransitionMapping],
types.TransitionsMinimal,
]

Expand Down Expand Up @@ -190,7 +187,7 @@ class _WrappedDataLoader:

def __init__(
self,
data_loader: Iterable[TransitionMapping],
data_loader: Iterable[types.TransitionMapping],
expected_batch_size: int,
):
"""Builds _WrappedDataLoader.
Expand All @@ -202,7 +199,7 @@ def __init__(
self.data_loader = data_loader
self.expected_batch_size = expected_batch_size

def __iter__(self) -> Iterator[TransitionMapping]:
def __iter__(self) -> Iterator[types.TransitionMapping]:
"""Yields data from `self.data_loader`, checking `self.expected_batch_size`.

Yields:
Expand Down Expand Up @@ -230,7 +227,7 @@ def make_data_loader(
transitions: AnyTransitions,
batch_size: int,
data_loader_kwargs: Optional[Mapping[str, Any]] = None,
) -> Iterable[TransitionMapping]:
) -> Iterable[types.TransitionMapping]:
"""Converts demonstration data to Torch data loader.

Args:
Expand Down
12 changes: 6 additions & 6 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BatchIteratorWithEpochEndCallback:
Will throw an exception when an epoch contains no batches.
"""

batch_loader: Iterable[algo_base.TransitionMapping]
batch_loader: Iterable[types.TransitionMapping]
n_epochs: Optional[int]
n_batches: Optional[int]
on_epoch_end: Optional[Callable[[int], None]]
Expand All @@ -55,8 +55,8 @@ def __post_init__(self) -> None:
"Must provide exactly one of `n_epochs` and `n_batches` arguments.",
)

def __iter__(self) -> Iterator[algo_base.TransitionMapping]:
def batch_iterator() -> Iterator[algo_base.TransitionMapping]:
def __iter__(self) -> Iterator[types.TransitionMapping]:
def batch_iterator() -> Iterator[types.TransitionMapping]:

# Note: the islice here ensures we do not exceed self.n_epochs
for epoch_num in itertools.islice(itertools.count(), self.n_epochs):
Expand Down Expand Up @@ -143,8 +143,8 @@ def __call__(


def enumerate_batches(
batch_it: Iterable[algo_base.TransitionMapping],
) -> Iterable[Tuple[Tuple[int, int, int], algo_base.TransitionMapping]]:
batch_it: Iterable[types.TransitionMapping],
) -> Iterable[Tuple[Tuple[int, int, int], types.TransitionMapping]]:
"""Prepends batch stats before the batches of a batch iterator."""
num_samples_so_far = 0
for num_batches, batch in enumerate(batch_it):
Expand Down Expand Up @@ -308,7 +308,7 @@ def __init__(
parameter `l2_weight` instead), or if the batch size is not a multiple
of the minibatch size.
"""
self._demo_data_loader: Optional[Iterable[algo_base.TransitionMapping]] = None
self._demo_data_loader: Optional[Iterable[types.TransitionMapping]] = None
self.batch_size = batch_size
self.minibatch_size = minibatch_size or batch_size
if self.batch_size % self.minibatch_size != 0:
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/algorithms/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:
transitions.setdefault(i, []).append(flat_trans)
elif isinstance(first_item, Mapping):
# analogous to cast above.
demonstrations = cast(Iterable[base.TransitionMapping], demonstrations)
demonstrations = cast(Iterable[types.TransitionMapping], demonstrations)

for batch in demonstrations:
transitions.update(
Expand Down
Loading
Loading