-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tune hyperparameters in tutorials for GAIL and AIRL #772
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! I meant to do this for quite some time!
The changes itsel LGTM.
I think the pipeline fails because we get the newest seals version (0.2) which is made for gymnasium. If we change our seals version specifier in setup.py
to seals~=0.1.5
, this should be fixed.
@@ -84,7 +88,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl` | |||
learner_rewards_before_training, _ = evaluate_policy( | |||
learner, env, 100, return_episode_rewards=True, | |||
) | |||
airl_trainer.train(20000) | |||
airl_trainer.train(20000) # Train for 2_000_000 steps to match expert. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 million timesteps is a lot of timesteps for something as simple as CartPole, I expect we can do better but this seems fine for the purpose of this PR, at least the environment runs quickly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd keep it for now (it's already an improvement) and possibly revisit in another PR.
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n", | ||
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n", | ||
"\n", | ||
"plt.hist(\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you removing histogram (here and in AIRL)? Fine to remove if it's not informative. But perhaps we should report the SD as well as the means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the reason was I thought it was not super informative (especially in case we reach expert perf). Good suggestion with SD though, will add!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shameless plug: this would be a nice application for my newly release data-samples-printer:
import data_samples_printer as dsp
dsp.pprint(
before_training=learner_rewards_before_training,
after_training=learner_rewards_after_training
)
prints something like:
▁ ▁ ▁▄ ▄▄▄█▇▄▄▇▄▇█▄█▃▃▇▄▇ ▇▁▃▄▁▃ ▄▃▁ ▁▁ ▁ -0.00 ±1.08 before_training
▂▃▇█▄▄▂▁ -0.01 ±0.20 after_training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ernestum , thanks for this, the lib looks quite cool! I'll remember about it in the future. For this PR I decided to not introduce additional dependency though.
4a0678e
to
4fc83be
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…I#772) * Pin huggingface_sb3 version. * Properly specify the compatible seals version so it does not auto-upgrade to 0.2. * Make random_mdp test deterministic by seeding the environment. * Tune hyperparameters in tutorials for GAIL and AIRL * Modify .rst docs for GAIL and AIRL to match tutorials * GAIL and AIRL tutorials: report also std in results --------- Co-authored-by: Maximilian Ernestus <[email protected]> Co-authored-by: Adam Gleave <[email protected]>
Description
This PR tunes hyperparameters for the GAIL and AIRL tutorials.
For GAIL, the expert performance is reached (~500 on CartPole) with 800K PPO steps (~2 min run time on MacBook Air M1). For AIRL, the default is the "fast" version which improves over random but does not reach the expert performance (800K steps, ~2 min run time); if we switch off "fast" then the expert performance is reached (2M steps, ~5 min run time).
The hyperparameters were inspired by configs for half-cheetah from the
benchmarking
directory + a bit of manual tuning. Also, for GAIL I needed to change fromBasicShapedRewardNet
toBasicRewardNet
to make it work (not exactly sure why but it affected performance a lot!).Testing
Just ran the notebooks, and also tested with a few different seeds to make sure results are stable.