diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 4dfd554a0..8049cca97 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -2,7 +2,7 @@ import abc import dataclasses import logging -from typing import Iterable, Iterator, Mapping, Optional, Type, List, overload +from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload import numpy as np import torch as th @@ -15,8 +15,8 @@ policies, vec_env, ) -from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback +from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F