-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Introduce interactive policies to gather data from a user
- Loading branch information
1 parent
4872ceb
commit b0efd61
Showing
6 changed files
with
200 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import tempfile | ||
|
||
import gym | ||
import numpy as np | ||
from stable_baselines3.common import vec_env | ||
|
||
from imitation.algorithms import bc, dagger | ||
from imitation.policies import interactive | ||
|
||
if __name__ == "__main__": | ||
rng = np.random.default_rng(0) | ||
|
||
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)]) | ||
env.seed(0) | ||
|
||
action_names = env.envs[0].get_action_meanings() | ||
names_to_keys = { | ||
"NOOP": "n", | ||
"FIRE": "f", | ||
"LEFT": "w", | ||
"RIGHT": "e", | ||
"LEFTFIRE": "q", | ||
"RIGHTFIRE": "r", | ||
} | ||
action_keys = list(map(names_to_keys.get, action_names)) | ||
|
||
expert = interactive.ImageObsDiscreteInteractivePolicy( | ||
env.observation_space, env.action_space, action_names, action_keys | ||
) | ||
|
||
bc_trainer = bc.BC( | ||
observation_space=env.observation_space, | ||
action_space=env.action_space, | ||
rng=rng, | ||
) | ||
|
||
with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: | ||
dagger_trainer = dagger.SimpleDAggerTrainer( | ||
venv=env, | ||
scratch_dir=tmpdir, | ||
expert_policy=expert, | ||
bc_trainer=bc_trainer, | ||
rng=rng, | ||
) | ||
dagger_trainer.train( | ||
total_timesteps=20, | ||
rollout_round_min_episodes=1, | ||
rollout_round_min_timesteps=10, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import abc | ||
from typing import Optional, List | ||
|
||
import gym | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import imitation.policies.base as base_policies | ||
from imitation.util import util | ||
|
||
|
||
class DiscreteInteractivePolicy(base_policies.NonTrainablePolicy, abc.ABC): | ||
def __init__( | ||
self, | ||
observation_space: gym.Space, | ||
action_space: gym.Space, | ||
action_names: List[str], | ||
action_keys: List[str], | ||
clear_screen_on_query: bool = True, | ||
): | ||
super().__init__( | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
) | ||
|
||
assert isinstance(action_space, gym.spaces.Discrete) | ||
assert len(action_names) == len(action_keys) == action_space.n | ||
# Names and keys should be unique. | ||
assert len(set(action_names)) == len(set(action_keys)) == action_space.n | ||
|
||
self.action_names = action_names | ||
self.action_keys = action_keys | ||
self.action_key_to_index = {k: i for i, k in enumerate(action_keys)} | ||
self.clear_screen_on_query = clear_screen_on_query | ||
|
||
def _choose_action(self, obs: np.ndarray) -> np.ndarray: | ||
if self.clear_screen_on_query: | ||
util.clear_screen() | ||
|
||
context = self._render(obs) | ||
key = self._get_input_key() | ||
self._clean_up(context) | ||
|
||
return np.array([self.action_key_to_index[key]]) | ||
|
||
def _get_input_key(self) -> str: | ||
print( | ||
"Please select an action. Possible choices in [ACTION_NAME:KEY] format:", | ||
", ".join( | ||
[f"{n}:{k}" for n, k in zip(self.action_names, self.action_keys)] | ||
), | ||
) | ||
|
||
key = input("Your choice (enter key):") | ||
while key not in self.action_keys: | ||
key = input("Invalid key, please try again! Your choice (enter key):") | ||
|
||
return key | ||
|
||
@abc.abstractmethod | ||
def _render(self, obs: np.ndarray) -> Optional[object]: | ||
"""Renders an observation, optionally returns a context object for later cleanup.""" | ||
|
||
def _clean_up(self, context: object) -> None: | ||
"""Cleans up after the input has been captured, e.g. stops showing the image.""" | ||
pass | ||
|
||
|
||
class ImageObsDiscreteInteractivePolicy(DiscreteInteractivePolicy): | ||
def _render(self, obs: np.ndarray) -> plt.Figure: | ||
img = self._prepare_obs_image(obs) | ||
|
||
fig, ax = plt.subplots() | ||
ax.imshow(img) | ||
ax.axis("off") | ||
fig.show() | ||
|
||
return fig | ||
|
||
def _clean_up(self, context: plt.Figure) -> None: | ||
plt.close(context) | ||
|
||
def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: | ||
return obs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""Tests interactive policies.""" | ||
import random | ||
from unittest.mock import patch | ||
|
||
import gym | ||
import numpy as np | ||
import pytest | ||
from stable_baselines3.common import vec_env | ||
|
||
from imitation.policies import interactive | ||
|
||
ENVS = [ | ||
"Pong-v4", | ||
] | ||
|
||
|
||
class NoRenderingDiscreteInteractivePolicy(interactive.DiscreteInteractivePolicy): | ||
def _render(self, obs: np.ndarray) -> None: | ||
pass | ||
|
||
|
||
@pytest.mark.parametrize("env_name", ENVS) | ||
def test_interactive_policy(env_name: str): | ||
env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make(env_name), 10)]) | ||
env.seed(0) | ||
|
||
num_actions = env.envs[0].action_space.n | ||
action_names = [f"n{i}" for i in range(num_actions)] | ||
action_keys = [f"k{i}" for i in range(num_actions)] | ||
interactive_policy = NoRenderingDiscreteInteractivePolicy( | ||
env.observation_space, | ||
env.action_space, | ||
action_names, | ||
action_keys, | ||
) | ||
|
||
obs = env.reset() | ||
done = np.array([False]) | ||
|
||
def mock_input_valid(_): | ||
return random.choice(action_keys) | ||
|
||
with patch("builtins.input", mock_input_valid): | ||
while not done.all(): | ||
action, _ = interactive_policy.predict(obs) | ||
assert isinstance(action, np.ndarray) | ||
assert all(env.action_space.contains(a) for a in action) | ||
|
||
obs, reward, done, info = env.step(action) | ||
assert isinstance(obs, np.ndarray) | ||
assert all(env.observation_space.contains(o) for o in obs) | ||
assert isinstance(reward, np.ndarray) | ||
assert isinstance(done, np.ndarray) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters