Skip to content

Commit

Permalink
Tune hyperparameters in tutorials for GAIL and AIRL
Browse files Browse the repository at this point in the history
  • Loading branch information
michalzajac-ml committed Sep 6, 2023
1 parent 4872ceb commit 91b508c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 38 deletions.
30 changes: 11 additions & 19 deletions docs/tutorials/3_train_gail.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"outputs": [],
"source": [
"from imitation.algorithms.adversarial.gail import GAIL\n",
"from imitation.rewards.reward_nets import BasicShapedRewardNet\n",
"from imitation.rewards.reward_nets import BasicRewardNet\n",
"from imitation.util.networks import RunningNorm\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo import MlpPolicy\n",
Expand All @@ -100,20 +100,21 @@
" policy=MlpPolicy,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.00001,\n",
" n_epochs=1,\n",
" learning_rate=0.0004,\n",
" gamma=0.95,\n",
" n_epochs=5,\n",
" seed=SEED,\n",
")\n",
"reward_net = BasicShapedRewardNet(\n",
"reward_net = BasicRewardNet(\n",
" observation_space=env.observation_space,\n",
" action_space=env.action_space,\n",
" normalize_input_layer=RunningNorm,\n",
")\n",
"gail_trainer = GAIL(\n",
" demonstrations=rollouts,\n",
" demo_batch_size=1024,\n",
" gen_replay_buffer_capacity=2048,\n",
" n_disc_updates_per_round=4,\n",
" gen_replay_buffer_capacity=512,\n",
" n_disc_updates_per_round=8,\n",
" venv=env,\n",
" gen_algo=learner,\n",
" reward_net=reward_net,\n",
Expand All @@ -126,7 +127,7 @@
")\n",
"\n",
"# train the learner and evaluate again\n",
"gail_trainer.train(20000)\n",
"gail_trainer.train(800_000)\n",
"env.seed(SEED)\n",
"learner_rewards_after_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
Expand All @@ -137,7 +138,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least."
"We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):"
]
},
{
Expand All @@ -146,17 +147,8 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n",
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"\n",
"plt.hist(\n",
" [learner_rewards_before_training, learner_rewards_after_training],\n",
" label=[\"untrained\", \"trained\"],\n",
")\n",
"plt.legend()\n",
"plt.show()"
"print(\"Mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"print(\"Mean reward after training:\", np.mean(learner_rewards_after_training))"
]
}
],
Expand Down
39 changes: 20 additions & 19 deletions docs/tutorials/4_train_airl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
"\n",
"SEED = 42\n",
"\n",
"FAST = True\n",
"\n",
"if FAST:\n",
" N_RL_TRAIN_STEPS = 800_000\n",
"else:\n",
" N_RL_TRAIN_STEPS = 2_000_000\n",
"\n",
"env = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" rng=np.random.default_rng(SEED),\n",
Expand Down Expand Up @@ -96,10 +103,13 @@
"learner = PPO(\n",
" env=env,\n",
" policy=MlpPolicy,\n",
" batch_size=16,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.0001,\n",
" n_epochs=2,\n",
" learning_rate=0.0005,\n",
" gamma=0.95,\n",
" clip_range=0.1,\n",
" vf_coef=0.1,\n",
" n_epochs=5,\n",
" seed=SEED,\n",
")\n",
"reward_net = BasicShapedRewardNet(\n",
Expand All @@ -109,9 +119,9 @@
")\n",
"airl_trainer = AIRL(\n",
" demonstrations=rollouts,\n",
" demo_batch_size=1024,\n",
" gen_replay_buffer_capacity=2048,\n",
" n_disc_updates_per_round=4,\n",
" demo_batch_size=2048,\n",
" gen_replay_buffer_capacity=512,\n",
" n_disc_updates_per_round=16,\n",
" venv=env,\n",
" gen_algo=learner,\n",
" reward_net=reward_net,\n",
Expand All @@ -121,7 +131,7 @@
"learner_rewards_before_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
")\n",
"airl_trainer.train(20000)\n",
"airl_trainer.train(N_RL_TRAIN_STEPS)\n",
"env.seed(SEED)\n",
"learner_rewards_after_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
Expand All @@ -132,7 +142,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least."
"We can see that an untrained policy performs poorly, while AIRL brings an improvement. To make it match the expert performance (500), set the flag `FAST` to `False` in the first cell."
]
},
{
Expand All @@ -141,17 +151,8 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n",
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"\n",
"plt.hist(\n",
" [learner_rewards_before_training, learner_rewards_after_training],\n",
" label=[\"untrained\", \"trained\"],\n",
")\n",
"plt.legend()\n",
"plt.show()"
"print(\"Mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"print(\"Mean reward after training:\", np.mean(learner_rewards_after_training))"
]
}
],
Expand Down

0 comments on commit 91b508c

Please sign in to comment.