-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an option to run SQIL with various off-policy algorithms #778
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the implementation! Overall looks strong, just a few relatively minor changes.
tests/algorithms/test_sqil.py
Outdated
cache = pytestconfig.cache | ||
assert cache is not None | ||
return expert_trajectories.make_expert_transition_loader( | ||
cache_dir=cache.mkdir("experts"), | ||
cache_dir=cache.mkdir(env_name.replace("/", "_")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need environment name in the cache directory? Should already be included in the environment path in https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/testing/expert_trajectories.py#L74
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, indeed, I was confused about the implementation of this function that uses cache and was not sure if I need to make it unique or not. Now I see that this root cache dir can be shared.
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After we collected our expert trajectories, it's time to set up our behavior cloning algorithm." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this was just copied from the original tutorial but I find the reference to behavior cloning potentially ambiguous: it usually refers to supervised learning on expert trajectories (and we have a BC
class that does exactly), SQIL is doing something conceptually similar but quite different in the details (RL rather than supervised learning).
Would suggest rephrasing this (and the original tutorial), could just call it an imitation algorithm rather than supervised learning algorithm.
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After training, we can observe that agent is quite improved (> 1000), although it does not reach the expert performance in this case." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have time to do more tuning, great, but not a priority; this is enough to illustrate the algorithm.
Co-authored-by: Adam Gleave <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…mpatibleAI#778) * Pin huggingface_sb3 version. * Properly specify the compatible seals version so it does not auto-upgrade to 0.2. * Make random_mdp test deterministic by seeding the environment. * Add an option to run SQIL with various off-policy algorithms * Add 8a_train_sqil_sac to toctree * Fix performance tests for SQIL * fix * Update docs/tutorials/8a_train_sqil_sac.ipynb Co-authored-by: Adam Gleave <[email protected]> * minor fixes * Bring back performance tests for SQIL --------- Co-authored-by: Maximilian Ernestus <[email protected]> Co-authored-by: Adam Gleave <[email protected]>
Description
This PR adds a possibility to combine SQIL with off-policy algorithms other than DQN, such as SAC, TD3, DDPG, as requested in #767.
A tutorial with SQIL+SAC training on HalfCheetah env is also provided. Random policy gets < 0, expert demonstrations are at ~3400. SQIL+SAC reaches 1400.7 +/- 254.1 after 300K steps (mean +/- std from 5 runs).
Testing
pytest tests/algorithms/test_sqil.py
-- adapted relevant tests to work with new base algorithms.Also one can run the provided tutorial.