diff --git a/src/imitation_cli/config/__init__.py b/src/imitation_cli/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/imitation_cli/dagger.py b/src/imitation_cli/dagger.py index 955c71991..67e267d91 100644 --- a/src/imitation_cli/dagger.py +++ b/src/imitation_cli/dagger.py @@ -57,11 +57,6 @@ class RunConfig: ) -@hydra.main( - version_base=None, - config_path="config", - config_name="dagger_run", -) def run_dagger(cfg: RunConfig): dagger_trainer: dagger.DAggerTrainer = instantiate(cfg.dagger) @@ -73,5 +68,14 @@ def run_dagger(cfg: RunConfig): dagger_trainer.save_trainer() +@hydra.main( + version_base=None, + config_path="config", + config_name="dagger_run", +) +def main(cfg: RunConfig): + run_dagger(cfg) + + if __name__ == "__main__": - run_dagger() + main() diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/cli/test_airl_cli.py b/tests/cli/test_airl_cli.py new file mode 100644 index 000000000..d4f793687 --- /dev/null +++ b/tests/cli/test_airl_cli.py @@ -0,0 +1,39 @@ +import hydra +import pytest + +from imitation.algorithms.adversarial import airl + +# Note: this import is needed to ensure that configurations are properly registered +from imitation_cli.airl import RunConfig + + +@pytest.fixture +def airl_run_config(tmpdir) -> RunConfig: + """A AIRL run config with a temporary directory as the output directory.""" + with hydra.initialize_config_module( + version_base=None, + config_module="imitation_cli.config", + ): + yield hydra.compose( + config_name="airl_run", + overrides=[f"hydra.run.dir={tmpdir}"], + # This is needed to ensure that variable interpolation to hydra.run.dir + # works properly + return_hydra_config=True, + ) + + +@pytest.fixture +def airl_trainer(airl_run_config: RunConfig) -> airl.AIRL: + return hydra.utils.instantiate(airl_run_config.airl) + + +def test_train_airl_trainer_some_steps_smoke(airl_trainer: airl.AIRL): + # WHEN + # Note: any value lower than 16386 will raise an exception + airl_trainer.train(16386) + + # THEN + # No exception is raised + + diff --git a/tests/cli/test_bc_cli.py b/tests/cli/test_bc_cli.py new file mode 100644 index 000000000..78e3fbcde --- /dev/null +++ b/tests/cli/test_bc_cli.py @@ -0,0 +1,38 @@ +import hydra +import pytest + +from imitation.algorithms import bc + +# Note: this import is needed to ensure that configurations are properly registered +from imitation_cli.bc import RunConfig + + +@pytest.fixture +def bc_run_config(tmpdir) -> RunConfig: + """A BC run config with a temporary directory as the output directory.""" + with hydra.initialize_config_module( + version_base=None, + config_module="imitation_cli.config", + ): + yield hydra.compose( + config_name="bc_run", + overrides=[f"hydra.run.dir={tmpdir}"], + # This is needed to ensure that variable interpolation to hydra.run.dir + # works properly + return_hydra_config=True, + ) + + +@pytest.fixture +def bc_trainer(bc_run_config: RunConfig) -> bc.BC: + return hydra.utils.instantiate(bc_run_config.bc) + + +def test_train_bc_trainer_one_batch_smoke(bc_trainer: bc.BC): + # WHEN + bc_trainer.train(n_batches=1) + + # THEN + # No exception is raised + + diff --git a/tests/cli/test_dagger_cli.py b/tests/cli/test_dagger_cli.py new file mode 100644 index 000000000..e74c0fc7a --- /dev/null +++ b/tests/cli/test_dagger_cli.py @@ -0,0 +1,38 @@ +import hydra +import pytest + +from imitation.algorithms import dagger + +# Note: this import is needed to ensure that configurations are properly registered +from imitation_cli.dagger import RunConfig + + +@pytest.fixture +def dagger_run_config(tmpdir) -> RunConfig: + """A DAgger run config with a temporary directory as the output directory.""" + with hydra.initialize_config_module( + version_base=None, + config_module="imitation_cli.config", + ): + yield hydra.compose( + config_name="dagger_run", + overrides=[f"hydra.run.dir={tmpdir}"], + # This is needed to ensure that variable interpolation to hydra.run.dir + # works properly + return_hydra_config=True, + ) + + +@pytest.fixture +def simple_dagger_trainer(dagger_run_config: RunConfig) -> dagger.SimpleDAggerTrainer: + return hydra.utils.instantiate(dagger_run_config.dagger) + + +def test_train_dagger_one_step_smoke(simple_dagger_trainer): + # WHEN + simple_dagger_trainer.train(1) + + # THEN + # No exception is raised + +