diff --git a/docs/algorithms/density.rst b/docs/algorithms/density.rst index 13458534f..accdc6d9a 100644 --- a/docs/algorithms/density.rst +++ b/docs/algorithms/density.rst @@ -43,12 +43,24 @@ Detailed example notebook: :doc:`../tutorials/7_train_density` env = util.make_vec_env("Pendulum-v1", rng=rng, n_envs=2) rollouts = serialize.load("../tests/testdata/expert_models/pendulum_0/rollouts/final.npz") - imitation_trainer = PPO(ActorCriticPolicy, env) + imitation_trainer = PPO( + ActorCriticPolicy, + env, + learning_rate=3e-4, + gamma=0.95, + ent_coef=1e-4, + n_steps=2048 + ) density_trainer = db.DensityAlgorithm( venv=env, + rng=rng, demonstrations=rollouts, rl_algo=imitation_trainer, - rng=rng, + density_type=db.DensityType.STATE_ACTION_DENSITY, + is_stationary=True, + kernel="gaussian", + kernel_bandwidth=0.4, + standardise_inputs=True, ) density_trainer.train() @@ -63,7 +75,7 @@ Detailed example notebook: :doc:`../tutorials/7_train_density` print("Stats before training:") print_stats(density_trainer, 1) - density_trainer.train_policy(100) + density_trainer.train_policy(100) # Train for 1_000_000 steps to approach the expert. print("Stats after training:") print_stats(density_trainer, 1)