Skip to content

Commit

Permalink
Download experts from hf inside tutorials and docs (#766)
Browse files Browse the repository at this point in the history
* Download expert policies from HF hub instead of training in docs/tutorials

* Download expert policies from HF hub instead of training in docs/algorithms

* Download expert policies from HF hub by default in quickstart.py

* Fix broken import in 3_train_gail.ipynb

* Fix broken call to load_policy in experts.rst docs

* Consistently use VecEnv environments in tutorial notebooks

* Use VecEnv environments in quickstart.py

* Suppress unused import warning for "seals" package in notebooks

* Consistently use VecEnv environments in docs/algorithms

* Fix missing imports in some algorithm docs

* Adapt hyperparameters in GAIL and AIRL notebooks and seed everywhere

* Fix imports in 3_train_gail.ipynb

* Revert 9_compare_baselines.ipynb to (almost) the version on master

* Increase expert training steps in quickstart.py

* Adapt hyperparameters in GAIL and AIRL documentation and seed everywhere

* Reuse existing VecEnv in code examples in docs/algorithms

* Reuse existing VecEnv in tutorial notebooks for GAIL and AIRL
  • Loading branch information
jas-ho authored Aug 9, 2023
1 parent 19c7f35 commit 60d8686
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 229 deletions.
63 changes: 41 additions & 22 deletions docs/algorithms/airl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,75 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
:skipif: skip_doctests

import numpy as np
import gym
import seals # noqa: F401 # needed to load "seals/" environments
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.airl import AIRL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)

env = gym.make("seals/CartPole-v0")
expert = PPO(policy=MlpPolicy, env=env)
expert.learn(1000)
SEED = 42

env = make_vec_env(
"seals/CartPole-v0",
rng=np.random.default_rng(SEED),
n_envs=8,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)
rollouts = rollout.rollout(
expert,
make_vec_env(
"seals/CartPole-v0",
rng=rng,
n_envs=5,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
),
rollout.make_sample_until(min_timesteps=None, min_episodes=60),
rng=rng,
env,
rollout.make_sample_until(min_episodes=60),
rng=np.random.default_rng(SEED),
)

venv = make_vec_env("seals/CartPole-v0", rng=rng, n_envs=8)
learner = PPO(env=venv, policy=MlpPolicy)
learner = PPO(
env=env,
policy=MlpPolicy,
batch_size=16,
learning_rate=0.0001,
n_epochs=2,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
venv.observation_space,
venv.action_space,
observation_space=env.observation_space,
action_space=env.action_space,
normalize_input_layer=RunningNorm,
)
airl_trainer = AIRL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
venv=venv,
venv=env,
gen_algo=learner,
reward_net=reward_net,
)

env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)
airl_trainer.train(20000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)

print("mean reward after training:", np.mean(learner_rewards_after_training))
print("mean reward before training:", np.mean(learner_rewards_before_training))

.. testoutput::
:hide:
Expand Down
25 changes: 16 additions & 9 deletions docs/algorithms/bc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,31 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc`
:skipif: skip_doctests

import numpy as np
import gym
from stable_baselines3 import PPO
import seals # noqa: F401 # needed to load "seals/" environments
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)
env = gym.make("CartPole-v1")
expert = PPO(policy=MlpPolicy, env=env)
expert.learn(1000)

env = make_vec_env(
"seals/CartPole-v0",
rng=rng,
n_envs=1,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # for computing rollouts
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)
rollouts = rollout.rollout(
expert,
DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)
Expand Down
26 changes: 16 additions & 10 deletions docs/algorithms/dagger.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,27 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger`
:skipif: skip_doctests

import tempfile

import numpy as np
import gym
from stable_baselines3 import PPO
import seals # noqa: F401 # needed to load "seals/" environments
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms import bc
from imitation.algorithms.dagger import SimpleDAggerTrainer
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)
env = gym.make("CartPole-v1")
expert = PPO(policy=MlpPolicy, env=env)
expert.learn(1000)
venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env = make_vec_env(
"seals/CartPole-v0",
rng=rng,
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)

bc_trainer = bc.BC(
observation_space=env.observation_space,
Expand All @@ -48,13 +54,13 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger`
with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
print(tmpdir)
dagger_trainer = SimpleDAggerTrainer(
venv=venv,
venv=env,
scratch_dir=tmpdir,
expert_policy=expert,
bc_trainer=bc_trainer,
rng=rng,
)
dagger_trainer.train(2000)
dagger_trainer.train(8_000)

reward, _ = evaluate_policy(dagger_trainer.policy, env, 10)
print("Reward:", reward)
Expand Down
67 changes: 45 additions & 22 deletions docs/algorithms/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,57 +20,80 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail`
:skipif: skip_doctests

import numpy as np
import gym
import seals # noqa: F401 # needed to load "seals/" environments
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.gail import GAIL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.policies.serialize import load_policy
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)
SEED = 42

env = gym.make("seals/CartPole-v0")
expert = PPO(policy=MlpPolicy, env=env, n_steps=64)
expert.learn(1000)
env = make_vec_env(
"seals/CartPole-v0",
rng=np.random.default_rng(SEED),
n_envs=8,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts
)
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)

rollouts = rollout.rollout(
expert,
make_vec_env(
"seals/CartPole-v0",
n_envs=5,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
rng=rng,
),
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=60),
rng=rng,
rng=np.random.default_rng(SEED),
)

venv = make_vec_env("seals/CartPole-v0", n_envs=8, rng=rng)
learner = PPO(env=venv, policy=MlpPolicy)
reward_net = BasicRewardNet(
venv.observation_space,
venv.action_space,
learner = PPO(
env=env,
policy=MlpPolicy,
batch_size=64,
ent_coef=0.0,
learning_rate=0.00001,
n_epochs=1,
seed=SEED,
)
reward_net = BasicShapedRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
demonstrations=rollouts,
demo_batch_size=1024,
gen_replay_buffer_capacity=2048,
n_disc_updates_per_round=4,
venv=venv,
venv=env,
gen_algo=learner,
reward_net=reward_net,
)

# evaluate the learner before training
env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)

# train the learner and evaluate again
gail_trainer.train(20000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)
env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)

print("mean reward after training:", np.mean(learner_rewards_after_training))
print("mean reward before training:", np.mean(learner_rewards_before_training))

.. testoutput::
:hide:
Expand Down
3 changes: 2 additions & 1 deletion docs/main-concepts/experts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ When using the Python API, you also have to specify the environment name as `env
remote_policy = load_policy(
"ppo-huggingface",
organization="your-org",
env_name="your-env"
env_name="your-env",
venv=venv,
)
)
Expand Down
Loading

0 comments on commit 60d8686

Please sign in to comment.