Skip to content

Commit

Permalink
Adapt train_dagger_with_human_demos.py to changed paths
Browse files Browse the repository at this point in the history
  • Loading branch information
jas-ho committed Aug 10, 2023
1 parent 5affa39 commit 077bb43
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
4 changes: 2 additions & 2 deletions examples/train_dagger_with_human_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from imitation.algorithms import bc
from imitation.algorithms.dagger import SimpleDAggerTrainer
from imitation.policies.interactive import InteractivePolicy
from imitation.policies import interactive_text
from imitation.util.util import make_vec_env

# todo: also test with gym.env
Expand All @@ -25,7 +25,7 @@
is_slippery=True,
),
)
expert = InteractivePolicy(env)
expert = interactive_text.TextInteractivePolicy(env)


bc_trainer = bc.BC(
Expand Down
26 changes: 9 additions & 17 deletions tests/policies/test_interactive.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Tests InteractivePolicy."""
import random
import threading
from functools import partial
from unittest.mock import patch

import numpy as np
import pytest
from stable_baselines3.common.vec_env import VecEnv

from imitation.policies.interactive import InteractivePolicy, query_human
from imitation.policies import interactive_text
from imitation.util.util import make_vec_env

_make_vec_env = partial(make_vec_env, n_envs=1, rng=np.random.default_rng(42))
Expand All @@ -18,21 +19,9 @@
]


def get_interactive_policy(env: VecEnv):
def query_fn(obs):
env.render()
return query_human()

return InteractivePolicy(
observation_space=env.observation_space,
action_space=env.action_space,
query_fn=query_fn,
)


@pytest.mark.parametrize("env", ENVS)
def test_interactive_policy_valid(env: VecEnv):
interactive_policy = get_interactive_policy(env)
interactive_policy = interactive_text.TextInteractivePolicy(env)
obs = env.reset()
done = np.array([False])

Expand All @@ -54,12 +43,15 @@ def mock_input_valid(_):

@pytest.mark.parametrize("env", ENVS)
def test_interactive_policy_invalid(capsys, env: VecEnv):
interactive_policy = get_interactive_policy(env)
interactive_policy = interactive_text.TextInteractivePolicy(env)
obs = env.reset()

def mock_input_invalid(_):
return random.choice(["x", "y", "z"])

with patch("builtins.input", mock_input_invalid):
with pytest.raises(ValueError):
action, _ = interactive_policy.predict(obs)
test_thread = threading.Thread(target=interactive_policy.predict, args=(obs,))
test_thread.start()
test_thread.join(timeout=0.1)
captured = capsys.readouterr()
assert "Invalid input." in captured.out

0 comments on commit 077bb43

Please sign in to comment.