Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune hyperparameters in tutorials for GAIL and AIRL #772

Merged
merged 7 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions docs/algorithms/airl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
learner = PPO(
env=env,
policy=MlpPolicy,
batch_size=16,
learning_rate=0.0001,
n_epochs=2,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0005,
gamma=0.95,
clip_range=0.1,
vf_coef=0.1,
n_epochs=5,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
Expand All @@ -72,9 +76,9 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
)
airl_trainer = AIRL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
demo_batch_size=2048,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=16,
venv=env,
gen_algo=learner,
reward_net=reward_net,
Expand All @@ -84,7 +88,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
learner_rewards_before_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)
airl_trainer.train(20000)
airl_trainer.train(20000) # Train for 2_000_000 steps to match expert.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 million timesteps is a lot of timesteps for something as simple as CartPole, I expect we can do better but this seems fine for the purpose of this PR, at least the environment runs quickly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep it for now (it's already an improvement) and possibly revisit in another PR.

env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
Expand Down
15 changes: 8 additions & 7 deletions docs/algorithms/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

Expand Down Expand Up @@ -60,20 +60,21 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
policy=MlpPolicy,
batch_size=64,
ent_coef=0.0,
learning_rate=0.00001,
n_epochs=1,
learning_rate=0.0004,
gamma=0.95,
n_epochs=5,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
reward_net = BasicRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=8,
venv=env,
gen_algo=learner,
reward_net=reward_net,
Expand All @@ -86,7 +87,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
)

