From 39647d73b820b515add3f03699efdcfdf7577536 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 11 Dec 2023 15:02:42 +0100 Subject: [PATCH] Make the default device for policies "cpu". --- src/imitation/algorithms/bc.py | 4 ++-- src/imitation/algorithms/dagger.py | 2 +- src/imitation/algorithms/sqil.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 29f12588d..4bb4057df 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -249,7 +249,7 @@ def __getstate__(self): def reconstruct_policy( policy_path: str, - device: Union[th.device, str] = "auto", + device: Union[th.device, str] = "cpu", ) -> policies.ActorCriticPolicy: """Reconstruct a saved policy. @@ -285,7 +285,7 @@ def __init__( optimizer_kwargs: Optional[Mapping[str, Any]] = None, ent_weight: float = 1e-3, l2_weight: float = 0.0, - device: Union[str, th.device] = "auto", + device: Union[str, th.device] = "cpu", custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Builds BC. diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index fb68713e6..aa496e877 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -100,7 +100,7 @@ def reconstruct_trainer( scratch_dir: types.AnyPath, venv: vec_env.VecEnv, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - device: Union[th.device, str] = "auto", + device: Union[th.device, str] = "cpu", ) -> "DAggerTrainer": """Reconstruct trainer from the latest snapshot in some working directory. diff --git a/src/imitation/algorithms/sqil.py b/src/imitation/algorithms/sqil.py index 55a8ad080..d043c8b40 100644 --- a/src/imitation/algorithms/sqil.py +++ b/src/imitation/algorithms/sqil.py @@ -119,7 +119,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, demonstrations: algo_base.AnyTransitions, - device: Union[th.device, str] = "auto", + device: Union[th.device, str] = "cpu", n_envs: int = 1, optimize_memory_usage: bool = False, ):