Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mceirl train #805

Draft
wants to merge 64 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
b4210c1
Merge py file changes from benchmark-algs
taufeeque9 Jan 4, 2023
97bc063
Clean parallel script
taufeeque9 Jan 10, 2023
9291225
Undo the changes from #653 to the dagger benchmark config files.
ernestum Jan 26, 2023
276d863
Improve readability and interpretability of benchmarking tests.
ernestum Jan 25, 2023
37eb914
Add pxponential beta scheduler for dagger
taufeeque9 Mar 1, 2023
877383b
Ignore coverage for unknown algorithms.
ernestum Feb 2, 2023
c8e55cb
Cleanup and extend tests for beta schedules in dagger.
ernestum Feb 2, 2023
6b9b306
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 6, 2023
8576465
Fix test cases
taufeeque9 Feb 8, 2023
d81eb68
Add optuna to dependencies
taufeeque9 Feb 8, 2023
27467d3
Fix test case
taufeeque9 Feb 8, 2023
b59a768
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 8, 2023
1a3b6b8
Clean up the scripts
taufeeque9 Feb 9, 2023
7a438da
Remove reporter(done) since mean_return is reported by the runs
taufeeque9 Feb 9, 2023
5bc5835
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 20, 2023
2e56de8
Add beta_schedule parameter to dagger script
taufeeque9 Feb 23, 2023
84e854a
Merge branch 'master' into benchmark-pr
taufeeque9 Mar 16, 2023
73d8576
Update config policy kwargs
taufeeque9 Mar 16, 2023
9fdf878
Changes from review
taufeeque9 May 16, 2023
1c1dbc4
Fix errors with some configs
taufeeque9 May 16, 2023
3467af2
Merge branch 'master' into benchmark-pr
taufeeque9 May 16, 2023
44c4e97
Updates based on review
taufeeque9 Jun 14, 2023
4d493ae
Merge branch 'master' into benchmark-pr
taufeeque9 Jun 14, 2023
ab01269
Change metric everywhere
taufeeque9 Jun 14, 2023
f64580e
Merge branch 'master' into benchmark-pr
taufeeque9 Jul 11, 2023
e896d7d
Separate tuning code from parallel.py
taufeeque9 Jul 11, 2023
64c3a8d
Fix docstring
taufeeque9 Jul 11, 2023
8fba0d3
Removing resume option as it is getting tricky to correctly implement
taufeeque9 Jul 11, 2023
12ab31c
Minor fixes
taufeeque9 Jul 11, 2023
19b0f2c
Updates from review
taufeeque9 Jul 16, 2023
046b8d9
fix lint error
taufeeque9 Jul 16, 2023
8eee082
Add documentation for using the tuning script
taufeeque9 Jul 16, 2023
5ce7658
Fix lint error
taufeeque9 Jul 17, 2023
a8be331
Updates from the review
taufeeque9 Jul 18, 2023
4ff006d
Fix file name test errors
taufeeque9 Jul 18, 2023
6933afa
Add tune_run_kwargs in parallel script
taufeeque9 Jul 19, 2023
77f9d9b
Fix test errors
taufeeque9 Jul 19, 2023
54eb8a6
Fix test
taufeeque9 Jul 19, 2023
d50238f
Fix lint
taufeeque9 Jul 19, 2023
3fe22d4
Updates from review
taufeeque9 Jul 19, 2023
c50aa20
Simplify few lines of code
taufeeque9 Jul 20, 2023
000af61
Updates from review
taufeeque9 Aug 4, 2023
8b55134
Fix test
taufeeque9 Aug 4, 2023
f3ba2b5
Revert "Fix test"
taufeeque9 Aug 4, 2023
f8251c7
Fix test
taufeeque9 Aug 4, 2023
664fc37
Convert Dict to Mapping in input argument
taufeeque9 Aug 7, 2023
8690e1d
Ignore coverage in script configurations.
ernestum Aug 30, 2023
dd9eb6a
Pin huggingface_sb3 version.
ernestum Aug 30, 2023
b3930f4
Merge branch 'master' into benchmark-pr
ernestum Sep 26, 2023
40d87ef
Update to the newest seals environment versions.
ernestum Sep 26, 2023
71f6c92
Push gymnasium dependency to 0.29 to ensure mujoco envs work.
ernestum Sep 27, 2023
747ad32
Incorporate review comments
taufeeque9 Oct 4, 2023
691e759
Fix test errors
taufeeque9 Oct 4, 2023
2038e60
Move benchmarking/ to scripts/ and add named configs for tuned hyperp…
taufeeque9 Oct 4, 2023
35c7265
Bump cache version & remove unnecessary files
taufeeque9 Oct 5, 2023
fdf4f49
Include tuned hyperparam json files in package data
taufeeque9 Oct 5, 2023
5f9a4e6
Update storage hash
taufeeque9 Oct 5, 2023
91bb785
Update search space of bc
taufeeque9 Oct 5, 2023
3d93c84
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
f59fea2
update benchmark and hyper parameter tuning readme
ZiyueWang25 Oct 5, 2023
95110dc
Update README.md
taufeeque9 Oct 5, 2023
10ec8a2
mce_irl_train
ZiyueWang25 Oct 6, 2023
7436784
add train_mce_irl script
ZiyueWang25 Oct 6, 2023
1ac7848
small fix
ZiyueWang25 Oct 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions benchmarking/README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,42 @@
# Benchmarking imitation

