-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Set up smoke tests for the algorithm CLIs
- Loading branch information
Showing
6 changed files
with
125 additions
and
6 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import hydra | ||
import pytest | ||
|
||
from imitation.algorithms.adversarial import airl | ||
|
||
# Note: this import is needed to ensure that configurations are properly registered | ||
from imitation_cli.airl import RunConfig | ||
|
||
|
||
@pytest.fixture | ||
def airl_run_config(tmpdir) -> RunConfig: | ||
"""A AIRL run config with a temporary directory as the output directory.""" | ||
with hydra.initialize_config_module( | ||
version_base=None, | ||
config_module="imitation_cli.config", | ||
): | ||
yield hydra.compose( | ||
config_name="airl_run", | ||
overrides=[f"hydra.run.dir={tmpdir}"], | ||
# This is needed to ensure that variable interpolation to hydra.run.dir | ||
# works properly | ||
return_hydra_config=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def airl_trainer(airl_run_config: RunConfig) -> airl.AIRL: | ||
return hydra.utils.instantiate(airl_run_config.airl) | ||
|
||
|
||
def test_train_airl_trainer_some_steps_smoke(airl_trainer: airl.AIRL): | ||
# WHEN | ||
# Note: any value lower than 16386 will raise an exception | ||
airl_trainer.train(16386) | ||
|
||
# THEN | ||
# No exception is raised | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import hydra | ||
import pytest | ||
|
||
from imitation.algorithms import bc | ||
|
||
# Note: this import is needed to ensure that configurations are properly registered | ||
from imitation_cli.bc import RunConfig | ||
|
||
|
||
@pytest.fixture | ||
def bc_run_config(tmpdir) -> RunConfig: | ||
"""A BC run config with a temporary directory as the output directory.""" | ||
with hydra.initialize_config_module( | ||
version_base=None, | ||
config_module="imitation_cli.config", | ||
): | ||
yield hydra.compose( | ||
config_name="bc_run", | ||
overrides=[f"hydra.run.dir={tmpdir}"], | ||
# This is needed to ensure that variable interpolation to hydra.run.dir | ||
# works properly | ||
return_hydra_config=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def bc_trainer(bc_run_config: RunConfig) -> bc.BC: | ||
return hydra.utils.instantiate(bc_run_config.bc) | ||
|
||
|
||
def test_train_bc_trainer_one_batch_smoke(bc_trainer: bc.BC): | ||
# WHEN | ||
bc_trainer.train(n_batches=1) | ||
|
||
# THEN | ||
# No exception is raised | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import hydra | ||
import pytest | ||
|
||
from imitation.algorithms import dagger | ||
|
||
# Note: this import is needed to ensure that configurations are properly registered | ||
from imitation_cli.dagger import RunConfig | ||
|
||
|
||
@pytest.fixture | ||
def dagger_run_config(tmpdir) -> RunConfig: | ||
"""A DAgger run config with a temporary directory as the output directory.""" | ||
with hydra.initialize_config_module( | ||
version_base=None, | ||
config_module="imitation_cli.config", | ||
): | ||
yield hydra.compose( | ||
config_name="dagger_run", | ||
overrides=[f"hydra.run.dir={tmpdir}"], | ||
# This is needed to ensure that variable interpolation to hydra.run.dir | ||
# works properly | ||
return_hydra_config=True, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def simple_dagger_trainer(dagger_run_config: RunConfig) -> dagger.SimpleDAggerTrainer: | ||
return hydra.utils.instantiate(dagger_run_config.dagger) | ||
|
||
|
||
def test_train_dagger_one_step_smoke(simple_dagger_trainer): | ||
# WHEN | ||
simple_dagger_trainer.train(1) | ||
|
||
# THEN | ||
# No exception is raised | ||
|
||
|