Skip to content

Commit

Permalink
Modify .rst docs for GAIL and AIRL to match tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
michalzajac-ml committed Sep 5, 2023
1 parent 22b391a commit 4a0678e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
18 changes: 11 additions & 7 deletions docs/algorithms/airl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
learner = PPO(
env=env,
policy=MlpPolicy,
batch_size=16,
learning_rate=0.0001,
n_epochs=2,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0005,
gamma=0.95,
clip_range=0.1,
vf_coef=0.1,
n_epochs=5,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
Expand All @@ -72,9 +76,9 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
)
airl_trainer = AIRL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
demo_batch_size=2048,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=16,
venv=env,
gen_algo=learner,
reward_net=reward_net,
Expand All @@ -84,7 +88,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
learner_rewards_before_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)
airl_trainer.train(20000)
airl_trainer.train(20000) # Train for 2_000_000 steps to match expert.
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
Expand Down
15 changes: 8 additions & 7 deletions docs/algorithms/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

Expand Down Expand Up @@ -60,20 +60,21 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
policy=MlpPolicy,
batch_size=64,
ent_coef=0.0,
learning_rate=0.00001,
n_epochs=1,
learning_rate=0.0004,
gamma=0.95,
n_epochs=5,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
reward_net = BasicRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=8,
venv=env,
gen_algo=learner,
reward_net=reward_net,
Expand All @@ -86,7 +87,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
)

# train the learner and evaluate again
gail_trainer.train(20000)
gail_trainer.train(20000) # Train for 800_000 steps to match expert.
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
Expand Down

0 comments on commit 4a0678e

Please sign in to comment.