This directory contains sacred configuration files for benchmarking imitation's algorithms. For v0.3.2, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://www.rocamonde.com/publication/gleave-imitation-2022/).
The `src/imitation/scripts/config/tuned_hps` directory provides the tuned hyperparameter configs for benchmarking imitation. For v0.4.0, these correspond to the hyperparameters used in the paper [imitation: Clean Imitation Learning Implementations](https://www.rocamonde.com/publication/gleave-imitation-2022/).

Configuration files can be loaded either from the CLI or from the Python API. The examples below assume that your current working directory is the root of the `imitation` repository. This is not necessarily the case and you should adjust your paths accordingly.
Configuration files can be loaded either from the CLI or from the Python API.

## CLI

```bash
python -m imitation.scripts.<train_script> <algo> with benchmarking/<config_name>.json
python -m imitation.scripts.<train_script> <algo> with <algo>_<env>
```
`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`.
`train_script` can be either 1) `train_imitation` with `algo` as `bc` or `dagger` or 2) `train_adversarial` with `algo` as `gail` or `airl`. The `env` can be either of `seals_ant`, `seals_half_cheetah`, `seals_hopper`, `seals_swimmer`, or `seals_walker`. The hyperparameters for other environments are not tuned yet. You can either the tuned hyperparameter for any of the other environments or tune the hyperparameters using the `tuning` script.

## Python

```python
...
ex.add_config('benchmarking/<config_name>.json')
from imitation.scripts.<train_script> import <train_ex>
<train_ex>.run(command_name="<algo>", named_configs=["<algo>_<env>"])
```

# Tuning Hyperparameters

The hyperparameters of any algorithm in imitation can be tuned using the `scripts/tuning.py`.
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
the search space defined in the `scripts/config/tuning.py`.

The tuning script proceeds in two phases:
1. Tune the hyperparameters using the search space provided.
2. Re-evaluate the best hyperparameter config found in the first phase based on the maximum mean return on a separate set of seeds. Report the mean and standard deviation of these trials.

To use it with the default search space:
```bash
python -m imitation.scripts.tuning with <algo> 'parallel_run_config.base_named_configs=["<env>"]'
```

In this command:
- `<algo>` provides the default search space and settings for the specific algorithm, which is defined in the `scripts/config/tuning.py`
- `<env>` sets the environment to tune the algorithm in. They are defined in the algo-specifc `scripts/config/train_[adversarial/imitation/preference_comparisons/rl].py` files. For the already tuned environments, use the `<algo>_<env>` named configs here.

See the documentation of `scripts/tuning.py` and `scripts/parallel.py` for many other arguments that can be
provided through the command line to change the tuning behavior.
2 changes: 1 addition & 1 deletion benchmarking/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def clean_config_file(file: pathlib.Path, write_path: pathlib.Path, /) -> None:

