diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 115de314f..29f12588d 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -9,6 +9,7 @@ from typing import ( Any, Callable, + Dict, Iterable, Iterator, Mapping, @@ -22,7 +23,7 @@ import numpy as np import torch as th import tqdm -from stable_baselines3.common import policies, utils, vec_env +from stable_baselines3.common import policies, torch_layers, utils, vec_env from imitation.algorithms import base as algo_base from imitation.data import rollout, types @@ -99,7 +100,12 @@ class BehaviorCloningLossCalculator: def __call__( self, policy: policies.ActorCriticPolicy, - obs: Union[th.Tensor, np.ndarray], + obs: Union[ + types.AnyTensor, + types.DictObs, + Dict[str, np.ndarray], + Dict[str, th.Tensor], + ], acts: Union[th.Tensor, np.ndarray], ) -> BCTrainingMetrics: """Calculate the supervised learning loss used to train the behavioral clone. @@ -113,9 +119,18 @@ def __call__( A BCTrainingMetrics object with the loss and all the components it consists of. """ - obs = util.safe_to_tensor(obs) + tensor_obs = types.map_maybe_dict( + util.safe_to_tensor, + types.maybe_unwrap_dictobs(obs), + ) acts = util.safe_to_tensor(acts) - _, log_prob, entropy = policy.evaluate_actions(obs, acts) + + # policy.evaluate_actions's type signatures are incorrect. + # See https://github.com/DLR-RM/stable-baselines3/issues/1679 + (_, log_prob, entropy) = policy.evaluate_actions( + tensor_obs, # type: ignore[arg-type] + acts, + ) prob_true_act = th.exp(log_prob).mean() log_prob = log_prob.mean() entropy = entropy.mean() if entropy is not None else None @@ -324,12 +339,18 @@ def __init__( self.rng = rng if policy is None: + extractor = ( + torch_layers.CombinedExtractor + if isinstance(observation_space, gym.spaces.Dict) + else torch_layers.FlattenExtractor + ) policy = policy_base.FeedForward32Policy( observation_space=observation_space, action_space=action_space, # Set lr_schedule to max value to force error if policy.optimizer # is used by mistake (should use self.optimizer instead). lr_schedule=lambda _: th.finfo(th.float32).max, + features_extractor_class=extractor, ) self._policy = policy.to(utils.get_device(device)) # TODO(adam): make policy mandatory and delete observation/action space params? @@ -464,9 +485,14 @@ def process_batch(): minibatch_size, num_samples_so_far, ), batch in batches_with_stats: - obs = th.as_tensor(batch["obs"], device=self.policy.device).detach() - acts = th.as_tensor(batch["acts"], device=self.policy.device).detach() - training_metrics = self.loss_calculator(self.policy, obs, acts) + obs_tensor: Union[th.Tensor, Dict[str, th.Tensor]] + # unwraps the observation if it's a dictobs and converts arrays to tensors + obs_tensor = types.map_maybe_dict( + lambda x: util.safe_to_tensor(x, device=self.policy.device), + types.maybe_unwrap_dictobs(batch["obs"]), + ) + acts = util.safe_to_tensor(batch["acts"], device=self.policy.device) + training_metrics = self.loss_calculator(self.policy, obs_tensor, acts) # Renormalise the loss to be averaged over the whole # batch size instead of the minibatch size. diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index 88dd962e5..377467ce9 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -134,9 +134,9 @@ def __init__( def _get_demo_from_batch( self, - obs_b: np.ndarray, + obs_b: types.Observation, act_b: np.ndarray, - next_obs_b: Optional[np.ndarray], + next_obs_b: Optional[types.Observation], ) -> Dict[Optional[int], List[np.ndarray]]: if next_obs_b is None and self.density_type == DensityType.STATE_STATE_DENSITY: raise ValueError( @@ -145,11 +145,18 @@ def _get_demo_from_batch( ) assert act_b.shape[1:] == self.venv.action_space.shape - assert obs_b.shape[1:] == self.venv.observation_space.shape + ob_space = self.venv.observation_space + if isinstance(obs_b, types.DictObs): + exp_shape = { + k: v.shape for k, v in ob_space.items() # type: ignore[attr-defined] + } + obs_shape = {k: v.shape[1:] for k, v in obs_b.items()} + assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}" + else: + assert obs_b.shape[1:] == ob_space.shape assert len(act_b) == len(obs_b) if next_obs_b is not None: - assert next_obs_b.shape[1:] == self.venv.observation_space.shape - assert len(next_obs_b) == len(obs_b) + assert next_obs_b.shape == obs_b.shape if next_obs_b is not None: next_obs_b_iterator: Iterable = next_obs_b @@ -200,14 +207,17 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: # analogous to cast above. demonstrations = cast(Iterable[types.TransitionMapping], demonstrations) + def to_np_maybe_dictobs(x): + if isinstance(x, types.DictObs): + return x + else: + return util.safe_to_numpy(x, warn=True) + for batch in demonstrations: - transitions.update( - self._get_demo_from_batch( - util.safe_to_numpy(batch["obs"], warn=True), - util.safe_to_numpy(batch["acts"], warn=True), - util.safe_to_numpy(batch.get("next_obs"), warn=True), - ), - ) + obs = to_np_maybe_dictobs(batch["obs"]) + acts = util.safe_to_numpy(batch["acts"], warn=True) + next_obs = to_np_maybe_dictobs(batch.get("next_obs")) + transitions.update(self._get_demo_from_batch(obs, acts, next_obs)) else: raise TypeError( f"Unsupported demonstration type {type(demonstrations)}", @@ -253,65 +263,40 @@ def _fit_density(self, transitions: np.ndarray) -> neighbors.KernelDensity: def _preprocess_transition( self, - obs: np.ndarray, + obs: types.Observation, act: np.ndarray, - next_obs: Optional[np.ndarray], + next_obs: Optional[types.Observation], ) -> np.ndarray: """Compute flattened transition on subset specified by `self.density_type`.""" + flattened_obs = space_utils.flatten( + self.venv.observation_space, + types.maybe_unwrap_dictobs(obs), + ) + flattened_obs = _check_data_is_np_array(flattened_obs, "observation") if self.density_type == DensityType.STATE_DENSITY: - flat_observations = space_utils.flatten(self.venv.observation_space, obs) - if not isinstance(flat_observations, np.ndarray): - raise ValueError( - "The density estimator only supports spaces that " - "flatten to a numpy array but the observation space " - f"flattens to {type(flat_observations)}", - ) - - return flat_observations + return flattened_obs elif self.density_type == DensityType.STATE_ACTION_DENSITY: - flat_observation = space_utils.flatten(self.venv.observation_space, obs) - flat_action = space_utils.flatten(self.venv.action_space, act) - - if not isinstance(flat_observation, np.ndarray): - raise ValueError( - "The density estimator only supports spaces that " - "flatten to a numpy array but the observation space " - f"flattens to {type(flat_observation)}", - ) - if not isinstance(flat_action, np.ndarray): - raise ValueError( - "The density estimator only supports spaces that " - "flatten to a numpy array but the action space " - f"flattens to {type(flat_action)}", - ) - - return np.concatenate([flat_observation, flat_action]) + flattened_action = space_utils.flatten(self.venv.action_space, act) + flattened_action = _check_data_is_np_array(flattened_action, "action") + return np.concatenate([flattened_obs, flattened_action]) elif self.density_type == DensityType.STATE_STATE_DENSITY: assert next_obs is not None - flat_observation = space_utils.flatten(self.venv.observation_space, obs) - flat_next_observation = space_utils.flatten( + flat_next_obs = space_utils.flatten( self.venv.observation_space, - next_obs, + types.maybe_unwrap_dictobs(next_obs), ) + flat_next_obs = _check_data_is_np_array(flat_next_obs, "observation") + assert type(flattened_obs) is type(flat_next_obs) - if not isinstance(flat_observation, np.ndarray): - raise ValueError( - "The density estimator only supports spaces that " - "flatten to a numpy array but the observation space " - f"flattens to {type(flat_observation)}", - ) - - assert type(flat_observation) is type(flat_next_observation) - - return np.concatenate([flat_observation, flat_next_observation]) + return np.concatenate([flattened_obs, flat_next_obs]) else: raise ValueError(f"Unknown density type {self.density_type}") def __call__( self, - state: np.ndarray, + state: types.Observation, action: np.ndarray, - next_state: np.ndarray, + next_state: types.Observation, done: np.ndarray, steps: Optional[np.ndarray] = None, ) -> np.ndarray: @@ -347,6 +332,8 @@ def __call__( rew_list = [] assert len(state) == len(action) and len(state) == len(next_state) + state = types.maybe_wrap_in_dictobs(state) + next_state = types.maybe_wrap_in_dictobs(next_state) for idx, (obs, act, next_obs) in enumerate(zip(state, action, next_state)): flat_trans = self._preprocess_transition(obs, act, next_obs) assert self._scaler is not None @@ -424,3 +411,13 @@ def policy(self) -> base_class.BasePolicy: assert self.rl_algo is not None assert self.rl_algo.policy is not None return self.rl_algo.policy + + +def _check_data_is_np_array(data: space_utils.FlatType, name: str) -> np.ndarray: + """Raises error if the flattened data is not a numpy array.""" + assert isinstance(data, np.ndarray), ( + "The density estimator only supports spaces that " + f"flatten to a numpy array but the {name} space " + f"flattens to {type(data)}", + ) + return data diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index 7a25f86c4..21ebd8642 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -7,7 +7,18 @@ """ import collections import warnings -from typing import Any, Iterable, List, Mapping, NoReturn, Optional, Tuple, Type, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import gymnasium as gym import numpy as np @@ -347,7 +358,7 @@ def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None num_demos = 0 for traj in trajs: cum_discount = 1.0 - for obs in traj.obs: + for obs in types.assert_not_dictobs(traj.obs): self.demo_state_om[obs] += cum_discount cum_discount *= self.discount num_demos += 1 @@ -411,23 +422,32 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: if isinstance(demonstrations, types.Transitions): self._set_demo_from_obs( - demonstrations.obs, + types.assert_not_dictobs(demonstrations.obs), demonstrations.dones, - demonstrations.next_obs, + types.assert_not_dictobs(demonstrations.next_obs), ) elif isinstance(demonstrations, types.TransitionsMinimal): - self._set_demo_from_obs(demonstrations.obs, None, None) + self._set_demo_from_obs( + types.assert_not_dictobs(demonstrations.obs), + None, + None, + ) elif isinstance(demonstrations, Iterable): # Demonstrations are a Torch DataLoader or other Mapping iterable # Collect them together into one big NumPy array. This is inefficient, # we could compute the running statistics instead, but in practice do # not expect large dataset sizes together with MCE IRL. - collated_list = collections.defaultdict(list) + collated_list: Dict[ + str, + List[types.AnyTensor], + ] = collections.defaultdict(list) for batch in demonstrations: assert isinstance(batch, Mapping) for k in ("obs", "dones", "next_obs"): - if k in batch: - collated_list[k].append(batch[k]) + x = batch.get(k) + if x is not None: + assert isinstance(x, (np.ndarray, th.Tensor)) + collated_list[k].append(x) collated = {k: np.concatenate(v) for k, v in collated_list.items()} assert "obs" in collated diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 413cd979a..14a8fad5b 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -465,9 +465,9 @@ def rewards(self, transitions: Transitions) -> th.Tensor: Shape - (num_transitions, ) for Single reward network and (num_transitions, num_networks) for ensemble of networks. """ - state = transitions.obs + state = types.assert_not_dictobs(transitions.obs) action = transitions.acts - next_state = transitions.next_obs + next_state = types.assert_not_dictobs(transitions.next_obs) done = transitions.dones if self.ensemble_model is not None: rews_np = self.ensemble_model.predict_processed_all( diff --git a/src/imitation/data/buffer.py b/src/imitation/data/buffer.py index 336a46252..f09ca50ca 100644 --- a/src/imitation/data/buffer.py +++ b/src/imitation/data/buffer.py @@ -1,6 +1,5 @@ """Buffers to store NumPy arrays and transitions in.""" -import dataclasses from typing import Any, Mapping, Optional, Tuple import numpy as np @@ -368,15 +367,16 @@ def from_data( Returns: A new ReplayBuffer. """ - obs_shape = transitions.obs.shape[1:] + obs = types.assert_not_dictobs(transitions.obs) + obs_shape = obs.shape[1:] act_shape = transitions.acts.shape[1:] if capacity is None: - capacity = transitions.obs.shape[0] + capacity = obs.shape[0] instance = cls( capacity=capacity, obs_shape=obs_shape, act_shape=act_shape, - obs_dtype=transitions.obs.dtype, + obs_dtype=obs.dtype, act_dtype=transitions.acts.dtype, ) instance.store(transitions, truncate_ok=truncate_ok) @@ -406,7 +406,7 @@ def store(self, transitions: types.Transitions, truncate_ok: bool = True) -> Non Raises: ValueError: The arguments didn't have the same length. """ # noqa: DAR402 - trans_dict = dataclasses.asdict(transitions) + trans_dict = types.dataclass_quick_asdict(transitions) # Remove unnecessary fields trans_dict = {k: trans_dict[k] for k in self._buffer.sample_shapes.keys()} self._buffer.store(trans_dict, truncate_ok=truncate_ok) diff --git a/src/imitation/data/huggingface_utils.py b/src/imitation/data/huggingface_utils.py index af6a45dae..ef3ae7d4c 100644 --- a/src/imitation/data/huggingface_utils.py +++ b/src/imitation/data/huggingface_utils.py @@ -124,6 +124,8 @@ def trajectories_to_dict( ], terminal=[traj.terminal for traj in trajectories], ) + if any(isinstance(traj.obs, types.DictObs) for traj in trajectories): + raise ValueError("DictObs are not currently supported") # Encode infos as jsonpickled strings trajectory_dict["infos"] = [ diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index e560f7cbc..78007630e 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -18,6 +18,7 @@ ) import numpy as np +from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.utils import check_for_correct_spaces @@ -69,7 +70,7 @@ def __init__(self): def add_step( self, - step_dict: Mapping[str, Union[np.ndarray, Mapping[str, Any]]], + step_dict: Mapping[str, Union[types.Observation, Mapping[str, Any]]], key: Hashable = None, ) -> None: """Add a single step to the partial trajectory identified by `key`. @@ -107,17 +108,19 @@ def finish_trajectory( for part_dict in part_dicts: for k, array in part_dict.items(): out_dict_unstacked[k].append(array) + out_dict_stacked = { - k: np.stack(arr_list, axis=0) for k, arr_list in out_dict_unstacked.items() + k: types.stack_maybe_dictobs(arr_list) + for k, arr_list in out_dict_unstacked.items() } traj = types.TrajectoryWithRew(**out_dict_stacked, terminal=terminal) - assert traj.rews.shape[0] == traj.acts.shape[0] == traj.obs.shape[0] - 1 + assert traj.rews.shape[0] == traj.acts.shape[0] == len(traj.obs) - 1 return traj def add_steps_and_auto_finish( self, acts: np.ndarray, - obs: np.ndarray, + obs: Union[types.Observation, Dict[str, np.ndarray]], rews: np.ndarray, dones: np.ndarray, infos: List[dict], @@ -142,20 +145,24 @@ def add_steps_and_auto_finish( each `True` in the `dones` argument. """ trajs: List[types.TrajectoryWithRew] = [] - for env_idx in range(len(obs)): + wrapped_obs = types.maybe_wrap_in_dictobs(obs) + + # iterate through environments + for env_idx in range(len(wrapped_obs)): assert env_idx in self.partial_trajectories assert list(self.partial_trajectories[env_idx][0].keys()) == ["obs"], ( "Need to first initialize partial trajectory using " "self._traj_accum.add_step({'obs': ob}, key=env_idx)" ) - zip_iter = enumerate(zip(acts, obs, rews, dones, infos)) + # iterate through steps + zip_iter = enumerate(zip(acts, wrapped_obs, rews, dones, infos)) for env_idx, (act, ob, rew, done, info) in zip_iter: if done: # When dones[i] from VecEnv.step() is True, obs[i] is the first # observation following reset() of the ith VecEnv, and # infos[i]["terminal_observation"] is the actual final observation. - real_ob = info["terminal_observation"] + real_ob = types.maybe_wrap_in_dictobs(info["terminal_observation"]) else: real_ob = ob @@ -268,8 +275,12 @@ def sample_until(trajs: Sequence[types.TrajectoryWithRew]) -> bool: # array of states, and an optional array of episode starts and returns an array of # corresponding actions. PolicyCallable = Callable[ - [np.ndarray, Optional[Tuple[np.ndarray, ...]], Optional[np.ndarray]], - Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]], + [ + Union[np.ndarray, Dict[str, np.ndarray]], # observations + Optional[Tuple[np.ndarray, ...]], # states + Optional[np.ndarray], # episode_starts + ], + Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]], # actions, states ] AnyPolicy = Union[BaseAlgorithm, BasePolicy, PolicyCallable, None] @@ -284,7 +295,7 @@ def policy_to_callable( if policy is None: def get_actions( - observations: np.ndarray, + observations: Union[np.ndarray, Dict[str, np.ndarray]], states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -298,7 +309,7 @@ def get_actions( # (which would call .forward()). So this elif clause must come first! def get_actions( - observations: np.ndarray, + observations: Union[np.ndarray, Dict[str, np.ndarray]], states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -397,17 +408,7 @@ def generate_trajectories( Sequence of trajectories, satisfying `sample_until`. Additional trajectories may be collected to avoid biasing process towards short episodes; the user should truncate if required. - - Raises: - ValueError: If the observation or action space has no shape or the observations - are not a numpy array. """ - if venv.observation_space.shape is None: - raise ValueError("Observation space must have a shape.") - - if venv.action_space.shape is None: - raise ValueError("Action space must have a shape.") - get_actions = policy_to_callable(policy, venv, deterministic_policy) # Collect rollout tuples. @@ -415,7 +416,14 @@ def generate_trajectories( # accumulator for incomplete trajectories trajectories_accum = TrajectoryAccumulator() obs = venv.reset() - for env_idx, ob in enumerate(obs): + assert isinstance( + obs, + (np.ndarray, dict), + ), "Tuple observations are not supported." + wrapped_obs = types.maybe_wrap_in_dictobs(obs) + + # we use dictobs to iterate over the envs in a vecenv + for env_idx, ob in enumerate(wrapped_obs): # Seed with first obs only. Inside loop, we'll only add second obs from # each (s,a,r,s') tuple, under the same "obs" key again. That way we still # get all observations, but they're not duplicated into "next obs" and @@ -431,18 +439,17 @@ def generate_trajectories( # # To start with, all environments are active. active = np.ones(venv.num_envs, dtype=bool) - if not isinstance(obs, np.ndarray): - raise ValueError( - "Dict/tuple observations are not supported." - "Currently only np.ndarray observations are supported.", - ) - state = None dones = np.zeros(venv.num_envs, dtype=bool) while np.any(active): + # policy gets unwrapped observations (eg as dict, not dictobs) acts, state = get_actions(obs, state, dones) obs, rews, dones, infos = venv.step(acts) - assert isinstance(obs, np.ndarray) + assert isinstance( + obs, + (np.ndarray, dict), + ), "Tuple observations are not supported." + wrapped_obs = types.maybe_wrap_in_dictobs(obs) # If an environment is inactive, i.e. the episode completed for that # environment after `sample_until(trajectories)` was true, then we do @@ -452,7 +459,7 @@ def generate_trajectories( new_trajs = trajectories_accum.add_steps_and_auto_finish( acts, - obs, + wrapped_obs, rews, dones, infos, @@ -477,9 +484,18 @@ def generate_trajectories( for trajectory in trajectories: n_steps = len(trajectory.acts) # extra 1 for the end - exp_obs = (n_steps + 1,) + venv.observation_space.shape + if isinstance(venv.observation_space, spaces.Dict): + exp_obs = {} + for k, v in venv.observation_space.items(): + assert v.shape is not None + exp_obs[k] = (n_steps + 1,) + v.shape + else: + obs_space_shape = venv.observation_space.shape + assert obs_space_shape is not None + exp_obs = (n_steps + 1,) + obs_space_shape # type: ignore[assignment] real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" + assert venv.action_space.shape is not None exp_act = (n_steps,) + venv.action_space.shape real_act = trajectory.acts.shape assert real_act == exp_act, f"expected shape {exp_act}, got {real_act}" @@ -555,8 +571,19 @@ def flatten_trajectories( Returns: The trajectories flattened into a single batch of Transitions. """ + + def all_of_type(key, desired_type): + return all( + isinstance(getattr(traj, key), desired_type) for traj in trajectories + ) + + assert all_of_type("obs", types.DictObs) or all_of_type("obs", np.ndarray) + assert all_of_type("acts", np.ndarray) + + # mypy struggles without Any annotation here. + # The necessary constraints are enforced above. keys = ["obs", "next_obs", "acts", "dones", "infos"] - parts: Mapping[str, List[np.ndarray]] = {key: [] for key in keys} + parts: Mapping[str, List[Any]] = {key: [] for key in keys} for traj in trajectories: parts["acts"].append(traj.acts) @@ -575,7 +602,8 @@ def flatten_trajectories( parts["infos"].append(infos) cat_parts = { - key: np.concatenate(part_list, axis=0) for key, part_list in parts.items() + key: types.concatenate_maybe_dictobs(part_list) + for key, part_list in parts.items() } lengths = set(map(len, cat_parts.values())) assert len(lengths) == 1, f"expected one length, got {lengths}" @@ -587,7 +615,10 @@ def flatten_trajectories_with_rew( ) -> types.TransitionsWithRew: transitions = flatten_trajectories(trajectories) rews = np.concatenate([traj.rews for traj in trajectories]) - return types.TransitionsWithRew(**dataclasses.asdict(transitions), rews=rews) + return types.TransitionsWithRew( + **types.dataclass_quick_asdict(transitions), + rews=rews, + ) def generate_transitions( @@ -628,7 +659,7 @@ def generate_transitions( ) transitions = flatten_trajectories_with_rew(traj) if truncate and n_timesteps is not None: - as_dict = dataclasses.asdict(transitions) + as_dict = types.dataclass_quick_asdict(transitions) truncated = {k: arr[:n_timesteps] for k, arr in as_dict.items()} transitions = types.TransitionsWithRew(**truncated) return transitions diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 97d1b950b..573176ffe 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -1,15 +1,23 @@ """Types and helper methods for transitions and trajectories.""" +import collections import dataclasses +import itertools +import numbers import os import warnings from typing import ( Any, + Callable, Dict, + Iterable, + Iterator, + List, Mapping, Optional, Sequence, Tuple, + TypedDict, TypeVar, Union, overload, @@ -22,7 +30,286 @@ T = TypeVar("T") AnyPath = Union[str, bytes, os.PathLike] -TransitionMapping = Mapping[str, Union[np.ndarray, th.Tensor]] +AnyTensor = Union[np.ndarray, th.Tensor] +TensorVar = TypeVar("TensorVar", np.ndarray, th.Tensor) + + +@dataclasses.dataclass(frozen=True) +class DictObs: + """Stores observations from an environment with a dictionary observation space. + + Provides an interface that is similar to observations in a numpy array. + Length, slicing, indexing, and iterating operations will operate on the first + dimension of the constituent arrays, as they would for observations in a single + array. + + There are also utility functions for mapping / stacking / concatenating + lists of dictobs. + """ + + _d: Dict[str, np.ndarray] + + @classmethod + def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]) -> "DictObs": + """Stacks the observation list into a single DictObs.""" + return cls.stack(map(cls, obs_list)) + + def __post_init__(self): + if not all( + isinstance(v, (np.ndarray, numbers.Number)) for v in self._d.values() + ): + raise TypeError("Values must be NumPy arrays") + + def __len__(self): + """Returns the first dimension of constituent arrays. + + Only defined if there is at least one array, and all arrays have the same + length of first dimension. Otherwise raises ValueError. + + Len of a DictObs usually represents number of timesteps, or number of + environments in a VecEnv. + + Use `dict_len` to get the number of entries in the dictionary. + + Raises: + RuntimeError: if the arrays have different lengths or there are no arrays. + + Returns: + The length (first dimension) of the constituent arrays + """ + lens = set(len(v) for v in self._d.values()) + if len(lens) == 1: + return lens.pop() + elif len(lens) == 0: + raise RuntimeError("Length not defined as DictObs is empty") + else: + raise RuntimeError( + f"Length not defined; arrays have conflicting first dimensions: {lens}", + ) + + @property + def dict_len(self): + """Returns the number of arrays in the DictObs.""" + return len(self._d) + + def __getitem__( + self, + key: Union[int, slice, Tuple[Union[int, slice], ...]], + ) -> "DictObs": + """Indexes or slices each array. + + See `.get` for accessing a value from the underlying dictionary. + + Note that it will still return singleton values as np.arrays, not scalars, + to be consistent with DictObs type signature. + + Args: + key: a single slice + + Returns: + A new DictObj object with each array indexed. + """ + # asarray handles case where we slice to a single array element. + return self.__class__({k: np.asarray(v[key]) for k, v in self._d.items()}) + + def __iter__(self) -> Iterator["DictObs"]: + """Iterates over the first dimension of each array. + + Raises: + ValueError if len() is not defined. + + Returns: + Iterator of dictobjs by first dimension. + """ + return (self[i] for i in range(len(self))) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + if not self.keys() == other.keys(): + return False + return all(np.array_equal(self.get(k), other.get(k)) for k in self.keys()) + + @property + def shape(self) -> Dict[str, Tuple[int, ...]]: + """Returns a dictionary with shape-tuples in place of the arrays.""" + return {k: v.shape for k, v in self.items()} + + @property + def dtype(self) -> Dict[str, np.dtype]: + """Returns a dictionary with dtype-tuples in place of the arrays.""" + return {k: v.dtype for k, v in self.items()} + + def keys(self): + return self._d.keys() + + def values(self): + return self._d.values() + + def items(self): + return self._d.items() + + def __contains__(self, key): + return key in self._d + + def get(self, key: str) -> np.ndarray: + """Returns the array for the given key, or raises KeyError.""" + return self._d[key] + + def unwrap(self) -> Dict[str, np.ndarray]: + """Returns a copy of the underlying dictionary (arrays are not copied).""" + return dict(self._d) + + def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": + """Returns a new DictObs with `fn` applied to every array.""" + return self.__class__({k: fn(v) for k, v in self.items()}) + + @staticmethod + def _unravel(dictobs_list: Iterable["DictObs"]) -> Dict[str, List[np.ndarray]]: + """Converts a list of DictObs into a dictionary of lists of arrays.""" + it1, it2 = itertools.tee(dictobs_list) + # assert all have same keys + key_set = set(frozenset(obs.keys()) for obs in it1) + if len(key_set) == 0: + raise ValueError("Empty list of DictObs") + if not len(key_set) == 1: + raise ValueError(f"Inconsistent keys: {key_set}") + + unraveled: Dict[str, List[np.ndarray]] = collections.defaultdict(list) + for ob_dict in it2: + for k, array in ob_dict._d.items(): + unraveled[k].append(array) + return unraveled + + @classmethod + def stack(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": + """Returns a single dictobs stacking the arrays by key.""" + return cls( + { + k: np.stack(arr_list, axis=axis) + for k, arr_list in cls._unravel(dictobs_list).items() + }, + ) + + @classmethod + def concatenate(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": + """Returns a single dictobs concatenating the arrays by key.""" + return cls( + { + k: np.concatenate(arr_list, axis=axis) + for k, arr_list in cls._unravel(dictobs_list).items() + }, + ) + + +# DictObs utilities + + +Observation = Union[np.ndarray, DictObs] +ObsVar = TypeVar("ObsVar", np.ndarray, DictObs) + + +def assert_not_dictobs(x: Observation) -> np.ndarray: + """Typeguard to assert `x` is an array, not a DictObs.""" + assert not isinstance(x, DictObs), "Dictionary observations are not supported here." + return x + + +def concatenate_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: + """Concatenates a list of observations appropriately (depending on type).""" + assert len(arrs) > 0 + if isinstance(arrs[0], DictObs): + return DictObs.concatenate(arrs) + else: + return np.concatenate(arrs) + + +def stack_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: + """Stacks a list of observations appropriately (depending on type).""" + assert len(arrs) > 0 + if isinstance(arrs[0], DictObs): + return DictObs.stack(arrs) + else: + return np.stack(arrs) + + +# the following overloads have a type error as a DictObs matches both definitions, but +# the return types are incompatible. Ideally T would exclude DictObs but that's not +# possible. +@overload +def maybe_unwrap_dictobs( # type: ignore[misc] + maybe_dictobs: DictObs, +) -> Dict[str, np.ndarray]: + ... + + +@overload +def maybe_unwrap_dictobs(maybe_dictobs: T) -> T: + ... + + +def maybe_unwrap_dictobs(maybe_dictobs): + """Unwraps if a DictObs, otherwise returns the object.""" + if isinstance(maybe_dictobs, DictObs): + return maybe_dictobs.unwrap() + else: + if not isinstance(maybe_dictobs, (np.ndarray, th.Tensor, int)): + warnings.warn(f"trying to unwrap object of type {type(maybe_dictobs)}") + return maybe_dictobs + + +@overload +def maybe_wrap_in_dictobs(obs: Union[Dict[str, np.ndarray], DictObs]) -> DictObs: + ... + + +@overload +def maybe_wrap_in_dictobs(obs: np.ndarray) -> np.ndarray: + ... + + +def maybe_wrap_in_dictobs( + obs: Union[Dict[str, np.ndarray], np.ndarray, DictObs], +) -> Observation: + """Converts an observation into a DictObs, if necessary.""" + if isinstance(obs, dict): + return DictObs(obs) + else: + if not isinstance(obs, (np.ndarray, DictObs, float, int)): + warnings.warn(f"tried to wrap {type(obs)} as an observation") + return obs + + +def map_maybe_dict(fn, maybe_dict): + """Either maps fn over dictionary values or applies fn to `maybe_dict`. + + Args: + fn: function to apply. Must take a single argument. + maybe_dict: either a dict or a value that can be passed to fn. + + Returns: + Either a dict (if maybe_dict was a dict) or `fn(maybe_dict)`. + """ + if isinstance(maybe_dict, dict): + return {k: fn(v) for k, v in maybe_dict.items()} + else: + return fn(maybe_dict) + + +class TransitionMappingNoNextObs(TypedDict): + """Dictionary with `obs` and `acts`.""" + + obs: Union[Observation, th.Tensor] + acts: AnyTensor + + +# inheritance with total=False so these are not required +class TransitionMapping(TransitionMappingNoNextObs, total=False): + """Dictionary with `obs` and `acts`, maybe also `next_obs`, `dones`, `rew`.""" + + next_obs: Union[Observation, th.Tensor] + dones: AnyTensor + rew: AnyTensor def dataclass_quick_asdict(obj) -> Dict[str, Any]: @@ -32,6 +319,9 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]: undocumentedly deep-copies every numpy array value. See https://stackoverflow.com/a/52229565/1091722. + This is also used to preserve DictObj objects, as `dataclasses.asdict` + unwraps them recursively. + Args: obj: A dataclass instance. @@ -46,7 +336,7 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]: class Trajectory: """A trajectory, e.g. a one episode rollout from an expert policy.""" - obs: np.ndarray + obs: Observation """Observations, shape (trajectory_len + 1, ) + observation_shape.""" acts: np.ndarray @@ -75,7 +365,8 @@ def __eq__(self, other) -> bool: if not isinstance(other, Trajectory): return False - dict_self, dict_other = dataclasses.asdict(self), dataclasses.asdict(other) + dict_self = dataclass_quick_asdict(self) + dict_other = dataclass_quick_asdict(other) # Trajectory objects may still have different keys if different subclasses if dict_self.keys() != dict_other.keys(): return False @@ -91,6 +382,9 @@ def __eq__(self, other) -> bool: # Treat None equivalent to sequence of empty dicts self_v = [{}] * len(self) if self_v is None else self_v other_v = [{}] * len(other) if other_v is None else other_v + if isinstance(self_v, DictObs): + if not self_v == other_v: + return False if not np.array_equal(self_v, other_v): return False @@ -152,7 +446,7 @@ def __post_init__(self): def transitions_collate_fn( batch: Sequence[Mapping[str, np.ndarray]], -) -> Mapping[str, Union[np.ndarray, th.Tensor]]: +) -> Mapping[str, AnyTensor]: """Custom `torch.utils.data.DataLoader` collate_fn for `TransitionsMinimal`. Use this as the `collate_fn` argument to `DataLoader` if using an instance of @@ -167,13 +461,16 @@ def transitions_collate_fn( list of dicts. (The default behavior would recursively collate every info dict into a single dict, which is incorrect.) """ - batch_no_infos = [ - {k: np.array(v) for k, v in sample.items() if k != "infos"} for sample in batch + batch_acts_and_dones = [ + {k: np.array(v) for k, v in sample.items() if k in ["acts", "dones"]} + for sample in batch ] - result = th_data.dataloader.default_collate(batch_no_infos) + result = th_data.dataloader.default_collate(batch_acts_and_dones) assert isinstance(result, dict) result["infos"] = [sample["infos"] for sample in batch] + result["obs"] = stack_maybe_dictobs([sample["obs"] for sample in batch]) + result["next_obs"] = stack_maybe_dictobs([sample["next_obs"] for sample in batch]) return result @@ -196,7 +493,7 @@ class TransitionsMinimal(th_data.Dataset, Sequence[Mapping[str, np.ndarray]]): field has been sliced. """ - obs: np.ndarray + obs: Observation """ Previous observations. Shape: (batch_size, ) + observation_shape. @@ -283,7 +580,7 @@ def __getitem__(self, key): class Transitions(TransitionsMinimal): """A batch of obs-act-obs-done transitions.""" - next_obs: np.ndarray + next_obs: Observation """New observation. Shape: (batch_size, ) + observation_shape. The i'th observation `next_obs[i]` in this array is the observation diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index d00ea14b4..94c88111d 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -53,9 +53,11 @@ def reset(self, **kwargs): self.n_transitions = 0 obs = self.venv.reset(**kwargs) self._traj_accum = rollout.TrajectoryAccumulator() + obs = types.maybe_wrap_in_dictobs(obs) for i, ob in enumerate(obs): self._traj_accum.add_step({"obs": ob}, key=i) self._timesteps = np.zeros((len(obs),), dtype=int) + obs = types.maybe_unwrap_dictobs(obs) return obs def step_async(self, actions): @@ -187,20 +189,20 @@ def __init__(self, env: gym.Env): def reset(self, **kwargs): new_obs, info = super().reset(**kwargs) - self._obs = [new_obs] + self._obs = [types.maybe_wrap_in_dictobs(new_obs)] self._rews = [] return new_obs, info def step(self, action): obs, rew, terminated, truncated, info = self.env.step(action) done = terminated or truncated - self._obs.append(obs) + self._obs.append(types.maybe_wrap_in_dictobs(obs)) self._rews.append(rew) if done: assert "rollout" not in info info["rollout"] = { - "obs": np.stack(self._obs), + "obs": types.stack_maybe_dictobs(self._obs), "rews": np.stack(self._rews), } return obs, rew, terminated, truncated, info diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 4a4c6e2ab..ba1f550df 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -1,7 +1,7 @@ """Custom policy classes and convenience methods.""" import abc -from typing import Type +from typing import Dict, Type, Union import gymnasium as gym import numpy as np @@ -10,6 +10,7 @@ from stable_baselines3.sac import policies as sac_policies from torch import nn +from imitation.data import types from imitation.util import networks @@ -23,18 +24,31 @@ def __init__(self, observation_space: gym.Space, action_space: gym.Space): action_space=action_space, ) - def _predict(self, obs: th.Tensor, deterministic: bool = False): + def _predict( + self, + obs: Union[th.Tensor, Dict[str, th.Tensor]], + deterministic: bool = False, + ): np_actions = [] - np_obs = obs.detach().cpu().numpy() + if isinstance(obs, dict): + np_obs = types.DictObs( + {k: v.detach().cpu().numpy() for k, v in obs.items()}, + ) + else: + np_obs = obs.detach().cpu().numpy() for np_ob in np_obs: - assert self.observation_space.contains(np_ob) - np_actions.append(self._choose_action(np_ob)) + np_ob_unwrapped = types.maybe_unwrap_dictobs(np_ob) + assert self.observation_space.contains(np_ob_unwrapped) + np_actions.append(self._choose_action(np_ob_unwrapped)) np_actions = np.stack(np_actions, axis=0) th_actions = th.as_tensor(np_actions, device=self.device) return th_actions @abc.abstractmethod - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: """Chooses an action, optionally based on observation obs.""" def forward(self, *args): @@ -46,7 +60,10 @@ def forward(self, *args): class RandomPolicy(NonTrainablePolicy): """Returns random actions.""" - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: return self.action_space.sample() @@ -65,7 +82,10 @@ def __init__(self, observation_space: gym.Space, action_space: gym.Space): f"Zero action {self._zero_action} not in action space {action_space}", ) - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: return self._zero_action diff --git a/src/imitation/policies/exploration_wrapper.py b/src/imitation/policies/exploration_wrapper.py index 9151d5971..cde576466 100644 --- a/src/imitation/policies/exploration_wrapper.py +++ b/src/imitation/policies/exploration_wrapper.py @@ -1,6 +1,6 @@ """Wrapper to turn a policy into a more exploratory version.""" -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple, Union import numpy as np from stable_baselines3.common import vec_env @@ -57,7 +57,7 @@ def __init__( def _random_policy( self, - obs: np.ndarray, + obs: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[Tuple[np.ndarray, ...]], episode_start: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -74,7 +74,7 @@ def _switch(self) -> None: def __call__( self, - observation: np.ndarray, + observation: Union[np.ndarray, Dict[str, np.ndarray]], input_state: Optional[Tuple[np.ndarray, ...]], episode_start: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py index 648a1874f..d9934f9a0 100644 --- a/src/imitation/policies/interactive.py +++ b/src/imitation/policies/interactive.py @@ -2,7 +2,7 @@ import abc import collections -from typing import Optional, Union +from typing import Dict, Optional, Union import gymnasium as gym import matplotlib.pyplot as plt @@ -57,10 +57,16 @@ def __init__( } self.clear_screen_on_query = clear_screen_on_query - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: if self.clear_screen_on_query: util.clear_screen() + if isinstance(obs, dict): + raise ValueError("Dictionary observations are not supported here") + context = self._render(obs) key = self._get_input_key() self._clean_up(context) diff --git a/src/imitation/rewards/reward_wrapper.py b/src/imitation/rewards/reward_wrapper.py index 7afa551b3..b7db34a5c 100644 --- a/src/imitation/rewards/reward_wrapper.py +++ b/src/imitation/rewards/reward_wrapper.py @@ -8,6 +8,7 @@ from stable_baselines3.common import logger as sb_logger from stable_baselines3.common import vec_env +from imitation.data import types from imitation.rewards import reward_function @@ -95,14 +96,23 @@ def step_wait(self): # encounter a `done`, in which case the last observation corresponding to # the `done` is dropped. We're going to pull it back out of the info dict! obs_fixed = [] + obs = types.maybe_wrap_in_dictobs(obs) for single_obs, single_done, single_infos in zip(obs, dones, infos): if single_done: single_obs = single_infos["terminal_observation"] - obs_fixed.append(single_obs) - obs_fixed = np.stack(obs_fixed) - - rews = self.reward_fn(self._old_obs, self._actions, obs_fixed, np.array(dones)) + obs_fixed.append(types.maybe_wrap_in_dictobs(single_obs)) + obs_fixed = ( + types.DictObs.stack(obs_fixed) + if isinstance(obs, types.DictObs) + else np.stack(obs_fixed) + ) + rews = self.reward_fn( + self._old_obs, + self._actions, + types.maybe_unwrap_dictobs(obs_fixed), + np.array(dones), + ) assert len(rews) == len(obs), "must return one rew for each env" done_mask = np.asarray(dones, dtype="bool").reshape((len(dones),)) @@ -116,6 +126,7 @@ def step_wait(self): # we can just use obs instead of obs_fixed because on the next iteration # after a reset we DO want to access the first observation of the new # trajectory, not the last observation of the old trajectory + obs = types.maybe_unwrap_dictobs(obs) self._old_obs = obs for info_dict, old_rew in zip(infos, old_rews): info_dict["original_env_rew"] = old_rew diff --git a/tests/algorithms/conftest.py b/tests/algorithms/conftest.py index b3d4589fc..a453f047d 100644 --- a/tests/algorithms/conftest.py +++ b/tests/algorithms/conftest.py @@ -1,9 +1,11 @@ """Fixtures common across algorithm tests.""" from typing import Sequence +import gymnasium as gym import pytest +from stable_baselines3.common import envs from stable_baselines3.common.policies import BasePolicy -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv from imitation.algorithms import bc from imitation.data.types import TrajectoryWithRew @@ -109,3 +111,22 @@ def pendulum_single_venv(rng) -> VecEnv: post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], rng=rng, ) + + +# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 +# merged and released. +class FloatReward(gym.RewardWrapper): + """Typecasts reward to a float.""" + + def reward(self, reward): + return float(reward) + + +@pytest.fixture +def multi_obs_venv() -> VecEnv: + def make_env(): + env = envs.SimpleMultiObsEnv(channel_last=False) + env = FloatReward(env) + return RolloutInfoWrapper(env) + + return DummyVecEnv([make_env, make_env]) diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index d3609efaa..68d8e6208 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -393,9 +393,9 @@ def test_logits_expert_is_high_log_policy_act_prob( ) obs, acts, next_obs, dones = trainer_diverse_env.reward_train.preprocess( - trans.obs, + types.assert_not_dictobs(trans.obs), trans.acts, - trans.next_obs, + types.assert_not_dictobs(trans.next_obs), trans.dones, ) log_act_prob_non_none = np.log(0.1 + 0.9 * np.random.rand(n_timesteps)) diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 802ac0d7f..1654a1482 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -41,7 +41,10 @@ def test_check_fixed_horizon_flag(custom_logger): def _make_and_iterate_loader(*args, **kwargs): - loader = base.make_data_loader(*args, **kwargs) + # our pytype version doesn't understand optional arguments in TypedDict + # this is fixed in 2023.04.11, but we require 2022.7.26 + # See https://github.com/google/pytype/issues/1195 + loader = base.make_data_loader(*args, **kwargs) # pytype: disable=wrong-arg-types for batch in loader: pass diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 2adf12a22..8de49c66e 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -4,12 +4,15 @@ import os from typing import Any, Callable, Optional, Sequence +import gymnasium as gym import hypothesis import hypothesis.strategies as st import numpy as np import pytest import torch as th -from stable_baselines3.common import evaluation, vec_env +from stable_baselines3.common import evaluation +from stable_baselines3.common import policies as sb_policies +from stable_baselines3.common import vec_env from imitation.algorithms import bc from imitation.data import rollout, types @@ -287,6 +290,36 @@ def test_that_policy_reconstruction_preserves_parameters( th.testing.assert_close(original, reconstructed) +def test_dict_space(multi_obs_venv: vec_env.VecEnv): + # multi-input policy to accept dict observations + assert isinstance(multi_obs_venv.observation_space, gym.spaces.Dict) + policy = sb_policies.MultiInputActorCriticPolicy( + multi_obs_venv.observation_space, + multi_obs_venv.action_space, + lambda _: 0.001, + ) + rng = np.random.default_rng() + + # sample random transitions + rollouts = rollout.rollout( + policy=None, + venv=multi_obs_venv, + sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + unwrap=True, + ) + transitions = rollout.flatten_trajectories(rollouts) + bc_trainer = bc.BC( + observation_space=multi_obs_venv.observation_space, + policy=policy, + action_space=multi_obs_venv.action_space, + rng=rng, + demonstrations=transitions, + ) + # confirm that training works + bc_trainer.train(n_epochs=1) + + ############################################# # ENSURE EXCEPTIONS ARE THROWN WHEN EXPECTED ############################################# diff --git a/tests/algorithms/test_dagger.py b/tests/algorithms/test_dagger.py index 92aaf2da2..01c1c5088 100644 --- a/tests/algorithms/test_dagger.py +++ b/tests/algorithms/test_dagger.py @@ -114,7 +114,7 @@ def get_random_acts(obs): for info in infos: assert isinstance(info, dict) # roll out 5 * venv.num_envs episodes (Pendulum-v1 has 200 timestep episodes) - for i in range(1000): + for _ in range(1000): _, _, dones, _ = collector.step(zero_acts) num_episodes += np.sum(dones) @@ -163,7 +163,7 @@ def test_traj_collector_reproducible(tmpdir, pendulum_venv): (pendulum_venv.num_envs,) + pendulum_venv.action_space.shape, dtype=pendulum_venv.action_space.dtype, ) - for i in range(1000): + for _ in range(1000): _, _, dones, _ = collector.step(zero_acts) # Get the observations from all the collected trajectories. @@ -377,7 +377,7 @@ def test_trainer_makes_progress(init_trainer_fn, pendulum_venv, pendulum_expert_ assert np.mean(novice_rewards) < -1000 # Train for 5 iterations. (4 or fewer causes test to fail on some configs.) # see https://github.com/HumanCompatibleAI/imitation/issues/580 for details - for i in range(5): + for _ in range(5): # roll out a few trajectories for dataset, then train for a few steps collector = trainer.create_trajectory_collector() for _ in range(4): @@ -447,7 +447,7 @@ def test_trainer_reproducible( rng, ) - for i in range(2): + for _ in range(2): collector = trainer.create_trajectory_collector() obs = collector.reset() dones = [False] * pendulum_venv.num_envs diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index b64e14f0e..5c92feb58 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -1,12 +1,13 @@ """Tests for `imitation.algorithms.density_baselines`.""" from dataclasses import asdict -from typing import Sequence +from typing import Sequence, cast +import gymnasium as gym import numpy as np import pytest import stable_baselines3 -from stable_baselines3.common import policies +from stable_baselines3.common import policies, vec_env from imitation.algorithms.density import DensityAlgorithm, DensityType from imitation.data import rollout, types @@ -31,7 +32,8 @@ def score_trajectories( dones = np.zeros(len(traj), dtype=bool) dones[-1] = True steps = np.arange(0, len(traj.acts)) - rewards = density_reward(traj.obs[:-1], traj.acts, traj.obs[1:], dones, steps) + obs = types.assert_not_dictobs(traj.obs) + rewards = density_reward(obs[:-1], traj.acts, obs[1:], dones, steps) ret = np.sum(rewards) returns.append(ret) return returns @@ -118,15 +120,7 @@ def test_density_with_other_trajectory_types( ) rollouts = pendulum_expert_trajectories[:2] transitions = rollout.flatten_trajectories_with_rew(rollouts) - transitions_mappings = [ - asdict(transitions), - ] - - minimal_transitions = types.TransitionsMinimal( - obs=transitions.obs, - acts=transitions.acts, - infos=transitions.infos, - ) + transitions_mappings = [cast(types.TransitionMapping, asdict(transitions))] d = DensityAlgorithm( demonstrations=transitions_mappings, venv=pendulum_venv, @@ -137,6 +131,11 @@ def test_density_with_other_trajectory_types( d.train_policy(n_timesteps=2) d.test_policy(n_trajectories=2) + minimal_transitions = types.TransitionsMinimal( + obs=transitions.obs, + acts=transitions.acts, + infos=transitions.infos, + ) d = DensityAlgorithm( demonstrations=minimal_transitions, venv=pendulum_venv, @@ -169,3 +168,39 @@ def test_density_trainer_raises( with pytest.raises(TypeError, match="Unsupported demonstration type"): density_trainer.set_demonstrations("foo") # type: ignore[arg-type] + + +def test_dict_space(multi_obs_venv: vec_env.VecEnv): + # multi-input policy to accept dict observations + assert isinstance(multi_obs_venv.observation_space, gym.spaces.Dict) + rl_algo = stable_baselines3.PPO( + policies.MultiInputActorCriticPolicy, + multi_obs_venv, + n_steps=10, # small value to make test faster + n_epochs=2, # small value to make test faster + ) + rng = np.random.default_rng() + + # sample random transitions + sample_until = rollout.make_min_episodes(15) + rollouts = rollout.rollout( + policy=None, + venv=multi_obs_venv, + sample_until=sample_until, + rng=rng, + ) + density_trainer = DensityAlgorithm( + demonstrations=rollouts, + kernel="gaussian", + venv=multi_obs_venv, + rl_algo=rl_algo, + kernel_bandwidth=0.2, + standardise_inputs=True, + rng=rng, + # SimpleMultiObsEnv has early stop (issue #40) + allow_variable_horizon=True, + ) + # confirm that training works + density_trainer.train() + density_trainer.train_policy(n_timesteps=2) + density_trainer.test_policy(n_trajectories=2) diff --git a/tests/algorithms/test_mce_irl.py b/tests/algorithms/test_mce_irl.py index 5f9549ea4..e8e99edbe 100644 --- a/tests/algorithms/test_mce_irl.py +++ b/tests/algorithms/test_mce_irl.py @@ -24,7 +24,7 @@ def rollouts(env, n=10, seed=None): rv = [] - for i in range(n): + for _ in range(n): done = False # if a seed is given, then we use the same seed each time (should # give same trajectory each time) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 7cb043811..ef3444f72 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -94,7 +94,10 @@ def _check_trajs_equal( ): assert len(trajs1) == len(trajs2) for traj1, traj2 in zip(trajs1, trajs2): - assert np.array_equal(traj1.obs, traj2.obs) + assert np.array_equal( + types.assert_not_dictobs(traj1.obs), + types.assert_not_dictobs(traj2.obs), + ) assert np.array_equal(traj1.acts, traj2.acts) assert np.array_equal(traj1.rews, traj2.rews) assert traj1.infos is not None diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index 95a045fc0..e7615461e 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -135,28 +135,33 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype): sample = buf.sample(100) info_vals = np.array([info["a"] for info in sample.infos]) - assert sample.obs.shape == sample.next_obs.shape == (100,) + obs_shape + # dictobs not supported for buffers, or by current code in + # this test file (eg `_get_fill_from_chunk`) + obs = types.assert_not_dictobs(sample.obs) + next_obs = types.assert_not_dictobs(sample.next_obs) + + assert obs.shape == next_obs.shape == (100,) + obs_shape assert sample.acts.shape == (100,) + act_shape assert sample.dones.shape == (100,) assert info_vals.shape == (100,) # Are samples right data type? - assert sample.obs.dtype == dtype + assert obs.dtype == dtype assert sample.acts.dtype == dtype - assert sample.next_obs.dtype == dtype + assert next_obs.dtype == dtype assert info_vals.dtype == dtype assert sample.dones.dtype == bool assert sample.infos.dtype == object # Are samples in range? - _check_bound(i + chunk_len, capacity, sample.obs) - _check_bound(i + chunk_len, capacity, sample.next_obs, 3 * capacity) + _check_bound(i + chunk_len, capacity, obs) + _check_bound(i + chunk_len, capacity, next_obs, 3 * capacity) _check_bound(i + chunk_len, capacity, sample.acts, 6 * capacity) _check_bound(i + chunk_len, capacity, info_vals, 9 * capacity) # Are samples in-order? - obs_fill = _get_fill_from_chunk(sample.obs) - next_obs_fill = _get_fill_from_chunk(sample.next_obs) + obs_fill = _get_fill_from_chunk(obs) + next_obs_fill = _get_fill_from_chunk(next_obs) act_fill = _get_fill_from_chunk(sample.acts) info_vals_fill = _get_fill_from_chunk(info_vals) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 1f062469c..c8c8cd021 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -380,3 +380,49 @@ def test_rollout_normal_error_for_other_shape_mismatch(rng): rollout.make_sample_until(min_timesteps=None, min_episodes=2), rng=rng, ) + + +class DictObsWrapper(gym.ObservationWrapper): + """Simple wrapper that turns the observation into a dictionary. + + The observation is duplicated, with "b" rescaled. + """ + + def __init__(self, env: gym.Env): + """Builds DictObsWrapper. + + Args: + env: The wrapped Env. + """ + super().__init__(env) + self.observation_space = gym.spaces.Dict( + {"a": env.observation_space, "b": env.observation_space}, + ) + + def observation(self, observation): + return {"a": observation, "b": observation / 2} + + +def test_dictionary_observations(rng): + """Test we can generate a rollout for a dict-type observation environment. + + Args: + rng: Random state to use (with fixed seed). + """ + env = gym.make("CartPole-v1") + env = monitor.Monitor(env, None) + env = DictObsWrapper(env) + venv = vec_env.DummyVecEnv([lambda: env]) + + policy = serialize.load_policy("zero", venv) + trajs = rollout.generate_trajectories( + policy, + venv, + rollout.make_min_episodes(10), + rng=rng, + ) + for traj in trajs: + assert isinstance(traj.obs, types.DictObs) + for obs in traj.obs: + assert venv.observation_space.contains(dict(obs.items())) + np.testing.assert_allclose(traj.obs.get("a") / 2, traj.obs.get("b")) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 0580c8186..74c658c26 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -22,7 +22,11 @@ gym.spaces.Box(-1, 1, shape=(2,)), gym.spaces.Box(-np.inf, np.inf, shape=(2,)), ] -OBS_SPACES = SPACES +DICT_SPACE = gym.spaces.Dict( + {"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))}, +) + +OBS_SPACES = SPACES + [DICT_SPACE] ACT_SPACES = SPACES LENGTHS = [0, 1, 2, 10] @@ -42,7 +46,12 @@ def trajectory( """Fixture to generate trajectory of length `length` iid sampled from spaces.""" if length == 0: pytest.skip() - obs = np.array([obs_space.sample() for _ in range(length + 1)]) + + raw_obs = [obs_space.sample() for _ in range(length + 1)] + if isinstance(obs_space, gym.spaces.Dict): + obs: types.Observation = types.DictObs.from_obs_list(raw_obs) + else: + obs = np.array(raw_obs) acts = np.array([act_space.sample() for _ in range(length)]) infos = np.array([{f"key{i}": i} for i in range(length)]) return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True) @@ -52,7 +61,10 @@ def trajectory( def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew: """Like `trajectory` but with reward randomly sampled from a Gaussian.""" rews = np.random.randn(len(trajectory)) - return types.TrajectoryWithRew(**dataclasses.asdict(trajectory), rews=rews) + return types.TrajectoryWithRew( + **types.dataclass_quick_asdict(trajectory), + rews=rews, + ) @pytest.fixture @@ -77,7 +89,7 @@ def transitions( next_obs = np.array([obs_space.sample() for _ in range(length)]) dones = np.zeros(length, dtype=bool) return types.Transitions( - **dataclasses.asdict(transitions_min), + **types.dataclass_quick_asdict(transitions_min), next_obs=next_obs, dones=dones, ) @@ -90,7 +102,10 @@ def transitions_rew( ) -> types.TransitionsWithRew: """Like `transitions` but with reward randomly sampled from a Gaussian.""" rews = np.random.randn(length) - return types.TransitionsWithRew(**dataclasses.asdict(transitions), rews=rews) + return types.TransitionsWithRew( + **types.dataclass_quick_asdict(transitions), + rews=rews, + ) def _check_transitions_get_item(trans, key): @@ -179,20 +194,24 @@ def test_traj_unequal_to_perturbations( assert trajectory != types.Trajectory( obs=trajectory.obs[: new_length + 1], acts=trajectory.acts[:new_length], - infos=trajectory.obs[:new_length], + infos=trajectory.infos[:new_length] + if trajectory.infos is not None + else None, terminal=trajectory.terminal, ) # Or with contents changed for t in [trajectory, trajectory_rew]: - as_dict = dataclasses.asdict(t) + as_dict = types.dataclass_quick_asdict(t) for k in as_dict.keys(): perturbed = dict(as_dict) if k == "infos": perturbed["infos"] = [{"foo": 42}] * len(as_dict["infos"]) + elif isinstance(as_dict[k], types.DictObs): + perturbed[k] = as_dict[k].map_arrays(lambda x: x + 1) else: perturbed[k] = as_dict[k] + 1 - assert trajectory != type(t)(**perturbed) + assert t != type(t)(**perturbed) @pytest.mark.parametrize("type_safe", [False, True]) @pytest.mark.parametrize("use_pickle", [False, True]) @@ -208,6 +227,9 @@ def test_save_trajectories( use_rewards, type_safe, ): + if isinstance(trajectory.obs, types.DictObs): + pytest.xfail("Saving/loading dictobs trajectories not yet supported") + chdir_context: contextlib.AbstractContextManager """Check that trajectories are properly saved.""" if use_chdir: @@ -434,3 +456,66 @@ def test_parse_path(): # Parse optional path. Works the same way but passes None down the line. assert util.parse_optional_path(None) is None assert util.parse_optional_path("/foo/bar") == util.parse_path("/foo/bar") + + +def test_dict_obs(): + A = np.random.rand(3, 4) + B = np.random.rand(3, 7, 1) + C = np.random.rand(4) + + ab = types.DictObs({"a": A, "b": B}) + abc = types.DictObs({"a": A, "b": B, "c": C}) + + # len + assert len(ab) == 3 + with pytest.raises(RuntimeError): + len(abc) + with pytest.raises(RuntimeError): + len(types.DictObs({})) + + assert abc.dict_len == 3 + + # slicing + np.testing.assert_equal(abc[0].get("a"), A[0]) + np.testing.assert_equal(abc[0].get("c"), np.array(C[0])) + np.testing.assert_equal(abc[0:2].get("a"), np.array(A[0:2])) + np.testing.assert_equal(ab[:, 0].get("a"), np.array(A[:, 0])) + with pytest.raises(IndexError): + abc[:, 0] + + # iter + for i, a_row in enumerate(A): + np.testing.assert_equal(a_row, ab[i].get("a")) + assert ab[0] == next(iter(ab)) + + # eq + assert abc == types.DictObs({"a": A, "b": B, "c": C}) + assert abc == types.DictObs({"a": np.array(A), "b": np.array(B), "c": np.array(C)}) + assert abc != types.DictObs({"a": A, "c": B, "b": C}) # diff keys + assert abc != types.DictObs({"a": A, "b": B + 1, "c": C}) # diff values + assert abc != {"a": A, "b": B + 1, "c": C} # diff type + assert abc != ab # diff keys + + # shape / dtype + assert abc.shape == {"a": A.shape, "b": B.shape, "c": C.shape} + assert abc.dtype == {"a": A.dtype, "b": B.dtype, "c": C.dtype} + + # wrap + assert types.maybe_wrap_in_dictobs({"a": A, "b": B, "c": C}) == abc + assert abc.unwrap() == {"a": A, "b": B, "c": C} + + # map, stack, concat + assert abc.map_arrays(lambda arr: arr + 1) == types.DictObs( + {"a": A + 1, "b": B + 1, "c": C + 1}, + ) + assert types.DictObs.stack(list(iter(ab))) == ab + np.testing.assert_equal( + types.DictObs.concatenate([abc, abc]).get("a"), + np.concatenate([A, A]), + ) + + with pytest.raises(AssertionError): + types.assert_not_dictobs(abc) + + with pytest.raises(TypeError): + types.DictObs({"a": "not an array"}) # type: ignore[wrong-arg-types] diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index dc1905596..33677c68f 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,6 +1,6 @@ """Tests for `imitation.data.wrappers`.""" -from typing import List, Sequence +from typing import List, Sequence, Type import gymnasium as gym import numpy as np @@ -48,10 +48,40 @@ def step(self, action): return t, t * 10, done, False, {} +class _CountingDictEnv(_CountingEnv): # pragma: no cover + """Similar to _CountingEnv, but with Dict observation.""" + + def __init__(self, episode_length=5): + super().__init__(episode_length) + self.observation_space = gym.spaces.Dict( + spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, + ) + + def reset(self, seed=None): + t, self.timestep = 0.0, 1.0 + return {"t": t}, {} + + def step(self, action): + if self.timestep is None: + raise RuntimeError("Need to reset before first step().") + if self.timestep > self.episode_length: + raise RuntimeError("Episode is over. Need to step().") + if np.array(action) not in self.action_space: + raise ValueError(f"Invalid action {action}") + + t, self.timestep = self.timestep, self.timestep + 1 + done = t == self.episode_length + return {"t": t}, t * 10, done, False, {} + + +Envs = [_CountingEnv, _CountingDictEnv] + + def _make_buffering_venv( + Env: Type[gym.Env], error_on_premature_reset: bool, ) -> BufferingWrapper: - venv = DummyVecEnv([_CountingEnv] * 2) + venv = DummyVecEnv([Env] * 2) wrapped_venv = BufferingWrapper(venv, error_on_premature_reset) wrapped_venv.reset() return wrapped_venv @@ -86,10 +116,12 @@ def concat(x): ) +@pytest.mark.parametrize("Env", Envs) @pytest.mark.parametrize("episode_lengths", [(1,), (6, 5, 1, 2), (2, 2)]) @pytest.mark.parametrize("n_steps", [1, 2, 20, 21]) @pytest.mark.parametrize("extra_pop_timesteps", [(), (1,), (4, 8)]) def test_pop( + Env: Type[gym.Env], episode_lengths: Sequence[int], n_steps: int, extra_pop_timesteps: Sequence[int], @@ -119,6 +151,7 @@ def test_pop( ``` Args: + Env: Environment class type. episode_lengths: The number of timesteps before episode end in each dummy environment. n_steps: Number of times to call `step()` on the dummy environment. @@ -141,7 +174,7 @@ def test_pop( pytest.skip("pop timesteps out of bounds for this test case") def make_env(ep_len): - return lambda: _CountingEnv(episode_length=ep_len) + return lambda: Env(episode_length=ep_len) venv = DummyVecEnv([make_env(ep_len) for ep_len in episode_lengths]) venv_buffer = BufferingWrapper(venv) @@ -153,10 +186,13 @@ def make_env(ep_len): # Initial observation (only matters for pop_transitions()). obs = venv_buffer.reset() - np.testing.assert_array_equal(obs, [0] * venv.num_envs) + if Env == _CountingEnv: + np.testing.assert_array_equal(obs, [0] * venv.num_envs) + else: + np.testing.assert_array_equal(obs["t"], [0] * venv.num_envs) for t in range(1, n_steps + 1): - acts = obs * 2.1 + acts = obs * 2.1 if Env == _CountingEnv else obs["t"] * 2.1 venv_buffer.step_async(acts) obs, *_ = venv_buffer.step_wait() @@ -179,29 +215,35 @@ def make_env(ep_len): # Check `pop_transitions()` trans = _join_transitions(transitions_list) - - _assert_equal_scrambled_vectors(trans.obs, expect_obs) - _assert_equal_scrambled_vectors(trans.next_obs, expect_next_obs) + if Env == _CountingEnv: + actual_obs = types.assert_not_dictobs(trans.obs) + actual_next_obs = types.assert_not_dictobs(trans.next_obs) + else: + actual_obs = types.DictObs.stack(trans.obs).get("t") + actual_next_obs = types.DictObs.stack(trans.next_obs).get("t") + _assert_equal_scrambled_vectors(actual_obs, expect_obs) + _assert_equal_scrambled_vectors(actual_next_obs, expect_next_obs) _assert_equal_scrambled_vectors(trans.acts, expect_acts) _assert_equal_scrambled_vectors(trans.rews, expect_rews) -def test_reset_error(): +@pytest.mark.parametrize("Env", Envs) +def test_reset_error(Env: Type[gym.Env]): # Resetting before a `step()` is okay. for flag in [True, False]: - venv = _make_buffering_venv(flag) + venv = _make_buffering_venv(Env, flag) for _ in range(10): venv.reset() # Resetting after a `step()` is not okay if error flag is True. - venv = _make_buffering_venv(True) + venv = _make_buffering_venv(Env, True) zeros = np.array([0.0, 0.0], dtype=venv.action_space.dtype) venv.step(zeros) with pytest.raises(RuntimeError, match="before samples were accessed"): venv.reset() # Same as previous case, but insert a `pop_transitions()` in between. - venv = _make_buffering_venv(True) + venv = _make_buffering_venv(Env, True) venv.step(zeros) venv.pop_transitions() venv.step(zeros) @@ -209,20 +251,21 @@ def test_reset_error(): venv.reset() # Resetting after a `step()` is ok if error flag is False. - venv = _make_buffering_venv(False) + venv = _make_buffering_venv(Env, False) venv.step(zeros) venv.reset() # Resetting after a `step()` is ok if transitions are first collected. for flag in [True, False]: - venv = _make_buffering_venv(flag) + venv = _make_buffering_venv(Env, flag) venv.step(zeros) venv.pop_transitions() venv.reset() -def test_n_transitions_and_empty_error(): - venv = _make_buffering_venv(True) +@pytest.mark.parametrize("Env", Envs) +def test_n_transitions_and_empty_error(Env: Type[gym.Env]): + venv = _make_buffering_venv(Env, True) trajs, ep_lens = venv.pop_trajectories() assert trajs == [] assert ep_lens == []