Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
michalzajac-ml committed Sep 7, 2023
1 parent 6a9389e commit 8cb822a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/train_dagger_atari_interactive_policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Training DAgger with an interactive policy that queries the user for actions.
Note that this is a toy example that does not lead to training a reasonable policy."""
Note that this is a toy example that does not lead to training a reasonable policy.
"""

import tempfile

Expand Down
9 changes: 5 additions & 4 deletions src/imitation/policies/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
action_keys_names: collections.OrderedDict[str, str],
action_keys_names: collections.OrderedDict,
clear_screen_on_query: bool = True,
):
"""Builds DiscreteInteractivePolicy.
Expand Down Expand Up @@ -80,7 +80,7 @@ def _get_input_key(self) -> str:

@abc.abstractmethod
def _render(self, obs: np.ndarray) -> typing.Optional[object]:
"""Renders an observation, optionally returns a context object for later cleanup."""
"""Renders an observation, optionally returns a context for later cleanup."""

def _clean_up(self, context: object) -> None:
"""Cleans up after the input has been captured, e.g. stops showing the image."""
Expand Down Expand Up @@ -134,14 +134,15 @@ class AtariInteractivePolicy(ImageObsDiscreteInteractivePolicy):
"""Interactive policy for Atari environments."""

def __init__(self, env: typing.Union[gym.Env, vec_env.VecEnv], *args, **kwargs):
"""Builds AtariInteractivePolicy."""
action_names = (
env.get_action_meanings()
if isinstance(env, gym.Env)
else env.envs[0].get_action_meanings()
)
action_keys_names = collections.OrderedDict(
[(ATARI_ACTION_NAMES_TO_KEYS[name], name) for name in action_names]
[(ATARI_ACTION_NAMES_TO_KEYS[name], name) for name in action_names],
)
super().__init__(
env.observation_space, env.action_space, action_keys_names, *args, **kwargs
env.observation_space, env.action_space, action_keys_names, *args, **kwargs,
)
4 changes: 2 additions & 2 deletions tests/policies/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def _render(self, obs: np.ndarray) -> None:
def _get_interactive_policy(env: vec_env.VecEnv):
num_actions = env.envs[0].action_space.n
action_keys_names = collections.OrderedDict(
[(f"k{i}", f"n{i}") for i in range(num_actions)]
[(f"k{i}", f"n{i}") for i in range(num_actions)],
)
interactive_policy = NoRenderingDiscreteInteractivePolicy(
env.observation_space, env.action_space, action_keys_names
env.observation_space, env.action_space, action_keys_names,
)
return interactive_policy

Expand Down

0 comments on commit 8cb822a

Please sign in to comment.