Skip to content

Commit

Permalink
Fix typing issue in interactive.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Sep 12, 2023
1 parent ab67c84 commit 029ef72
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/imitation/policies/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from shimmy import atari_env
from stable_baselines3.common import vec_env

import imitation.policies.base as base_policies
Expand Down Expand Up @@ -133,11 +134,11 @@ def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray:
class AtariInteractivePolicy(ImageObsDiscreteInteractivePolicy):
"""Interactive policy for Atari environments."""

def __init__(self, env: Union[gym.Env, vec_env.VecEnv], *args, **kwargs):
def __init__(self, env: Union[atari_env.AtariEnv, vec_env.VecEnv], *args, **kwargs):
"""Builds AtariInteractivePolicy."""
action_names = (
env.get_action_meanings()
if isinstance(env, gym.Env)
if isinstance(env, atari_env.AtariEnv)
else env.env_method("get_action_meanings", indices=[0])[0]
)
action_keys_names = collections.OrderedDict(
Expand Down

0 comments on commit 029ef72

Please sign in to comment.