Skip to content

Commit

Permalink
Make the default device for policies "cpu".
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Dec 11, 2023
1 parent 629ef9a commit 39647d7
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def __getstate__(self):

def reconstruct_policy(
policy_path: str,
device: Union[th.device, str] = "auto",
device: Union[th.device, str] = "cpu",
) -> policies.ActorCriticPolicy:
"""Reconstruct a saved policy.
Expand Down Expand Up @@ -285,7 +285,7 @@ def __init__(
optimizer_kwargs: Optional[Mapping[str, Any]] = None,
ent_weight: float = 1e-3,
l2_weight: float = 0.0,
device: Union[str, th.device] = "auto",
device: Union[str, th.device] = "cpu",
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
):
"""Builds BC.
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/algorithms/dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def reconstruct_trainer(
scratch_dir: types.AnyPath,
venv: vec_env.VecEnv,
custom_logger: Optional[imit_logger.HierarchicalLogger] = None,
device: Union[th.device, str] = "auto",
device: Union[th.device, str] = "cpu",
) -> "DAggerTrainer":
"""Reconstruct trainer from the latest snapshot in some working directory.
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/algorithms/sqil.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
observation_space: spaces.Space,
action_space: spaces.Space,
demonstrations: algo_base.AnyTransitions,
device: Union[th.device, str] = "auto",
device: Union[th.device, str] = "cpu",
n_envs: int = 1,
optimize_memory_usage: bool = False,
):
Expand Down

0 comments on commit 39647d7

Please sign in to comment.