remove_empty_dicts(config)
# files are of the format
# /path/to/file/example_<algo>_<env>_best_hp_eval/<other_info>/sacred/1/config.json
# /path/to/file/<algo>_<env>_best_hp_eval/<other_info>/sacred/1/config.json
# we want to write to /<write_path>/<algo>_<env>.json
with open(write_path / f"{file.parents[3].name}.json", "w") as f:
json.dump(config, f, indent=4)
Expand Down
29 changes: 16 additions & 13 deletions experiments/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@

For example, we can run:

TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps
python commands.py \
--name=run0 \
--cfg_pattern=../benchmarking/*ai*_seals_walker_*.json \
--cfg_pattern=$TUNED_HPS_DIR/*ai*_seals_walker_*.json \
--output_dir=output

And get the following commands printed out:

python -m imitation.scripts.train_adversarial airl \
--capture=sys --name=run0 \
--file_storage=output/sacred/$USER-cmd-run0-airl-0-a3531726 \
with ../benchmarking/example_airl_seals_walker_best_hp_eval.json \
with ../src/imitation/scripts/config/tuned_hps/airl_seals_walker_best_hp_eval.json \
seed=0 logging.log_root=output

python -m imitation.scripts.train_adversarial gail \
--capture=sys --name=run0 \
--file_storage=output/sacred/$USER-cmd-run0-gail-0-a1ec171b \
with ../benchmarking/example_gail_seals_walker_best_hp_eval.json \
with $TUNED_HPS_DIR/gail_seals_walker_best_hp_eval.json \
seed=0 logging.log_root=output

We can execute commands in parallel by piping them to GNU parallel:
Expand All @@ -40,9 +41,10 @@

For example, we can run:

TUNED_HPS_DIR=../src/imitation/scripts/config/tuned_hps
python commands.py \
--name=run0 \
--cfg_pattern=../benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \
--cfg_pattern=$TUNED_HPS_DIR/bc_seals_half_cheetah_best_hp_eval.json \
--output_dir=/data/output \
--remote

Expand All @@ -51,8 +53,9 @@
ctl job run --name $USER-cmd-run0-bc-0-72cb1df3 \
--command "python -m imitation.scripts.train_imitation bc \
--capture=sys --name=run0 \
--file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 \
with /data/imitation/benchmarking/example_bc_seals_half_cheetah_best_hp_eval.json \
--file_storage=/data/output/sacred/$USER-cmd-run0-bc-0-72cb1df3 with \
/data/imitation/src/imitation/scripts/config/tuned_hps/
bc_seals_half_cheetah_best_hp_eval.json \
seed=0 logging.log_root=/data/output" \
--container hacobe/devbox:imitation \
--login --force-pull --never-restart --gpu 0 --shared-host-dir-mount /data
Expand Down Expand Up @@ -85,7 +88,7 @@ def _get_algo_name(cfg_file: str) -> str:
"""Get the algorithm name from the given config filename."""
algo_names = set()
for key in _ALGO_NAME_TO_SCRIPT_NAME:
if cfg_file.find("_" + key + "_") != -1:
if cfg_file.find(key + "_") != -1:
algo_names.add(key)

if len(algo_names) == 0:
Expand Down Expand Up @@ -177,19 +180,19 @@ def parse() -> argparse.Namespace:
parser.add_argument(
"--cfg_pattern",
type=str,
default="example_bc_seals_half_cheetah_best_hp_eval.json",
default="bc_seals_half_cheetah_best_hp_eval.json",
help="""Generate a command for every file that matches this glob pattern. \
Each matching file should be a config file that has its algorithm name \
(bc, dagger, airl or gail) bookended by underscores in the filename. \
If the --remote flag is enabled, then generate a command for every file in the \
--remote_cfg_dir directory that has the same filename as a file that matches this \
glob pattern. E.g., suppose the current, local working directory is 'foo' and \
the subdirectory 'foo/bar' contains the config files 'example_bc_best.json' and \
'example_dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \
will return ['bar/example_bc_best.json', 'bar/example_dagger_best.json']. \
the subdirectory 'foo/bar' contains the config files 'bc_best.json' and \
'dagger_best.json'. If the pattern 'bar/*.json' is supplied, then globbing \
will return ['bar/bc_best.json', 'bar/dagger_best.json']. \
If the --remote flag is enabled, 'bar' will be replaced with `remote_cfg_dir` and \
commands will be created for the following configs: \
[`remote_cfg_dir`/example_bc_best.json, `remote_cfg_dir`/example_dagger_best.json] \
[`remote_cfg_dir`/bc_best.json, `remote_cfg_dir`/dagger_best.json] \
Why not just supply the pattern '`remote_cfg_dir`/*.json' directly? \
Because the `remote_cfg_dir` directory may not exist on the local machine.""",
)
Expand Down Expand Up @@ -220,7 +223,7 @@ def parse() -> argparse.Namespace:
parser.add_argument(
"--remote_cfg_dir",
type=str,
default="/data/imitation/benchmarking",
default="/data/imitation/src/imitation/scripts/config/tuned_hps",
help="""Path to a directory storing config files \
accessible from each container. """,
)
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ source = imitation
include=
src/*
tests/*
omit =
src/imitation/scripts/config/*

[coverage:report]
exclude_lines =
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
python_requires=">=3.8.0",
packages=find_packages("src"),
package_dir={"": "src"},
package_data={"imitation": ["py.typed", "envs/examples/airl_envs/assets/*.xml"]},
package_data={"imitation": ["py.typed", "scripts/config/tuned_hps/*.json"]},
# Note: while we are strict with our test and doc requirement versions, we try to
# impose as little restrictions on the install requirements as possible. Try to
# encode only known incompatibilities here. This prevents nasty dependency issues
# for our users.
install_requires=[
"gymnasium[classic-control]~=0.28.1",
"gymnasium[classic-control]~=0.29",
"matplotlib",
"numpy>=1.15",
"torch>=1.4.0",
Expand All @@ -200,6 +200,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"sacred>=0.8.4",
"tensorboard>=1.14",
"huggingface_sb3~=3.0",
"optuna>=3.0.1",
"datasets>=2.8.0",
],
tests_require=TESTS_REQUIRE,
Expand All @@ -220,7 +221,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"docs": DOCS_REQUIRE,
"parallel": PARALLEL_REQUIRE,
"mujoco": [
"gymnasium[classic-control,mujoco]~=0.28.1",
"gymnasium[classic-control,mujoco]~=0.29",
],
"atari": ATARI_REQUIRE,
},
Expand Down
37 changes: 23 additions & 14 deletions src/imitation/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,38 +262,47 @@ def analyze_imitation(
csv_output_path: If provided, then save a CSV output file to this path.
tex_output_path: If provided, then save a LaTeX-format table to this path.
print_table: If True, then print the dataframe to stdout.
table_verbosity: Increasing levels of verbosity, from 0 to 2, increase the
number of columns in the table.
table_verbosity: Increasing levels of verbosity, from 0 to 3, increase the
number of columns in the table. Level 3 prints all of the columns available.

Returns:
The DataFrame generated from the Sacred logs.
"""
table_entry_fns_subset = _get_table_entry_fns_subset(table_verbosity)
# Get column names for which we have get value using make_entry_fn
# These are same across Level 2 & 3. In Level 3, we additionally add remaining
# config columns.
table_entry_fns_subset = _get_table_entry_fns_subset(min(table_verbosity, 2))

rows = []
output_table = pd.DataFrame()
for sd in _gather_sacred_dicts():
row = {}
if table_verbosity == 3:
# gets all config columns
row = pd.json_normalize(sd.config)
else:
# create an empty dataframe with a single row
row = pd.DataFrame(index=[0])

for col_name, make_entry_fn in table_entry_fns_subset.items():
row[col_name] = make_entry_fn(sd)
rows.append(row)

df = pd.DataFrame(rows)
if len(df) > 0:
df.sort_values(by=["algo", "env_name"], inplace=True)
output_table = pd.concat([output_table, row])

if len(output_table) > 0:
output_table.sort_values(by=["algo", "env_name"], inplace=True)

display_options = dict(index=False)
display_options: Mapping[str, Any] = dict(index=False)
if csv_output_path is not None:
df.to_csv(csv_output_path, **display_options)
output_table.to_csv(csv_output_path, **display_options)
print(f"Wrote CSV file to {csv_output_path}")
if tex_output_path is not None:
s: str = df.to_latex(**display_options)
s: str = output_table.to_latex(**display_options)
with open(tex_output_path, "w") as f:
f.write(s)
print(f"Wrote TeX file to {tex_output_path}")

if print_table:
print(df.to_string(**display_options))
return df
print(output_table.to_string(**display_options))
return output_table


def _make_return_summary(stats: dict, prefix="") -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/scripts/config/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def config():
tex_output_path = None # Write LaTex output to this path
print_table = True # Set to True to print analysis to stdout
split_str = "," # str used to split source_dir_str into multiple source dirs
table_verbosity = 1 # Choose from 0, 1, or 2
table_verbosity = 1 # Choose from 0, 1, 2 or 3
source_dirs = None


Expand Down
79 changes: 13 additions & 66 deletions src/imitation/scripts/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
`@parallel_ex.named_config` to define a new parallel experiment.

Adding custom named configs is necessary because the CLI interface can't add
search spaces to the config like `"seed": tune.grid_search([0, 1, 2, 3])`.
search spaces to the config like `"seed": tune.choice([0, 1, 2, 3])`.

For tuning hyperparameters of an algorithm on a given environment,
check out the imitation/scripts/tuning.py script.
"""

import numpy as np
Expand All @@ -31,19 +34,10 @@ def config():
"config_updates": {},
} # `config` argument to `ray.tune.run(trainable, config)`

local_dir = None # `local_dir` arg for `ray.tune.run`
upload_dir = None # `upload_dir` arg for `ray.tune.run`
n_seeds = 3 # Number of seeds to search over by default


@parallel_ex.config
def seeds(n_seeds):
search_space = {"config_updates": {"seed": tune.grid_search(list(range(n_seeds)))}}


@parallel_ex.named_config
def s3():
upload_dir = "s3://shwang-chai/private"
num_samples = 1 # Number of samples per grid search configuration
repeat = 1 # Number of times to repeat a sampled configuration
experiment_checkpoint_path = "" # Path to checkpoint of experiment
tune_run_kwargs = {} # Additional kwargs to pass to `tune.run`


# Debug named configs
Expand All @@ -58,12 +52,12 @@ def generate_test_data():
"""
sacred_ex_name = "train_rl"
run_name = "TEST"
n_seeds = 1
repeat = 1
search_space = {
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(
"learning_rate": tune.choice(
[3e-4 * x for x in (1 / 3, 1 / 2)],
),
},
Expand All @@ -86,63 +80,16 @@ def generate_test_data():
def example_cartpole_rl():
sacred_ex_name = "train_rl"
run_name = "example-cartpole"
n_seeds = 2
repeat = 2
search_space = {
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.grid_search([16, 32, 64]),
"learning_rate": tune.choice(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.choice([16, 32, 64]),
},
},
},
}
base_named_configs = ["cartpole"]
resources_per_trial = dict(cpu=4)


EASY_ENVS = ["cartpole", "pendulum", "mountain_car"]


@parallel_ex.named_config
def example_rl_easy():
sacred_ex_name = "train_rl"
run_name = "example-rl-easy"
n_seeds = 2
search_space = {
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
"config_updates": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(np.logspace(3e-6, 1e-1, num=3)),
"nminibatches": tune.grid_search([16, 32, 64]),
},
},
},
}
resources_per_trial = dict(cpu=4)


@parallel_ex.named_config
def example_gail_easy():
sacred_ex_name = "train_adversarial"
run_name = "example-gail-easy"
n_seeds = 1
search_space = {
"named_configs": tune.grid_search([[env] for env in EASY_ENVS]),
"config_updates": {
"init_trainer_kwargs": {
"rl": {
"rl_kwargs": {
"learning_rate": tune.grid_search(
np.logspace(3e-6, 1e-1, num=3),
),
"nminibatches": tune.grid_search([16, 32, 64]),
},
},
},
},
}
search_space = {
"command_name": "gail",
}
Loading
Loading