-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tune hyperparameters in tutorials for GAIL and AIRL #772
Changes from all commits
b8d1616
09c5f2f
4872ceb
91b508c
4fc83be
ab6e0c3
154ed62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,7 +89,7 @@ | |
"outputs": [], | ||
"source": [ | ||
"from imitation.algorithms.adversarial.gail import GAIL\n", | ||
"from imitation.rewards.reward_nets import BasicShapedRewardNet\n", | ||
"from imitation.rewards.reward_nets import BasicRewardNet\n", | ||
"from imitation.util.networks import RunningNorm\n", | ||
"from stable_baselines3 import PPO\n", | ||
"from stable_baselines3.ppo import MlpPolicy\n", | ||
|
@@ -100,20 +100,21 @@ | |
" policy=MlpPolicy,\n", | ||
" batch_size=64,\n", | ||
" ent_coef=0.0,\n", | ||
" learning_rate=0.00001,\n", | ||
" n_epochs=1,\n", | ||
" learning_rate=0.0004,\n", | ||
" gamma=0.95,\n", | ||
" n_epochs=5,\n", | ||
" seed=SEED,\n", | ||
")\n", | ||
"reward_net = BasicShapedRewardNet(\n", | ||
"reward_net = BasicRewardNet(\n", | ||
" observation_space=env.observation_space,\n", | ||
" action_space=env.action_space,\n", | ||
" normalize_input_layer=RunningNorm,\n", | ||
")\n", | ||
"gail_trainer = GAIL(\n", | ||
" demonstrations=rollouts,\n", | ||
" demo_batch_size=1024,\n", | ||
" gen_replay_buffer_capacity=2048,\n", | ||
" n_disc_updates_per_round=4,\n", | ||
" gen_replay_buffer_capacity=512,\n", | ||
" n_disc_updates_per_round=8,\n", | ||
" venv=env,\n", | ||
" gen_algo=learner,\n", | ||
" reward_net=reward_net,\n", | ||
|
@@ -126,7 +127,7 @@ | |
")\n", | ||
"\n", | ||
"# train the learner and evaluate again\n", | ||
"gail_trainer.train(20000)\n", | ||
"gail_trainer.train(800_000)\n", | ||
"env.seed(SEED)\n", | ||
"learner_rewards_after_training, _ = evaluate_policy(\n", | ||
" learner, env, 100, return_episode_rewards=True\n", | ||
|
@@ -137,7 +138,7 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least." | ||
"We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):" | ||
] | ||
}, | ||
{ | ||
|
@@ -146,17 +147,18 @@ | |
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n", | ||
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n", | ||
"\n", | ||
"plt.hist(\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you removing histogram (here and in AIRL)? Fine to remove if it's not informative. But perhaps we should report the SD as well as the means? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the reason was I thought it was not super informative (especially in case we reach expert perf). Good suggestion with SD though, will add! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shameless plug: this would be a nice application for my newly release data-samples-printer: import data_samples_printer as dsp
dsp.pprint(
before_training=learner_rewards_before_training,
after_training=learner_rewards_after_training
) prints something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ernestum , thanks for this, the lib looks quite cool! I'll remember about it in the future. For this PR I decided to not introduce additional dependency though. |
||
" [learner_rewards_before_training, learner_rewards_after_training],\n", | ||
" label=[\"untrained\", \"trained\"],\n", | ||
"print(\n", | ||
" \"Rewards before training:\",\n", | ||
" np.mean(learner_rewards_before_training),\n", | ||
" \"+/-\",\n", | ||
" np.std(learner_rewards_before_training),\n", | ||
")\n", | ||
"plt.legend()\n", | ||
"plt.show()" | ||
"print(\n", | ||
" \"Rewards after training:\",\n", | ||
" np.mean(learner_rewards_after_training),\n", | ||
" \"+/-\",\n", | ||
" np.std(learner_rewards_after_training),\n", | ||
")" | ||
] | ||
} | ||
], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 million timesteps is a lot of timesteps for something as simple as CartPole, I expect we can do better but this seems fine for the purpose of this PR, at least the environment runs quickly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd keep it for now (it's already an improvement) and possibly revisit in another PR.