Skip to content

Commit

Permalink
WIP: Introduce interactive policies to gather data from a user
Browse files Browse the repository at this point in the history
  • Loading branch information
michalzajac-ml committed Sep 6, 2023
1 parent 4872ceb commit b0efd61
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 7 deletions.
49 changes: 49 additions & 0 deletions examples/train_dagger_atari_interactive_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import tempfile

import gym
import numpy as np
from stable_baselines3.common import vec_env

from imitation.algorithms import bc, dagger
from imitation.policies import interactive

if __name__ == "__main__":
rng = np.random.default_rng(0)

env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)])
env.seed(0)

action_names = env.envs[0].get_action_meanings()
names_to_keys = {
"NOOP": "n",
"FIRE": "f",
"LEFT": "w",
"RIGHT": "e",
"LEFTFIRE": "q",
"RIGHTFIRE": "r",
}
action_keys = list(map(names_to_keys.get, action_names))

expert = interactive.ImageObsDiscreteInteractivePolicy(
env.observation_space, env.action_space, action_names, action_keys
)

bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
rng=rng,
)

with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir:
dagger_trainer = dagger.SimpleDAggerTrainer(
venv=env,
scratch_dir=tmpdir,
expert_policy=expert,
bc_trainer=bc_trainer,
rng=rng,
)
dagger_trainer.train(
total_timesteps=20,
rollout_round_min_episodes=1,
rollout_round_min_timesteps=10,
)
10 changes: 5 additions & 5 deletions src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from imitation.util import networks


class HardCodedPolicy(policies.BasePolicy, abc.ABC):
"""Abstract class for hard-coded (non-trainable) policies."""
class NonTrainablePolicy(policies.BasePolicy, abc.ABC):
"""Abstract class for non-trainable (e.g. hard-coded or interactive) policies."""

def __init__(self, observation_space: gym.Space, action_space: gym.Space):
"""Builds HardcodedPolicy with specified observation and action space."""
"""Builds NonTrainablePolicy with specified observation and action space."""
super().__init__(
observation_space=observation_space,
action_space=action_space,
Expand All @@ -43,14 +43,14 @@ def forward(self, *args):
raise NotImplementedError # pragma: no cover


class RandomPolicy(HardCodedPolicy):
class RandomPolicy(NonTrainablePolicy):
"""Returns random actions."""

def _choose_action(self, obs: np.ndarray) -> np.ndarray:
return self.action_space.sample()


class ZeroPolicy(HardCodedPolicy):
class ZeroPolicy(NonTrainablePolicy):
"""Returns constant zero action."""

def _choose_action(self, obs: np.ndarray) -> np.ndarray:
Expand Down
84 changes: 84 additions & 0 deletions src/imitation/policies/interactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import abc
from typing import Optional, List

import gym
import matplotlib.pyplot as plt
import numpy as np

import imitation.policies.base as base_policies
from imitation.util import util


class DiscreteInteractivePolicy(base_policies.NonTrainablePolicy, abc.ABC):
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
action_names: List[str],
action_keys: List[str],
clear_screen_on_query: bool = True,
):
super().__init__(
observation_space=observation_space,
action_space=action_space,
)

assert isinstance(action_space, gym.spaces.Discrete)
assert len(action_names) == len(action_keys) == action_space.n
# Names and keys should be unique.
assert len(set(action_names)) == len(set(action_keys)) == action_space.n

self.action_names = action_names
self.action_keys = action_keys
self.action_key_to_index = {k: i for i, k in enumerate(action_keys)}
self.clear_screen_on_query = clear_screen_on_query

def _choose_action(self, obs: np.ndarray) -> np.ndarray:
if self.clear_screen_on_query:
util.clear_screen()

context = self._render(obs)
key = self._get_input_key()
self._clean_up(context)

return np.array([self.action_key_to_index[key]])

def _get_input_key(self) -> str:
print(
"Please select an action. Possible choices in [ACTION_NAME:KEY] format:",
", ".join(
[f"{n}:{k}" for n, k in zip(self.action_names, self.action_keys)]
),
)

key = input("Your choice (enter key):")
while key not in self.action_keys:
key = input("Invalid key, please try again! Your choice (enter key):")

return key

@abc.abstractmethod
def _render(self, obs: np.ndarray) -> Optional[object]:
"""Renders an observation, optionally returns a context object for later cleanup."""

def _clean_up(self, context: object) -> None:
"""Cleans up after the input has been captured, e.g. stops showing the image."""
pass


class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy):
def _render(self, obs: np.ndarray) -> plt.Figure:
img = self._prepare_obs_image(obs)

fig, ax = plt.subplots()
ax.imshow(img)
ax.axis("off")
fig.show()

return fig

def _clean_up(self, context: plt.Figure) -> None:
plt.close(context)

def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray:
return obs
7 changes: 7 additions & 0 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,10 @@ def split_in_half(x: int) -> Tuple[int, int]:
"""
half = x // 2
return half, x - half


def clear_screen() -> None:
if os.name == "nt": # Windows
os.system("cls")
else:
os.system("clear")
53 changes: 53 additions & 0 deletions tests/policies/test_interactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests interactive policies."""
import random
from unittest.mock import patch

import gym
import numpy as np
import pytest
from stable_baselines3.common import vec_env

from imitation.policies import interactive

ENVS = [
"Pong-v4",
]


class NoRenderingDiscreteInteractivePolicy(interactive.DiscreteInteractivePolicy):
def _render(self, obs: np.ndarray) -> None:
pass


@pytest.mark.parametrize("env_name", ENVS)
def test_interactive_policy(env_name: str):
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)])
env.seed(0)

num_actions = env.envs[0].action_space.n
action_names = [f"n{i}" for i in range(num_actions)]
action_keys = [f"k{i}" for i in range(num_actions)]
interactive_policy = NoRenderingDiscreteInteractivePolicy(
env.observation_space,
env.action_space,
action_names,
action_keys,
)

obs = env.reset()
done = np.array([False])

def mock_input_valid(_):
return random.choice(action_keys)

with patch("builtins.input", mock_input_valid):
while not done.all():
action, _ = interactive_policy.predict(obs)
assert isinstance(action, np.ndarray)
assert all(env.action_space.contains(a) for a in action)

obs, reward, done, info = env.step(action)
assert isinstance(obs, np.ndarray)
assert all(env.observation_space.contains(o) for o in obs)
assert isinstance(reward, np.ndarray)
assert isinstance(done, np.ndarray)
4 changes: 2 additions & 2 deletions tests/policies/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
SIMPLE_DISCRETE_ENV = "CartPole-v0" # Discrete(2) action space
SIMPLE_CONTINUOUS_ENV = "MountainCarContinuous-v0" # Box(1) action space
SIMPLE_ENVS = [SIMPLE_DISCRETE_ENV, SIMPLE_CONTINUOUS_ENV]
HARDCODED_TYPES = ["random", "zero"]
NONTRAINABLE_TYPES = ["random", "zero"]

assert_equal = functools.partial(th.testing.assert_close, rtol=0, atol=0)


@pytest.mark.parametrize("env_name", SIMPLE_ENVS)
@pytest.mark.parametrize("policy_type", HARDCODED_TYPES)
@pytest.mark.parametrize("policy_type", NONTRAINABLE_TYPES)
def test_actions_valid(env_name, policy_type, rng):
"""Test output actions of our custom policies always lie in action space."""
venv = util.make_vec_env(
Expand Down

0 comments on commit b0efd61

Please sign in to comment.