Skip to content

Commit

Permalink
Modify .rst docs for density estimation to match tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
michalzajac-ml committed Sep 5, 2023
1 parent 35f47a8 commit d37c66c
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions docs/algorithms/density.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,24 @@ Detailed example notebook: :doc:`../tutorials/7_train_density`
env = util.make_vec_env("Pendulum-v1", rng=rng, n_envs=2)
rollouts = serialize.load("../tests/testdata/expert_models/pendulum_0/rollouts/final.npz")

imitation_trainer = PPO(ActorCriticPolicy, env)
imitation_trainer = PPO(
ActorCriticPolicy,
env,
learning_rate=3e-4,
gamma=0.95,
ent_coef=1e-4,
n_steps=2048
)
density_trainer = db.DensityAlgorithm(
venv=env,
rng=rng,
demonstrations=rollouts,
rl_algo=imitation_trainer,
rng=rng,
density_type=db.DensityType.STATE_ACTION_DENSITY,
is_stationary=True,
kernel="gaussian",
kernel_bandwidth=0.4,
standardise_inputs=True,
)
density_trainer.train()

Expand All @@ -63,7 +75,7 @@ Detailed example notebook: :doc:`../tutorials/7_train_density`
print("Stats before training:")
print_stats(density_trainer, 1)

density_trainer.train_policy(100)
density_trainer.train_policy(100) # Train for 1_000_000 steps to approach the expert.

print("Stats after training:")
print_stats(density_trainer, 1)
Expand Down

0 comments on commit d37c66c

Please sign in to comment.