-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of the SQIL algorithm #744
Merged
Merged
Changes from all commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
1935c99
Initial version of the SQIL implementation
RedTachyon 2d4151e
Pin SB3 version to 1.7.0 (#738) (#745)
RedTachyon 993a0d7
Another redundant type warning
RedTachyon 899a5d8
Correctly set the expert rewards to 1
RedTachyon 73064ac
Update typing, add some tests
RedTachyon b6c9d26
Update sqil.py
RedTachyon 42d5468
Style fixes
RedTachyon 86825d8
Test updates
RedTachyon 95a2661
Add a test to check the buffer
RedTachyon 67662b4
Formatting, docstring
RedTachyon 68f693b
Improve test coverage
RedTachyon c4b0521
Update branch to master (#749)
RedTachyon 1b5338b
Some documentation updates (not complete)
RedTachyon 3c78336
Add a SQIL tutorial
RedTachyon c303af1
Reduce tutorial runtime
RedTachyon bf81940
Add SQIL description in docs, try to add it to the right places
RedTachyon 0f95524
Merge branch 'master' into redtachyon/740-sqil
RedTachyon 5da56f3
Fix docs
RedTachyon e410c39
Merge remote-tracking branch 'HCAI/redtachyon/740-sqil' into redtachy…
RedTachyon d8f3c30
Blacken a tutorial
RedTachyon ae43a75
Reorder things in docs
RedTachyon 5b23f84
Change the SQIL structure to instead subclass the replay buffer, new …
RedTachyon bc8152b
Add an empty line
RedTachyon 7d56e6a
Simplify the arguments
RedTachyon 4e3f156
Cover another edge case, another test, fixes
RedTachyon d018cbd
Fix a circular import issue
RedTachyon 29cdbfa
Add a performance test - might be slow?
RedTachyon 551fa7e
Fix coverage
RedTachyon fcd94b9
Improve input validation
AdamGleave 34ddf82
Bugfix: have set_demonstrations set rather than return
AdamGleave cf20fbb
Move TransitionMapping from algorithms.base to data.types
AdamGleave ee16818
Fix typo: expert_buffer->self.expert_buffer
AdamGleave 87876aa
Bugfix: use safe_to_numpy rather than assuming th.Tensor
AdamGleave 12e30b1
Fix lint
AdamGleave 90a3a79
Fix unused imports
AdamGleave ef0fd26
Refactor tests
AdamGleave 34241b2
Bump # of rollouts to try to fix MacOS flakiness
AdamGleave ed399d3
Merge branch 'master' into redtachyon/740-sqil
ernestum c8e9df8
Simplify SQIL example and tutorial by 1. downloading expert trajector…
ernestum e4e5d9f
Improve docstring of SQILReplayBuffer.
ernestum b89e5d8
Set the expert_buffer in the constructor.
ernestum c7723e5
Consistently set expert transition reward to 1 and learner transition…
ernestum e0bc16d
Fix docstring of SQILReplayBuffer.sample()
ernestum 203c89f
Switch back to the CartPole-v1 environment in the SQIL examples
ernestum c149385
Only train for 1k steps in the SQIL example so the doctests don't run…
ernestum 18a6622
Fix cell metadata for tutorial notebook.
ernestum 9c5b91c
Notebook formatting fixes.
ernestum f8584c3
Fix typing error in SQIL implementation.
ernestum 02f3191
Fix isort issue.
ernestum 649de46
Clarify that our variant of the SQIL implementation is not really "so…
ernestum c72b088
Fix link in experts documentation.
ernestum 8277a5c
Remove support for transition mappings.
ernestum a0af5c5
Remove data_loader from SQIL test cases.
ernestum 4ccea30
Bump number of demonstrations in SQIL performance test to reduce flak…
ernestum 68cbce8
Adapt hyperparameters in test_sqil_performance to reduce flakiness
jas-ho 2bf467d
Fix seeds for flaky test_sqil_performance
jas-ho ccda686
Increase coverage in test_sqil.py
jas-ho 91b226a
Pass kwargs to SQIL.train to DQN.learn
jas-ho 5cbb6b2
Pass parameters as kwargs for multi-ary methods in sqil.py
jas-ho d2124a2
Make test for exceptions raised by SQIL constructor more specific
jas-ho File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
.. _soft q imitation learning docs: | ||
|
||
================================ | ||
Soft Q Imitation Learning (SQIL) | ||
================================ | ||
|
||
Soft Q Imitation learning learns to imitate a policy from demonstrations by | ||
using the DQN algorithm with modified rewards. During each policy update, half | ||
of the batch is sampled from the demonstrations and half is sampled from the | ||
environment. Expert demonstrations are assigned a reward of 1, and the | ||
environment is assigned a reward of 0. This encourages the policy to imitate | ||
the demonstrations, and to simultaneously avoid states not seen in the | ||
demonstrations. | ||
|
||
.. note:: | ||
|
||
This implementation is based on the DQN implementation in Stable Baselines 3, | ||
which does not implement the soft Q-learning and therefore does not support | ||
continuous actions. Therefore, this implementation only supports discrete actions | ||
and the name "soft" Q-learning could be misleading. | ||
|
||
Example | ||
======= | ||
|
||
Detailed example notebook: :doc:`../tutorials/8_train_sqil` | ||
|
||
.. testcode:: | ||
:skipif: skip_doctests | ||
|
||
import datasets | ||
import gym | ||
from stable_baselines3.common.evaluation import evaluate_policy | ||
from stable_baselines3.common.vec_env import DummyVecEnv | ||
|
||
from imitation.algorithms import sqil | ||
from imitation.data import huggingface_utils | ||
|
||
# Download some expert trajectories from the HuggingFace Datasets Hub. | ||
dataset = datasets.load_dataset("HumanCompatibleAI/ppo-CartPole-v1") | ||
rollouts = huggingface_utils.TrajectoryDatasetSequence(dataset["train"]) | ||
|
||
sqil_trainer = sqil.SQIL( | ||
venv=DummyVecEnv([lambda: gym.make("CartPole-v1")]), | ||
demonstrations=rollouts, | ||
policy="MlpPolicy", | ||
) | ||
# Hint: set to 1_000_000 to match the expert performance. | ||
sqil_trainer.train(total_timesteps=1_000) | ||
reward, _ = evaluate_policy(sqil_trainer.policy, sqil_trainer.venv, 10) | ||
print("Reward:", reward) | ||
|
||
.. testoutput:: | ||
:hide: | ||
|
||
... | ||
|
||
API | ||
=== | ||
.. autoclass:: imitation.algorithms.sqil.SQIL | ||
:members: | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/8_train_sqil.ipynb)\n", | ||
"# Train an Agent using Soft Q Imitation Learning\n", | ||
"\n", | ||
"Soft Q Imitation Learning ([SQIL](https://arxiv.org/abs/1905.11108)) is a simple algorithm that can be used to clone expert behavior.\n", | ||
"It's fundamentally a modification of the DQN algorithm. At each training step, whenever we sample a batch of data from the replay buffer,\n", | ||
"we also sample a batch of expert data. Expert demonstrations are assigned a reward of 1, while the agent's own transitions are assigned a reward of 0.\n", | ||
"This approach encourages the agent to imitate the expert's behavior, but also to avoid unfamiliar states.\n", | ||
"\n", | ||
"In this tutorial we will use the `imitation` library to train an agent using SQIL." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"First, we need some expert trajectories in our environment (`seals/CartPole-v0`).\n", | ||
"Note that you can use other environments, but the action space must be discrete for this algorithm." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import datasets\n", | ||
"from stable_baselines3.common.vec_env import DummyVecEnv\n", | ||
"\n", | ||
"from imitation.data import huggingface_utils\n", | ||
"\n", | ||
"# Download some expert trajectories from the HuggingFace Datasets Hub.\n", | ||
"dataset = datasets.load_dataset(\"HumanCompatibleAI/ppo-CartPole-v1\")\n", | ||
"\n", | ||
"# Convert the dataset to a format usable by the imitation library.\n", | ||
"expert_trajectories = huggingface_utils.TrajectoryDatasetSequence(dataset[\"train\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Let's quickly check if the expert is any good.\n", | ||
"We usually should be able to reach a reward of 500, which is the maximum achievable value." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from imitation.data import rollout\n", | ||
"\n", | ||
"trajectory_stats = rollout.rollout_stats(expert_trajectories)\n", | ||
"\n", | ||
"print(\n", | ||
" f\"We have {trajectory_stats['n_traj']} trajectories.\"\n", | ||
" f\"The average length of each trajectory is {trajectory_stats['len_mean']}.\"\n", | ||
" f\"The average return of each trajectory is {trajectory_stats['return_mean']}.\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After we collected our expert trajectories, it's time to set up our behavior cloning algorithm." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from imitation.algorithms import sqil\n", | ||
"import gym\n", | ||
"\n", | ||
"venv = DummyVecEnv([lambda: gym.make(\"CartPole-v1\")])\n", | ||
"sqil_trainer = sqil.SQIL(\n", | ||
" venv=venv,\n", | ||
" demonstrations=expert_trajectories,\n", | ||
" policy=\"MlpPolicy\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"As you can see the untrained policy only gets poor rewards:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from stable_baselines3.common.evaluation import evaluate_policy\n", | ||
"\n", | ||
"reward_before_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n", | ||
"print(f\"Reward before training: {reward_before_training}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After training, we can match the rewards of the expert (500):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sqil_trainer.train(\n", | ||
" total_timesteps=1_000,\n", | ||
") # Note: set to 1_000_000 to obtain good results\n", | ||
"reward_after_training, _ = evaluate_policy(sqil_trainer.policy, venv, 10)\n", | ||
"print(f\"Reward after training: {reward_after_training}\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100_000
was already sufficient to reach expert performance (tried only a couple of times though)