diff --git a/examples/train_dagger_atari_interactive_policy.py b/examples/train_dagger_atari_interactive_policy.py new file mode 100644 index 000000000..83d20be21 --- /dev/null +++ b/examples/train_dagger_atari_interactive_policy.py @@ -0,0 +1,49 @@ +import tempfile + +import gym +import numpy as np +from stable_baselines3.common import vec_env + +from imitation.algorithms import bc, dagger +from imitation.policies import interactive + +if __name__ == "__main__": + rng = np.random.default_rng(0) + + env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)]) + env.seed(0) + + action_names = env.envs[0].get_action_meanings() + names_to_keys = { + "NOOP": "n", + "FIRE": "f", + "LEFT": "w", + "RIGHT": "e", + "LEFTFIRE": "q", + "RIGHTFIRE": "r", + } + action_keys = list(map(names_to_keys.get, action_names)) + + expert = interactive.ImageObsDiscreteInteractivePolicy( + env.observation_space, env.action_space, action_names, action_keys + ) + + bc_trainer = bc.BC( + observation_space=env.observation_space, + action_space=env.action_space, + rng=rng, + ) + + with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: + dagger_trainer = dagger.SimpleDAggerTrainer( + venv=env, + scratch_dir=tmpdir, + expert_policy=expert, + bc_trainer=bc_trainer, + rng=rng, + ) + dagger_trainer.train( + total_timesteps=20, + rollout_round_min_episodes=1, + rollout_round_min_timesteps=10, + ) diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 3101cf2c7..60db89f50 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -13,11 +13,11 @@ from imitation.util import networks -class HardCodedPolicy(policies.BasePolicy, abc.ABC): - """Abstract class for hard-coded (non-trainable) policies.""" +class NonTrainablePolicy(policies.BasePolicy, abc.ABC): + """Abstract class for non-trainable (e.g. hard-coded or interactive) policies.""" def __init__(self, observation_space: gym.Space, action_space: gym.Space): - """Builds HardcodedPolicy with specified observation and action space.""" + """Builds NonTrainablePolicy with specified observation and action space.""" super().__init__( observation_space=observation_space, action_space=action_space, @@ -43,14 +43,14 @@ def forward(self, *args): raise NotImplementedError # pragma: no cover -class RandomPolicy(HardCodedPolicy): +class RandomPolicy(NonTrainablePolicy): """Returns random actions.""" def _choose_action(self, obs: np.ndarray) -> np.ndarray: return self.action_space.sample() -class ZeroPolicy(HardCodedPolicy): +class ZeroPolicy(NonTrainablePolicy): """Returns constant zero action.""" def _choose_action(self, obs: np.ndarray) -> np.ndarray: diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py new file mode 100644 index 000000000..2ee63347d --- /dev/null +++ b/src/imitation/policies/interactive.py @@ -0,0 +1,84 @@ +import abc +from typing import Optional, List + +import gym +import matplotlib.pyplot as plt +import numpy as np + +import imitation.policies.base as base_policies +from imitation.util import util + + +class DiscreteInteractivePolicy(base_policies.NonTrainablePolicy, abc.ABC): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + action_names: List[str], + action_keys: List[str], + clear_screen_on_query: bool = True, + ): + super().__init__( + observation_space=observation_space, + action_space=action_space, + ) + + assert isinstance(action_space, gym.spaces.Discrete) + assert len(action_names) == len(action_keys) == action_space.n + # Names and keys should be unique. + assert len(set(action_names)) == len(set(action_keys)) == action_space.n + + self.action_names = action_names + self.action_keys = action_keys + self.action_key_to_index = {k: i for i, k in enumerate(action_keys)} + self.clear_screen_on_query = clear_screen_on_query + + def _choose_action(self, obs: np.ndarray) -> np.ndarray: + if self.clear_screen_on_query: + util.clear_screen() + + context = self._render(obs) + key = self._get_input_key() + self._clean_up(context) + + return np.array([self.action_key_to_index[key]]) + + def _get_input_key(self) -> str: + print( + "Please select an action. Possible choices in [ACTION_NAME:KEY] format:", + ", ".join( + [f"{n}:{k}" for n, k in zip(self.action_names, self.action_keys)] + ), + ) + + key = input("Your choice (enter key):") + while key not in self.action_keys: + key = input("Invalid key, please try again! Your choice (enter key):") + + return key + + @abc.abstractmethod + def _render(self, obs: np.ndarray) -> Optional[object]: + """Renders an observation, optionally returns a context object for later cleanup.""" + + def _clean_up(self, context: object) -> None: + """Cleans up after the input has been captured, e.g. stops showing the image.""" + pass + + +class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy): + def _render(self, obs: np.ndarray) -> plt.Figure: + img = self._prepare_obs_image(obs) + + fig, ax = plt.subplots() + ax.imshow(img) + ax.axis("off") + fig.show() + + return fig + + def _clean_up(self, context: plt.Figure) -> None: + plt.close(context) + + def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: + return obs diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 83696028d..fb7855dbb 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -460,3 +460,10 @@ def split_in_half(x: int) -> Tuple[int, int]: """ half = x // 2 return half, x - half + + +def clear_screen() -> None: + if os.name == "nt": # Windows + os.system("cls") + else: + os.system("clear") diff --git a/tests/policies/test_interactive.py b/tests/policies/test_interactive.py new file mode 100644 index 000000000..8916363a6 --- /dev/null +++ b/tests/policies/test_interactive.py @@ -0,0 +1,53 @@ +"""Tests interactive policies.""" +import random +from unittest.mock import patch + +import gym +import numpy as np +import pytest +from stable_baselines3.common import vec_env + +from imitation.policies import interactive + +ENVS = [ + "Pong-v4", +] + + +class NoRenderingDiscreteInteractivePolicy(interactive.DiscreteInteractivePolicy): + def _render(self, obs: np.ndarray) -> None: + pass + + +@pytest.mark.parametrize("env_name", ENVS) +def test_interactive_policy(env_name: str): + env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)]) + env.seed(0) + + num_actions = env.envs[0].action_space.n + action_names = [f"n{i}" for i in range(num_actions)] + action_keys = [f"k{i}" for i in range(num_actions)] + interactive_policy = NoRenderingDiscreteInteractivePolicy( + env.observation_space, + env.action_space, + action_names, + action_keys, + ) + + obs = env.reset() + done = np.array([False]) + + def mock_input_valid(_): + return random.choice(action_keys) + + with patch("builtins.input", mock_input_valid): + while not done.all(): + action, _ = interactive_policy.predict(obs) + assert isinstance(action, np.ndarray) + assert all(env.action_space.contains(a) for a in action) + + obs, reward, done, info = env.step(action) + assert isinstance(obs, np.ndarray) + assert all(env.observation_space.contains(o) for o in obs) + assert isinstance(reward, np.ndarray) + assert isinstance(done, np.ndarray) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index a79b134f1..e957eaf52 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -17,13 +17,13 @@ SIMPLE_DISCRETE_ENV = "CartPole-v0" # Discrete(2) action space SIMPLE_CONTINUOUS_ENV = "MountainCarContinuous-v0" # Box(1) action space SIMPLE_ENVS = [SIMPLE_DISCRETE_ENV, SIMPLE_CONTINUOUS_ENV] -HARDCODED_TYPES = ["random", "zero"] +NONTRAINABLE_TYPES = ["random", "zero"] assert_equal = functools.partial(th.testing.assert_close, rtol=0, atol=0) @pytest.mark.parametrize("env_name", SIMPLE_ENVS) -@pytest.mark.parametrize("policy_type", HARDCODED_TYPES) +@pytest.mark.parametrize("policy_type", NONTRAINABLE_TYPES) def test_actions_valid(env_name, policy_type, rng): """Test output actions of our custom policies always lie in action space.""" venv = util.make_vec_env(