diff --git a/docs/tutorials/3_train_gail.ipynb b/docs/tutorials/3_train_gail.ipynb index 11a6f3e94..f729aa002 100644 --- a/docs/tutorials/3_train_gail.ipynb +++ b/docs/tutorials/3_train_gail.ipynb @@ -89,7 +89,7 @@ "outputs": [], "source": [ "from imitation.algorithms.adversarial.gail import GAIL\n", - "from imitation.rewards.reward_nets import BasicShapedRewardNet\n", + "from imitation.rewards.reward_nets import BasicRewardNet\n", "from imitation.util.networks import RunningNorm\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.ppo import MlpPolicy\n", @@ -100,11 +100,12 @@ " policy=MlpPolicy,\n", " batch_size=64,\n", " ent_coef=0.0,\n", - " learning_rate=0.00001,\n", - " n_epochs=1,\n", + " learning_rate=0.0004,\n", + " gamma=0.95,\n", + " n_epochs=5,\n", " seed=SEED,\n", ")\n", - "reward_net = BasicShapedRewardNet(\n", + "reward_net = BasicRewardNet(\n", " observation_space=env.observation_space,\n", " action_space=env.action_space,\n", " normalize_input_layer=RunningNorm,\n", @@ -112,8 +113,8 @@ "gail_trainer = GAIL(\n", " demonstrations=rollouts,\n", " demo_batch_size=1024,\n", - " gen_replay_buffer_capacity=2048,\n", - " n_disc_updates_per_round=4,\n", + " gen_replay_buffer_capacity=512,\n", + " n_disc_updates_per_round=8,\n", " venv=env,\n", " gen_algo=learner,\n", " reward_net=reward_net,\n", @@ -126,7 +127,7 @@ ")\n", "\n", "# train the learner and evaluate again\n", - "gail_trainer.train(20000)\n", + "gail_trainer.train(800_000)\n", "env.seed(SEED)\n", "learner_rewards_after_training, _ = evaluate_policy(\n", " learner, env, 100, return_episode_rewards=True\n", @@ -137,7 +138,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least." + "We can see that an untrained policy performs poorly, while GAIL matches expert returns (500):" ] }, { @@ -146,17 +147,8 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n", - "print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n", - "\n", - "plt.hist(\n", - " [learner_rewards_before_training, learner_rewards_after_training],\n", - " label=[\"untrained\", \"trained\"],\n", - ")\n", - "plt.legend()\n", - "plt.show()" + "print(\"Mean reward before training:\", np.mean(learner_rewards_before_training))\n", + "print(\"Mean reward after training:\", np.mean(learner_rewards_after_training))" ] } ], diff --git a/docs/tutorials/4_train_airl.ipynb b/docs/tutorials/4_train_airl.ipynb index e7bb6bb99..b2f45a327 100644 --- a/docs/tutorials/4_train_airl.ipynb +++ b/docs/tutorials/4_train_airl.ipynb @@ -31,6 +31,13 @@ "\n", "SEED = 42\n", "\n", + "FAST = True\n", + "\n", + "if FAST:\n", + " N_RL_TRAIN_STEPS = 800_000\n", + "else:\n", + " N_RL_TRAIN_STEPS = 2_000_000\n", + "\n", "env = make_vec_env(\n", " \"seals/CartPole-v0\",\n", " rng=np.random.default_rng(SEED),\n", @@ -96,10 +103,13 @@ "learner = PPO(\n", " env=env,\n", " policy=MlpPolicy,\n", - " batch_size=16,\n", + " batch_size=64,\n", " ent_coef=0.0,\n", - " learning_rate=0.0001,\n", - " n_epochs=2,\n", + " learning_rate=0.0005,\n", + " gamma=0.95,\n", + " clip_range=0.1,\n", + " vf_coef=0.1,\n", + " n_epochs=5,\n", " seed=SEED,\n", ")\n", "reward_net = BasicShapedRewardNet(\n", @@ -109,9 +119,9 @@ ")\n", "airl_trainer = AIRL(\n", " demonstrations=rollouts,\n", - " demo_batch_size=1024,\n", - " gen_replay_buffer_capacity=2048,\n", - " n_disc_updates_per_round=4,\n", + " demo_batch_size=2048,\n", + " gen_replay_buffer_capacity=512,\n", + " n_disc_updates_per_round=16,\n", " venv=env,\n", " gen_algo=learner,\n", " reward_net=reward_net,\n", @@ -121,7 +131,7 @@ "learner_rewards_before_training, _ = evaluate_policy(\n", " learner, env, 100, return_episode_rewards=True\n", ")\n", - "airl_trainer.train(20000)\n", + "airl_trainer.train(N_RL_TRAIN_STEPS)\n", "env.seed(SEED)\n", "learner_rewards_after_training, _ = evaluate_policy(\n", " learner, env, 100, return_episode_rewards=True\n", @@ -132,7 +142,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "When we look at the histograms of rewards before and after learning, we can see that the learner is not perfect yet, but it made some progress at least." + "We can see that an untrained policy performs poorly, while AIRL brings an improvement. To make it match the expert performance (500), set the flag `FAST` to `False` in the first cell." ] }, { @@ -141,17 +151,8 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n", - "print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n", - "\n", - "plt.hist(\n", - " [learner_rewards_before_training, learner_rewards_after_training],\n", - " label=[\"untrained\", \"trained\"],\n", - ")\n", - "plt.legend()\n", - "plt.show()" + "print(\"Mean reward before training:\", np.mean(learner_rewards_before_training))\n", + "print(\"Mean reward after training:\", np.mean(learner_rewards_after_training))" ] } ],