From 029ef72ae088b847f8fa3594d2238e318a2deb9d Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 12 Sep 2023 12:18:48 +0200 Subject: [PATCH] Fix typing issue in interactive.py --- src/imitation/policies/interactive.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py index f68f22a68..648a1874f 100644 --- a/src/imitation/policies/interactive.py +++ b/src/imitation/policies/interactive.py @@ -7,6 +7,7 @@ import gymnasium as gym import matplotlib.pyplot as plt import numpy as np +from shimmy import atari_env from stable_baselines3.common import vec_env import imitation.policies.base as base_policies @@ -133,11 +134,11 @@ def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: class AtariInteractivePolicy(ImageObsDiscreteInteractivePolicy): """Interactive policy for Atari environments.""" - def __init__(self, env: Union[gym.Env, vec_env.VecEnv], *args, **kwargs): + def __init__(self, env: Union[atari_env.AtariEnv, vec_env.VecEnv], *args, **kwargs): """Builds AtariInteractivePolicy.""" action_names = ( env.get_action_meanings() - if isinstance(env, gym.Env) + if isinstance(env, atari_env.AtariEnv) else env.env_method("get_action_meanings", indices=[0])[0] ) action_keys_names = collections.OrderedDict(