Skip to content

Commit

Permalink
Some formatting fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Feb 29, 2024
1 parent 2e356b7 commit f4bcbdc
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 106 deletions.
2 changes: 1 addition & 1 deletion benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,4 @@ where:
- `algo` is the algorithm you want to compare against

If `your_runs_dir` contains runs for more than one algorithm, you will have to
disambiguate using the `--algo` option.
disambiguate using the `--algo` option.
4 changes: 3 additions & 1 deletion src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,9 @@ def train(
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
probs = unnormalized_probs / np.sum(unnormalized_probs)
shares = util.oric(probs * total_comparisons)
shares[shares <= 0] = 1 # ensure we at least request one comparison per iteration
shares[
shares <= 0
] = 1 # ensure we at least request one comparison per iteration

schedule = [initial_comparisons] + shares.tolist()
print(f"Query schedule: {schedule}")
Expand Down
12 changes: 9 additions & 3 deletions src/imitation/scripts/config/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,13 @@ def pc():
"named_configs": ["reward.reward_ensemble"],
"config_updates": {
"active_selection_oversampling": tune.randint(1, 11),
"comparison_queue_size": tune.randint(1, 1001), # upper bound determined by total_comparisons=1000
"comparison_queue_size": tune.randint(
1, 1001
), # upper bound determined by total_comparisons=1000
"exploration_frac": tune.uniform(0.0, 0.5),
"fragment_length": tune.randint(1, 1001), # trajectories are 1000 steps long
"fragment_length": tune.randint(
1, 1001
), # trajectories are 1000 steps long
"gatherer_kwargs": {
"temperature": tune.uniform(0.0, 2.0),
"discount_factor": tune.uniform(0.95, 1.0),
Expand All @@ -213,7 +217,9 @@ def pc():
"noise_prob": tune.uniform(0.0, 0.1),
"discount_factor": tune.uniform(0.95, 1.0),
},
"query_schedule": tune.choice(["hyperbolic", "constant", "inverse_quadratic"]),
"query_schedule": tune.choice(
["hyperbolic", "constant", "inverse_quadratic"]
),
"trajectory_generator_kwargs": {
"switch_prob": tune.uniform(0.1, 1),
"random_prob": tune.uniform(0.1, 0.9),
Expand Down
1 change: 0 additions & 1 deletion src/imitation/scripts/ingredients/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def dqn():
rl_cls = sb3.DQN



def _maybe_add_relabel_buffer(
rl_kwargs: Dict[str, Any],
relabel_reward_fn: Optional[RewardFn] = None,
Expand Down
8 changes: 4 additions & 4 deletions tuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ If you want to specify a custom algorithm and search space, add it to the dict i

You can tune using multiple workers in parallel by running multiple instances of `tune.py` that all point to the same journal log file (see `tune.py --help` for details).
To easily launch multiple workers on a SLURM cluster and ensure they don't conflict with each other,
use the `tune_on_slurm.py` script.
use the `tune_on_slurm.py` script.
This script will launch a SLURM job array with the specified number of workers.
If you want to tune all algorithms on all environments on SLURM, use `tune_all_on_slurm.sh`.

# Legacy Tuning Scripts

Note: There are some legacy tuning scripts that can be used like this:
Note: There are some legacy tuning scripts that can be used like this:

The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
the search space defined in the `scripts/config/tuning.py`.

The tuning script proceeds in two phases:
1. Tune the hyperparameters using the search space provided.
2. Re-evaluate the best hyperparameter config found in the first phase
based on the maximum mean return on a separate set of seeds.
2. Re-evaluate the best hyperparameter config found in the first phase
based on the maximum mean return on a separate set of seeds.
Report the mean and standard deviation of these trials.

To use it with the default search space:
Expand Down
83 changes: 49 additions & 34 deletions tuning/benchmark_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,47 +40,56 @@
"\n",
"for log_file in experiment_log_files:\n",
" d = dict()\n",
" \n",
" d['logfile'] = log_file\n",
" \n",
" study = optuna.load_study(storage=optuna.storages.JournalStorage(\n",
"\n",
" d[\"logfile\"] = log_file\n",
"\n",
" study = optuna.load_study(\n",
" storage=optuna.storages.JournalStorage(\n",
" optuna.storages.JournalFileStorage(str(log_file))\n",
" ),\n",
" # in our case, we have one journal file per study so the study name can be\n",
" # inferred\n",
" study_name=None,\n",
" )\n",
" d['study'] = study\n",
" d['study_name'] = study.study_name\n",
" \n",
" d[\"study\"] = study\n",
" d[\"study_name\"] = study.study_name\n",
"\n",
" trial_state_counter = Counter(t.state for t in study.trials)\n",
" n_completed_trials = trial_state_counter[TrialState.COMPLETE]\n",
" d['trials'] = n_completed_trials\n",
" d['trials_running'] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n",
" d['trials_failed'] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n",
" d['all_trials'] = len(study.trials)\n",
" \n",
" d[\"trials\"] = n_completed_trials\n",
" d[\"trials_running\"] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n",
" d[\"trials_failed\"] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n",
" d[\"all_trials\"] = len(study.trials)\n",
"\n",
" if n_completed_trials > 0:\n",
" d['best_value'] = round(study.best_trial.value, 2)\n",
" \n",
" d[\"best_value\"] = round(study.best_trial.value, 2)\n",
"\n",
" assert \"_\" in study.study_name\n",
" study_segments = study.study_name.split(\"_\") \n",
" study_segments = study.study_name.split(\"_\")\n",
" assert len(study_segments) > 3\n",
" tuning, algo, with_ = study_segments[:3]\n",
" assert (tuning, with_) == (\"tuning\", \"with\")\n",
" \n",
" d['algo'] = algo\n",
" d['env'] = \"_\".join(study_segments[3:])\n",
" d['best_trial_duration'] = study.best_trial.duration\n",
" d['mean_duration'] = sum([t.duration for t in study.trials if t.state == TrialState.COMPLETE], datetime.timedelta())/n_completed_trials\n",
" \n",
"\n",
" d[\"algo\"] = algo\n",
" d[\"env\"] = \"_\".join(study_segments[3:])\n",
" d[\"best_trial_duration\"] = study.best_trial.duration\n",
" d[\"mean_duration\"] = (\n",
" sum(\n",
" [t.duration for t in study.trials if t.state == TrialState.COMPLETE],\n",
" datetime.timedelta(),\n",
" )\n",
" / n_completed_trials\n",
" )\n",
"\n",
" reruns_folder = log_file.parent / \"reruns\"\n",
" rerun_results = [round(run['result']['imit_stats']['monitor_return_mean'], 2)\n",
" for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)]\n",
" d['rerun_values'] = rerun_results\n",
" \n",
" rerun_results = [\n",
" round(run[\"result\"][\"imit_stats\"][\"monitor_return_mean\"], 2)\n",
" for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)\n",
" ]\n",
" d[\"rerun_values\"] = rerun_results\n",
"\n",
" raw_study_data.append(d)\n",
" \n",
"\n",
"study_data = pd.DataFrame(raw_study_data)"
]
},
Expand All @@ -103,7 +112,7 @@
" \"seals_humanoid\",\n",
" \"seals_cartpole\",\n",
" \"pendulum\",\n",
" \"seals_mountain_car\"\n",
" \"seals_mountain_car\",\n",
"]\n",
"\n",
"pc_paper_700 = dict(\n",
Expand Down Expand Up @@ -163,12 +172,14 @@
" for env, value in values_by_env.items():\n",
" if value == \"-\":\n",
" continue\n",
" raw_study_data.append(dict(\n",
" algo=algo,\n",
" env=env,\n",
" best_value=value,\n",
" ))\n",
" \n",
" raw_study_data.append(\n",
" dict(\n",
" algo=algo,\n",
" env=env,\n",
" best_value=value,\n",
" )\n",
" )\n",
"\n",
"study_data = pd.DataFrame(raw_study_data)"
]
},
Expand All @@ -185,7 +196,11 @@
"display(study_data[[\"algo\", \"env\", \"best_value\"]])\n",
"\n",
"print(\"Rerun Data\")\n",
"display(study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][study_data[\"rerun_values\"].map(np.std) > 0])"
"display(\n",
" study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][\n",
" study_data[\"rerun_values\"].map(np.std) > 0\n",
" ]\n",
")"
]
}
],
Expand Down
Loading

0 comments on commit f4bcbdc

Please sign in to comment.