# train the learner and evaluate again
gail_trainer.train(20000)
gail_trainer.train(20000) # Train for 800_000 steps to match expert.
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
Expand Down
38 changes: 20 additions & 18 deletions docs/tutorials/3_train_gail.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
"outputs": [],
"source": [
"from imitation.algorithms.adversarial.gail import GAIL\n",
"from imitation.rewards.reward_nets import BasicShapedRewardNet\n",
"from imitation.rewards.reward_nets import BasicRewardNet\n",
"from imitation.util.networks import RunningNorm\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.ppo import MlpPolicy\n",
Expand All @@ -100,20 +100,21 @@
" policy=MlpPolicy,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.00001,\n",
" n_epochs=1,\n",
" learning_rate=0.0004,\n",
" gamma=0.95,\n",
" n_epochs=5,\n",
" seed=SEED,\n",
")\n",
"reward_net = BasicShapedRewardNet(\n",
"reward_net = BasicRewardNet(\n",
" observation_space=env.observation_space,\n",
" action_space=env.action_space,\n",
" normalize_input_layer=RunningNorm,\n",
")\n",
"gail_trainer = GAIL(\n",
" demonstrations=rollouts,\n",
" demo_batch_size=1024,\n",
" gen_replay_buffer_capacity=2048,\n",
" n_disc_updates_per_round=4,\n",
" gen_replay_buffer_capacity=512,\n",
" n_disc_updates_per_round=8,\n",
" venv=env,\n",
" gen_algo=learner,\n",
" reward_net=reward_net,\n",
Expand All @@ -126,7 +127,7 @@
")\n",
"\n",
"# train the learner and evaluate again\n",
"gail_trainer.train(20000)\n",
"gail_trainer.train(800_000)\n",
"env.seed(SEED)\n",
"learner_rewards_after_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
Expand All @@ -137,7 +138,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least."
"We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):"
]
},
{
Expand All @@ -146,17 +147,18 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n",
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"\n",
"plt.hist(\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you removing histogram (here and in AIRL)? Fine to remove if it's not informative. But perhaps we should report the SD as well as the means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the reason was I thought it was not super informative (especially in case we reach expert perf). Good suggestion with SD though, will add!

Copy link
Collaborator

@ernestum ernestum Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shameless plug: this would be a nice application for my newly release data-samples-printer:

import data_samples_printer as dsp
dsp.pprint(
    before_training=learner_rewards_before_training,
    after_training=learner_rewards_after_training
)

prints something like:

▁  ▁      ▁▄  ▄▄▄█▇▄▄▇▄▇█▄█▃▃▇▄▇ ▇▁▃▄▁▃ ▄▃▁ ▁▁   ▁ -0.00 ±1.08 before_training
                      ▂▃▇█▄▄▂▁                     -0.01 ±0.20 after_training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ernestum , thanks for this, the lib looks quite cool! I'll remember about it in the future. For this PR I decided to not introduce additional dependency though.

" [learner_rewards_before_training, learner_rewards_after_training],\n",
" label=[\"untrained\", \"trained\"],\n",
"print(\n",
" \"Rewards before training:\",\n",
" np.mean(learner_rewards_before_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_before_training),\n",
")\n",
"plt.legend()\n",
"plt.show()"
"print(\n",
" \"Rewards after training:\",\n",
" np.mean(learner_rewards_after_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_after_training),\n",
")"
]
}
],
Expand Down
47 changes: 29 additions & 18 deletions docs/tutorials/4_train_airl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
"\n",
"SEED = 42\n",
"\n",
"FAST = True\n",
"\n",
"if FAST:\n",
" N_RL_TRAIN_STEPS = 800_000\n",
"else:\n",
" N_RL_TRAIN_STEPS = 2_000_000\n",
"\n",
"env = make_vec_env(\n",
" \"seals/CartPole-v0\",\n",
" rng=np.random.default_rng(SEED),\n",
Expand Down Expand Up @@ -96,10 +103,13 @@
"learner = PPO(\n",
" env=env,\n",
" policy=MlpPolicy,\n",
" batch_size=16,\n",
" batch_size=64,\n",
" ent_coef=0.0,\n",
" learning_rate=0.0001,\n",
" n_epochs=2,\n",
" learning_rate=0.0005,\n",
" gamma=0.95,\n",
" clip_range=0.1,\n",
" vf_coef=0.1,\n",
" n_epochs=5,\n",
" seed=SEED,\n",
")\n",
"reward_net = BasicShapedRewardNet(\n",
Expand All @@ -109,9 +119,9 @@
")\n",
"airl_trainer = AIRL(\n",
" demonstrations=rollouts,\n",
" demo_batch_size=1024,\n",
" gen_replay_buffer_capacity=2048,\n",
" n_disc_updates_per_round=4,\n",
" demo_batch_size=2048,\n",
" gen_replay_buffer_capacity=512,\n",
" n_disc_updates_per_round=16,\n",
" venv=env,\n",
" gen_algo=learner,\n",
" reward_net=reward_net,\n",
Expand All @@ -121,7 +131,7 @@
"learner_rewards_before_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
")\n",
"airl_trainer.train(20000)\n",
"airl_trainer.train(N_RL_TRAIN_STEPS)\n",
"env.seed(SEED)\n",
"learner_rewards_after_training, _ = evaluate_policy(\n",
" learner, env, 100, return_episode_rewards=True\n",
Expand All @@ -132,7 +142,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least."
"We can see that an untrained policy performs poorly, while AIRL brings an improvement. To make it match the expert performance (500), set the flag `FAST` to `False` in the first cell."
]
},
{
Expand All @@ -141,17 +151,18 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n",
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"\n",
"plt.hist(\n",
" [learner_rewards_before_training, learner_rewards_after_training],\n",
" label=[\"untrained\", \"trained\"],\n",
"print(\n",
" \"Rewards before training:\",\n",
" np.mean(learner_rewards_before_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_before_training),\n",
")\n",
"plt.legend()\n",
"plt.show()"
"print(\n",
" \"Rewards after training:\",\n",
" np.mean(learner_rewards_after_training),\n",
" \"+/-\",\n",
" np.std(learner_rewards_after_training),\n",
")"
]
}
],
Expand Down
Loading