Skip to content

Commit

Permalink
Set up smoke tests for the algorithm CLIs
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Jul 26, 2023
1 parent affba71 commit f9b01e3
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 6 deletions.
Empty file.
16 changes: 10 additions & 6 deletions src/imitation_cli/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ class RunConfig:
)


@hydra.main(
version_base=None,
config_path="config",
config_name="dagger_run",
)
def run_dagger(cfg: RunConfig):
dagger_trainer: dagger.DAggerTrainer = instantiate(cfg.dagger)

Expand All @@ -73,5 +68,14 @@ def run_dagger(cfg: RunConfig):
dagger_trainer.save_trainer()


@hydra.main(
version_base=None,
config_path="config",
config_name="dagger_run",
)
def main(cfg: RunConfig):
run_dagger(cfg)


if __name__ == "__main__":
run_dagger()
main()
Empty file added tests/cli/__init__.py
Empty file.
39 changes: 39 additions & 0 deletions tests/cli/test_airl_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import hydra
import pytest

from imitation.algorithms.adversarial import airl

# Note: this import is needed to ensure that configurations are properly registered
from imitation_cli.airl import RunConfig


@pytest.fixture
def airl_run_config(tmpdir) -> RunConfig:
"""A AIRL run config with a temporary directory as the output directory."""
with hydra.initialize_config_module(
version_base=None,
config_module="imitation_cli.config",
):
yield hydra.compose(
config_name="airl_run",
overrides=[f"hydra.run.dir={tmpdir}"],
# This is needed to ensure that variable interpolation to hydra.run.dir
# works properly
return_hydra_config=True,
)


@pytest.fixture
def airl_trainer(airl_run_config: RunConfig) -> airl.AIRL:
return hydra.utils.instantiate(airl_run_config.airl)


def test_train_airl_trainer_some_steps_smoke(airl_trainer: airl.AIRL):
# WHEN
# Note: any value lower than 16386 will raise an exception
airl_trainer.train(16386)

# THEN
# No exception is raised


38 changes: 38 additions & 0 deletions tests/cli/test_bc_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import hydra
import pytest

from imitation.algorithms import bc

# Note: this import is needed to ensure that configurations are properly registered
from imitation_cli.bc import RunConfig


@pytest.fixture
def bc_run_config(tmpdir) -> RunConfig:
"""A BC run config with a temporary directory as the output directory."""
with hydra.initialize_config_module(
version_base=None,
config_module="imitation_cli.config",
):
yield hydra.compose(
config_name="bc_run",
overrides=[f"hydra.run.dir={tmpdir}"],
# This is needed to ensure that variable interpolation to hydra.run.dir
# works properly
return_hydra_config=True,
)


@pytest.fixture
def bc_trainer(bc_run_config: RunConfig) -> bc.BC:
return hydra.utils.instantiate(bc_run_config.bc)


def test_train_bc_trainer_one_batch_smoke(bc_trainer: bc.BC):
# WHEN
bc_trainer.train(n_batches=1)

# THEN
# No exception is raised


38 changes: 38 additions & 0 deletions tests/cli/test_dagger_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import hydra
import pytest

from imitation.algorithms import dagger

# Note: this import is needed to ensure that configurations are properly registered
from imitation_cli.dagger import RunConfig


@pytest.fixture
def dagger_run_config(tmpdir) -> RunConfig:
"""A DAgger run config with a temporary directory as the output directory."""
with hydra.initialize_config_module(
version_base=None,
config_module="imitation_cli.config",
):
yield hydra.compose(
config_name="dagger_run",
overrides=[f"hydra.run.dir={tmpdir}"],
# This is needed to ensure that variable interpolation to hydra.run.dir
# works properly
return_hydra_config=True,
)


@pytest.fixture
def simple_dagger_trainer(dagger_run_config: RunConfig) -> dagger.SimpleDAggerTrainer:
return hydra.utils.instantiate(dagger_run_config.dagger)


def test_train_dagger_one_step_smoke(simple_dagger_trainer):
# WHEN
simple_dagger_trainer.train(1)

# THEN
# No exception is raised


0 comments on commit f9b01e3

Please sign in to comment.