From 5182ecffc9030f2df08091a91360e0b035ca68a4 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Wed, 13 Sep 2023 13:40:31 -0700 Subject: [PATCH 01/85] first pass of dict obs functionality --- src/imitation/algorithms/bc.py | 26 ++- src/imitation/algorithms/density.py | 7 +- src/imitation/algorithms/mce_irl.py | 18 +- .../algorithms/preference_comparisons.py | 4 +- src/imitation/data/buffer.py | 5 + src/imitation/data/rollout.py | 110 +++++++--- src/imitation/data/types.py | 193 ++++++++++++++++-- src/imitation/policies/exploration_wrapper.py | 8 +- src/imitation/util/util.py | 7 +- tests/algorithms/test_adversarial.py | 4 +- tests/algorithms/test_density_baselines.py | 3 +- .../algorithms/test_preference_comparisons.py | 4 +- tests/data/test_types.py | 4 +- tests/data/test_wrappers.py | 7 +- 14 files changed, 338 insertions(+), 62 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index a940d9cd9..52786a9e6 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -100,7 +100,7 @@ class BehaviorCloningLossCalculator: def __call__( self, policy: policies.ActorCriticPolicy, - obs: Union[th.Tensor, np.ndarray], + obs: Union[th.Tensor, np.ndarray, types.DictObs], acts: Union[th.Tensor, np.ndarray], ) -> BCTrainingMetrics: """Calculate the supervised learning loss used to train the behavioral clone. @@ -114,9 +114,18 @@ def __call__( A BCTrainingMetrics object with the loss and all the components it consists of. """ - obs = util.safe_to_tensor(obs) + tensor_obs: Union[th.Tensor, dict[str, th.Tensor]] + if isinstance(obs, types.DictObs): + tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.unwrap()} + else: + tensor_obs = util.safe_to_tensor(obs) acts = util.safe_to_tensor(acts) - _, log_prob, entropy = policy.evaluate_actions(obs, acts) + # TODO: add check obs is proper type? + # policy.evaluate_actions's type signature seems wrong to me. + # it declares it only takes a tensor but it calls + # extract_features which is happy with Dict[str, tensor]. + # In reality the required type of obs depends on the feature extractor. + _, log_prob, entropy = policy.evaluate_actions(tensor_obs, acts) # type: ignore prob_true_act = th.exp(log_prob).mean() log_prob = log_prob.mean() entropy = entropy.mean() if entropy is not None else None @@ -325,6 +334,7 @@ def __init__( self.rng = rng if policy is None: + # TODO: maybe default to comb. dict when dict obs space? policy = policy_base.FeedForward32Policy( observation_space=observation_space, action_space=action_space, @@ -465,8 +475,14 @@ def process_batch(): minibatch_size, num_samples_so_far, ), batch in batches_with_stats: - obs = th.as_tensor(batch["obs"], device=self.policy.device).detach() - acts = th.as_tensor(batch["acts"], device=self.policy.device).detach() + obs = types.DictObs.map( + batch["obs"], + lambda o: util.safe_to_tensor(o, device=self.policy.device).detach(), + ) + acts = util.safe_to_tensor( + batch["acts"], + device=self.policy.device, + ).detach() training_metrics = self.loss_calculator(self.policy, obs, acts) # Renormalise the loss to be averaged over the whole diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index fcc5e5ac9..378c6bf80 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -168,9 +168,11 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: if isinstance(demonstrations, types.TransitionsMinimal): next_obs_b = getattr(demonstrations, "next_obs", None) + if next_obs_b is not None: + next_obs_b = types.assert_not_dictobs(next_obs_b) transitions.update( self._get_demo_from_batch( - demonstrations.obs, + types.assert_not_dictobs(demonstrations.obs), demonstrations.acts, next_obs_b, ), @@ -191,8 +193,9 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: demonstrations = cast(Iterable[types.Trajectory], demonstrations) for traj in demonstrations: + traj_obs = types.assert_not_dictobs(traj.obs) for i, (obs, act, next_obs) in enumerate( - zip(traj.obs[:-1], traj.acts, traj.obs[1:]), + zip(traj_obs[:-1], traj.acts, traj_obs[1:]), ): flat_trans = self._preprocess_transition(obs, act, next_obs) transitions.setdefault(i, []).append(flat_trans) diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index fde7ac228..fba45e0d3 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -347,6 +347,10 @@ def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None num_demos = 0 for traj in trajs: cum_discount = 1.0 + if isinstance(traj.obs, types.DictObs): + raise ValueError( + "Dictionary observations are not currently supported for mce_irl" + ) for obs in traj.obs: self.demo_state_om[obs] += cum_discount cum_discount *= self.discount @@ -411,12 +415,14 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: if isinstance(demonstrations, types.Transitions): self._set_demo_from_obs( - demonstrations.obs, + types.assert_not_dictobs(demonstrations.obs), demonstrations.dones, - demonstrations.next_obs, + types.assert_not_dictobs(demonstrations.next_obs), ) elif isinstance(demonstrations, types.TransitionsMinimal): - self._set_demo_from_obs(demonstrations.obs, None, None) + self._set_demo_from_obs( + types.assert_not_dictobs(demonstrations.obs), None, None + ) elif isinstance(demonstrations, Iterable): # Demonstrations are a Torch DataLoader or other Mapping iterable # Collect them together into one big NumPy array. This is inefficient, @@ -427,7 +433,11 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: assert isinstance(batch, Mapping) for k in ("obs", "dones", "next_obs"): if k in batch: - collated_list[k].append(batch[k]) + if isinstance(batch[k], types.DictObs): + raise ValueError( + "Dictionary observations are not currently supported for buffers" + ) + collated_list[k].append() collated = {k: np.concatenate(v) for k, v in collated_list.items()} assert "obs" in collated diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 413cd979a..14a8fad5b 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -465,9 +465,9 @@ def rewards(self, transitions: Transitions) -> th.Tensor: Shape - (num_transitions, ) for Single reward network and (num_transitions, num_networks) for ensemble of networks. """ - state = transitions.obs + state = types.assert_not_dictobs(transitions.obs) action = transitions.acts - next_state = transitions.next_obs + next_state = types.assert_not_dictobs(transitions.next_obs) done = transitions.dones if self.ensemble_model is not None: rews_np = self.ensemble_model.predict_processed_all( diff --git a/src/imitation/data/buffer.py b/src/imitation/data/buffer.py index 5554f403a..c6a3783dc 100644 --- a/src/imitation/data/buffer.py +++ b/src/imitation/data/buffer.py @@ -345,6 +345,11 @@ def from_data( Returns: A new ReplayBuffer. """ + if isinstance(transitions.obs, types.DictObs): + raise ValueError( + "Dictionary observations are not currently supported for buffers" + ) + obs_shape = transitions.obs.shape[1:] act_shape = transitions.acts.shape[1:] if capacity is None: diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index add281a65..40d41df0e 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -15,9 +15,12 @@ Sequence, Tuple, Union, + cast, + overload, ) import numpy as np +from gym import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.utils import check_for_correct_spaces @@ -69,7 +72,7 @@ def __init__(self): def add_step( self, - step_dict: Mapping[str, Union[np.ndarray, Mapping[str, Any]]], + step_dict: Mapping[str, Union[np.ndarray, Mapping[str, Any], types.DictObs]], key: Hashable = None, ) -> None: """Add a single step to the partial trajectory identified by `key`. @@ -107,17 +110,22 @@ def finish_trajectory( for part_dict in part_dicts: for k, array in part_dict.items(): out_dict_unstacked[k].append(array) - out_dict_stacked = { - k: np.stack(arr_list, axis=0) for k, arr_list in out_dict_unstacked.items() - } - traj = types.TrajectoryWithRew(**out_dict_stacked, terminal=terminal) - assert traj.rews.shape[0] == traj.acts.shape[0] == traj.obs.shape[0] - 1 + + # TODO: what about infos? Does this actually handle them well? + traj = types.TrajectoryWithRew( + obs=types.stack(out_dict_unstacked["obs"]), + acts=np.stack(out_dict_unstacked["acts"], axis=0), + infos=np.stack(out_dict_unstacked["infos"], axis=0), # TODO: confused + rews=np.stack(out_dict_unstacked["rews"], axis=0), + terminal=terminal, + ) + assert traj.rews.shape[0] == traj.acts.shape[0] == len(traj.obs) - 1 return traj def add_steps_and_auto_finish( self, acts: np.ndarray, - obs: np.ndarray, + obs: Union[np.ndarray, dict[str, np.ndarray], types.DictObs], rews: np.ndarray, dones: np.ndarray, infos: List[dict], @@ -142,20 +150,26 @@ def add_steps_and_auto_finish( each `True` in the `dones` argument. """ trajs: List[types.TrajectoryWithRew] = [] - for env_idx in range(len(obs)): + wrapped_obs = types.DictObs.maybe_wrap(obs) + + # len of dictobs is the shape[0] of each value array - which here is # of envs + for env_idx in range(len(wrapped_obs)): assert env_idx in self.partial_trajectories assert list(self.partial_trajectories[env_idx][0].keys()) == ["obs"], ( "Need to first initialize partial trajectory using " "self._traj_accum.add_step({'obs': ob}, key=env_idx)" ) - zip_iter = enumerate(zip(acts, obs, rews, dones, infos)) + zip_iter = enumerate(zip(acts, wrapped_obs, rews, dones, infos)) for env_idx, (act, ob, rew, done, info) in zip_iter: if done: # When dones[i] from VecEnv.step() is True, obs[i] is the first # observation following reset() of the ith VecEnv, and # infos[i]["terminal_observation"] is the actual final observation. real_ob = info["terminal_observation"] + if isinstance(real_ob, dict): + # TODO: does this need to be unsqueezed or something? + real_ob = types.DictObs(real_ob) else: real_ob = ob @@ -268,7 +282,11 @@ def sample_until(trajs: Sequence[types.TrajectoryWithRew]) -> bool: # array of states, and an optional array of episode starts and returns an array of # corresponding actions. PolicyCallable = Callable[ - [np.ndarray, Optional[Tuple[np.ndarray, ...]], Optional[np.ndarray]], + [ + Union[np.ndarray, types.DictObs], + Optional[Tuple[np.ndarray, ...]], + Optional[np.ndarray], + ], Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]], ] AnyPolicy = Union[BaseAlgorithm, BasePolicy, PolicyCallable, None] @@ -284,7 +302,7 @@ def policy_to_callable( if policy is None: def get_actions( - observations: np.ndarray, + observations: Union[np.ndarray, types.DictObs], states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -298,7 +316,7 @@ def get_actions( # (which would call .forward()). So this elif clause must come first! def get_actions( - observations: np.ndarray, + observations: Union[np.ndarray, types.DictObs], states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -306,7 +324,7 @@ def get_actions( # pytype doesn't seem to understand that policy is a BaseAlgorithm # or BasePolicy here, rather than a Callable (acts, states) = policy.predict( # pytype: disable=attribute-error - observations, + types.DictObs.maybe_unwrap(observations), state=states, episode_start=episode_starts, deterministic=deterministic_policy, @@ -403,7 +421,21 @@ def generate_trajectories( # accumulator for incomplete trajectories trajectories_accum = TrajectoryAccumulator() obs = venv.reset() - for env_idx, ob in enumerate(obs): + + assert isinstance( + obs, + ( + np.ndarray, + dict, + ), + ), "Tuple observations are not supported." + + # need to wrap here to iterate over envs properly + wrapped_obs = types.DictObs.maybe_wrap(obs) + # TODO: make this nicer, it's currently non-mypy compliant + # probably want helper + + for env_idx, ob in enumerate(wrapped_obs): # Seed with first obs only. Inside loop, we'll only add second obs from # each (s,a,r,s') tuple, under the same "obs" key again. That way we still # get all observations, but they're not duplicated into "next obs" and @@ -419,13 +451,14 @@ def generate_trajectories( # # To start with, all environments are active. active = np.ones(venv.num_envs, dtype=bool) - assert isinstance(obs, np.ndarray), "Dict/tuple observations are not supported." state = None dones = np.zeros(venv.num_envs, dtype=bool) while np.any(active): acts, state = get_actions(obs, state, dones) obs, rews, dones, infos = venv.step(acts) - assert isinstance(obs, np.ndarray) + assert isinstance( + obs, (np.ndarray, dict) + ), "Tuple observations are not supported." # If an environment is inactive, i.e. the episode completed for that # environment after `sample_until(trajectories)` was true, then we do @@ -460,9 +493,10 @@ def generate_trajectories( for trajectory in trajectories: n_steps = len(trajectory.acts) # extra 1 for the end - exp_obs = (n_steps + 1,) + venv.observation_space.shape - real_obs = trajectory.obs.shape - assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" + if not isinstance(venv.observation_space, spaces.dict.Dict): + exp_obs = (n_steps + 1,) + venv.observation_space.shape + real_obs = types.assert_not_dictobs(trajectory.obs).shape + assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" exp_act = (n_steps,) + venv.action_space.shape real_act = trajectory.acts.shape assert real_act == exp_act, f"expected shape {exp_act}, got {real_act}" @@ -527,6 +561,30 @@ def rollout_stats( return out_stats +# TODO: I don't love these helpers here. + + +@overload +def concat_arrays_or_dictobs(arrs: Iterable[types.DictObs]) -> types.DictObs: + ... + + +@overload +def concat_arrays_or_dictobs(arrs: Iterable[np.ndarray]) -> np.ndarray: + ... + + +# TODO: awkward that it officially accepts union of +def concat_arrays_or_dictobs(arrs): + if isinstance(arrs[0], types.DictObs): + assert all((isinstance(a, types.DictObs) for a in arrs)) + return types.DictObs.concatenate(arrs) + else: + assert all((isinstance(a, np.ndarray) for a in arrs)) + cast_arrs = cast(Iterable[np.ndarray], arrs) + return np.concatenate(cast_arrs) + + def flatten_trajectories( trajectories: Iterable[types.Trajectory], ) -> types.Transitions: @@ -539,7 +597,8 @@ def flatten_trajectories( The trajectories flattened into a single batch of Transitions. """ keys = ["obs", "next_obs", "acts", "dones", "infos"] - parts: Mapping[str, List[np.ndarray]] = {key: [] for key in keys} + # TODO: sad to use Any here + parts: Mapping[str, List[Any]] = {key: [] for key in keys} for traj in trajectories: parts["acts"].append(traj.acts) @@ -557,12 +616,17 @@ def flatten_trajectories( infos = traj.infos parts["infos"].append(infos) - cat_parts = { - key: np.concatenate(part_list, axis=0) for key, part_list in parts.items() - } + cat_parts = {key: types.concatenate(part_list) for key, part_list in parts.items()} lengths = set(map(len, cat_parts.values())) assert len(lengths) == 1, f"expected one length, got {lengths}" return types.Transitions(**cat_parts) + # TODO: clean + # cat_parts["obs"], + # types.assert_not_dictobs(cat_parts["acts"]), + # types.assert_not_dictobs(cat_parts["infos"]), + # cat_parts["next_obs"], + # types.assert_not_dictobs(cat_parts["done"]), + # ) def flatten_trajectories_with_rew( diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 97d1b950b..068a264c8 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -1,11 +1,15 @@ """Types and helper methods for transitions and trajectories.""" +import collections import dataclasses import os import warnings from typing import ( Any, + Callable, Dict, + Iterable, + List, Mapping, Optional, Sequence, @@ -22,6 +26,124 @@ T = TypeVar("T") AnyPath = Union[str, bytes, os.PathLike] + + +@dataclasses.dataclass(frozen=True) +class DictObs: + # TODO: Docs! + """ + Stores observations from an environment with a dictionary observation space. + This class serves two purposes: + 1. enforcing invariants on the observations + 2. providing an interface that more closely reflects that of observations + stored in a numpy array. + + Observations are in the format dict[str, np.ndarray]. + DictObs enforces that: + - all arrays have equal first dimension. This dimension usually represents + timesteps, but sometimes represents different environments in a vecenv. + + For the inteface, DictObs provides: + - len(DictObs) returns the first dimension of the arrays(enforced to be equal) + - slicing/indexing along this first dimension (returning a dictobs) + - iterating (yeilds a series of dictobs, iterating over first dimension of each array) + + There are some other convenience functions for mapping / stacking / concatenating + lists of dictobs. + """ + + # TODO: should support th.tensor? + d: dict[str, np.ndarray] + + def __len__(self): + lens = set(len(v) for v in self.d.values()) + if len(lens) == 1: + return lens.pop() + else: + raise ValueError(f"observations of conflicting lengths found: {lens}") + + @classmethod + def map(cls, obj: Union["DictObs", np.ndarray, th.Tensor], fn: Callable): + if isinstance(obj, cls): + return cls({k: fn(v) for k, v in obj.d.items()}) + else: + return fn(obj) + + @classmethod + def stack(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": + unstacked: dict[str, list[np.ndarray]] = collections.defaultdict(list) + for do in dictobs_list: + for k, array in do.d.items(): + unstacked[k].append(array) + stacked = {k: np.stack(arr_list) for k, arr_list in unstacked.items()} + return cls(stacked) + + @classmethod + def concatenate(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": + unstacked: dict[str, list[np.ndarray]] = collections.defaultdict(list) + for do in dictobs_list: + for k, array in do.d.items(): + unstacked[k].append(array) + catted = { + k: np.concatenate(arr_list, axis=0) for k, arr_list in unstacked.items() + } + return cls(catted) + + def __getitem__(self, key: Union[slice, int]) -> "DictObs": + # TODO assert just one slice? (multi dimensional slices are sketchy) + # TODO test + # TODO compare to adam's below + + # TODO -- what if you slice for a single timestep? This would presumably invalidate + # the equal-first-dimension invariant. Do we want that invariant? + # maybe it should be optional at least? + return DictObs.map(self, lambda a: a[key]) + + def __iter__(self): + return (self[i] for i in range(len(self))) + + def shape(self) -> dict[str, tuple[int, ...]]: + return {k: v.shape for k, v in self.d.items()} + + def dtype(self) -> dict[str, tuple[int, ...]]: + return {k: v.dtype for k, v in self.d.items()} + + def unwrap(self) -> dict: + return self.d + + @classmethod + def maybe_wrap( + cls, obs: Union[dict[str, np.ndarray], np.ndarray, "DictObs"] + ) -> Union["DictObs", np.ndarray]: + """If `obs` is a dict, wraps in a dict obs. + If `obs` is an array or already an obsdict, returns it unchanged""" + if isinstance(obs, dict): + return cls(obs) + else: + assert isinstance(obs, (np.ndarray, cls)) + return obs + + @classmethod + def maybe_unwrap( + cls, maybe_dictobs: Union["DictObs", np.ndarray] + ) -> Union[dict[str, np.ndarray], np.ndarray]: + if isinstance(maybe_dictobs, cls): + return maybe_dictobs.unwrap() + else: + assert isinstance(maybe_dictobs, (np.ndarray, th.Tensor)) + return maybe_dictobs + + # TODO: eq + # TODO: post_init len check? + + +def assert_not_dictobs(x: Union[np.ndarray, DictObs]) -> np.ndarray: + if isinstance(x, DictObs): + raise ValueError("Dictionary observations are not supported here.") + return x + + +# TODO: maybe should support DictObs? TransitionMapping = Mapping[str, Union[np.ndarray, th.Tensor]] @@ -46,7 +168,7 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]: class Trajectory: """A trajectory, e.g. a one episode rollout from an expert policy.""" - obs: np.ndarray + obs: Union[np.ndarray, DictObs] """Observations, shape (trajectory_len + 1, ) + observation_shape.""" acts: np.ndarray @@ -168,12 +290,16 @@ def transitions_collate_fn( info dict into a single dict, which is incorrect.) """ batch_no_infos = [ - {k: np.array(v) for k, v in sample.items() if k != "infos"} for sample in batch + {k: np.array(v) for k, v in sample.items() if k in ["acts", "dones"]} + for sample in batch ] result = th_data.dataloader.default_collate(batch_no_infos) assert isinstance(result, dict) result["infos"] = [sample["infos"] for sample in batch] + result["obs"] = stack([sample["obs"] for sample in batch]) + result["next_obs"] = stack([sample["next_obs"] for sample in batch]) + # TODO: clean names, docs return result @@ -196,7 +322,7 @@ class TransitionsMinimal(th_data.Dataset, Sequence[Mapping[str, np.ndarray]]): field has been sliced. """ - obs: np.ndarray + obs: Union[np.ndarray, DictObs] """ Previous observations. Shape: (batch_size, ) + observation_shape. @@ -283,7 +409,7 @@ def __getitem__(self, key): class Transitions(TransitionsMinimal): """A batch of obs-act-obs-done transitions.""" - next_obs: np.ndarray + next_obs: Union[np.ndarray, DictObs] """New observation. Shape: (batch_size, ) + observation_shape. The i'th observation `next_obs[i]` in this array is the observation @@ -304,16 +430,19 @@ class Transitions(TransitionsMinimal): def __post_init__(self): """Performs input validation: check shapes & dtypes match docstring.""" super().__post_init__() - if self.obs.shape != self.next_obs.shape: - raise ValueError( - "obs and next_obs must have same shape: " - f"{self.obs.shape} != {self.next_obs.shape}", - ) - if self.obs.dtype != self.next_obs.dtype: - raise ValueError( - "obs and next_obs must have the same dtype: " - f"{self.obs.dtype} != {self.next_obs.dtype}", - ) + # TODO: could add support for checking dictobs shape/dtype of each array + # would be nice for debugging to have a dictobs.shapes -> dict{str: shape} + if isinstance(self.obs, np.ndarray): + if self.obs.shape != self.next_obs.shape: + raise ValueError( + "obs and next_obs must have same shape: " + f"{self.obs.shape} != {self.next_obs.shape}", + ) + if self.obs.dtype != self.next_obs.dtype: + raise ValueError( + "obs and next_obs must have the same dtype: " + f"{self.obs.dtype} != {self.next_obs.dtype}", + ) if self.dones.shape != (len(self.acts),): raise ValueError( "dones must be 1D array, one entry for each timestep: " @@ -339,3 +468,39 @@ def __post_init__(self): """Performs input validation, including for rews.""" super().__post_init__() _rews_validation(self.rews, self.acts) + + +ObsType = TypeVar("ObsType", np.ndarray, DictObs) + + +def concatenate(arrs: List[ObsType]) -> ObsType: + assert len(arrs) > 0 + if isinstance(arrs[0], DictObs): + return DictObs.concatenate(arrs) + else: + return np.concatenate(arrs) + + +def stack(arrs: List[ObsType]) -> ObsType: + assert len(arrs) > 0 + if isinstance(arrs[0], DictObs): + return DictObs.stack(arrs) + else: + return np.stack(arrs) + + +# class ObsList(collections.UserList[O]): +# def isDict(self) -> bool: +# return (len(self) > 0) and isinstance(self[0], DictObs) + +# def concatenate(self) -> O: +# if self.isDict(): +# return DictObs.concatenate(self) +# else: +# return np.concatenate(self) + +# def stack(self) -> O: +# if self.isDict(): +# return DictObs.concatenate(self) +# else: +# return np.concatenate(self) diff --git a/src/imitation/policies/exploration_wrapper.py b/src/imitation/policies/exploration_wrapper.py index 9151d5971..447708745 100644 --- a/src/imitation/policies/exploration_wrapper.py +++ b/src/imitation/policies/exploration_wrapper.py @@ -1,11 +1,11 @@ """Wrapper to turn a policy into a more exploratory version.""" -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import numpy as np from stable_baselines3.common import vec_env -from imitation.data import rollout +from imitation.data import rollout, types from imitation.util import util @@ -57,7 +57,7 @@ def __init__( def _random_policy( self, - obs: np.ndarray, + obs: Union[np.ndarray, types.DictObs], state: Optional[Tuple[np.ndarray, ...]], episode_start: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -74,7 +74,7 @@ def _switch(self) -> None: def __call__( self, - observation: np.ndarray, + observation: Union[np.ndarray, types.DictObs], input_state: Optional[Tuple[np.ndarray, ...]], episode_start: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 2abae1605..443333e81 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -230,13 +230,18 @@ def endless_iter(iterable: Iterable[T]) -> Iterator[T]: return itertools.chain.from_iterable(itertools.repeat(iterable)) -def safe_to_tensor(array: Union[np.ndarray, th.Tensor], **kwargs) -> th.Tensor: +def safe_to_tensor( + array: Union[np.ndarray, th.Tensor], + **kwargs, +) -> th.Tensor: """Converts a NumPy array to a PyTorch tensor. The data is copied in the case where the array is non-writable. Unfortunately if you just use `th.as_tensor` for this, an ugly warning is logged and there's undefined behavior if you try to write to the tensor. + `array` can also be a dictionary, in which case all values will be converted. + Args: array: The array to convert to a PyTorch tensor. kwargs: Additional keyword arguments to pass to `th.as_tensor`. diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index d3609efaa..68d8e6208 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -393,9 +393,9 @@ def test_logits_expert_is_high_log_policy_act_prob( ) obs, acts, next_obs, dones = trainer_diverse_env.reward_train.preprocess( - trans.obs, + types.assert_not_dictobs(trans.obs), trans.acts, - trans.next_obs, + types.assert_not_dictobs(trans.next_obs), trans.dones, ) log_act_prob_non_none = np.log(0.1 + 0.9 * np.random.rand(n_timesteps)) diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index b64e14f0e..a7288e1bf 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -31,7 +31,8 @@ def score_trajectories( dones = np.zeros(len(traj), dtype=bool) dones[-1] = True steps = np.arange(0, len(traj.acts)) - rewards = density_reward(traj.obs[:-1], traj.acts, traj.obs[1:], dones, steps) + obs = types.assert_not_dictobs(traj.obs) + rewards = density_reward(obs[:-1], traj.acts, obs[1:], dones, steps) ret = np.sum(rewards) returns.append(ret) return returns diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 5f237812f..6baeb90f8 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -94,7 +94,9 @@ def _check_trajs_equal( ): assert len(trajs1) == len(trajs2) for traj1, traj2 in zip(trajs1, trajs2): - assert np.array_equal(traj1.obs, traj2.obs) + assert np.array_equal( + types.assert_not_dictobs(traj1.obs), types.assert_not_dictobs(traj2.obs) + ) assert np.array_equal(traj1.acts, traj2.acts) assert np.array_equal(traj1.rews, traj2.rews) assert traj1.infos is not None diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 4ff4b1681..1f5bcc8ba 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -179,7 +179,9 @@ def test_traj_unequal_to_perturbations( assert trajectory != types.Trajectory( obs=trajectory.obs[: new_length + 1], acts=trajectory.acts[:new_length], - infos=trajectory.obs[:new_length], + infos=trajectory.infos[:new_length] + if trajectory.infos is not None + else None, terminal=trajectory.terminal, ) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 14a7626c8..03c584559 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -180,8 +180,11 @@ def make_env(ep_len): # Check `pop_transitions()` trans = _join_transitions(transitions_list) - _assert_equal_scrambled_vectors(trans.obs, expect_obs) - _assert_equal_scrambled_vectors(trans.next_obs, expect_next_obs) + _assert_equal_scrambled_vectors(types.assert_not_dictobs(trans.obs), expect_obs) + _assert_equal_scrambled_vectors( + types.assert_not_dictobs(trans.next_obs), + expect_next_obs, + ) _assert_equal_scrambled_vectors(trans.acts, expect_acts) _assert_equal_scrambled_vectors(trans.rews, expect_rews) From 61d816b634cecc5c938d6e51ce5d2dbd1c4d7c94 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Wed, 13 Sep 2023 16:34:01 -0700 Subject: [PATCH 02/85] cleanup DictObs --- src/imitation/algorithms/bc.py | 27 ++--- src/imitation/data/rollout.py | 10 +- src/imitation/data/types.py | 176 +++++++++++++++++---------------- src/imitation/util/util.py | 1 + 4 files changed, 117 insertions(+), 97 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 52786a9e6..08e2611b2 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -9,6 +9,7 @@ from typing import ( Any, Callable, + Dict, Iterable, Iterator, Mapping, @@ -114,9 +115,9 @@ def __call__( A BCTrainingMetrics object with the loss and all the components it consists of. """ - tensor_obs: Union[th.Tensor, dict[str, th.Tensor]] + tensor_obs: Union[th.Tensor, Dict[str, th.Tensor]] if isinstance(obs, types.DictObs): - tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.unwrap()} + tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.unwrap().items()} else: tensor_obs = util.safe_to_tensor(obs) acts = util.safe_to_tensor(acts) @@ -475,15 +476,19 @@ def process_batch(): minibatch_size, num_samples_so_far, ), batch in batches_with_stats: - obs = types.DictObs.map( - batch["obs"], - lambda o: util.safe_to_tensor(o, device=self.policy.device).detach(), - ) - acts = util.safe_to_tensor( - batch["acts"], - device=self.policy.device, - ).detach() - training_metrics = self.loss_calculator(self.policy, obs, acts) + obs_tensor: Union[th.Tensor, Dict[str, th.Tensor]] + if isinstance(batch["obs"], types.DictObs): + obs_dict = batch["obs"].unwrap() + obs_tensor = { + k: util.safe_to_tensor(v, device=self.policy.device) + for k, v in obs_dict.items() + } + else: + obs_tensor = util.safe_to_tensor( + batch["obs"], device=self.policy.device + ) + acts = util.safe_to_tensor(batch["acts"], device=self.policy.device) + training_metrics = self.loss_calculator(self.policy, obs_tensor, acts) # Renormalise the loss to be averaged over the whole # batch size instead of the minibatch size. diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 40d41df0e..09d130f86 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -113,7 +113,7 @@ def finish_trajectory( # TODO: what about infos? Does this actually handle them well? traj = types.TrajectoryWithRew( - obs=types.stack(out_dict_unstacked["obs"]), + obs=types.stack_maybe_dictobs(out_dict_unstacked["obs"]), acts=np.stack(out_dict_unstacked["acts"], axis=0), infos=np.stack(out_dict_unstacked["infos"], axis=0), # TODO: confused rews=np.stack(out_dict_unstacked["rews"], axis=0), @@ -457,7 +457,8 @@ def generate_trajectories( acts, state = get_actions(obs, state, dones) obs, rews, dones, infos = venv.step(acts) assert isinstance( - obs, (np.ndarray, dict) + obs, + (np.ndarray, dict), ), "Tuple observations are not supported." # If an environment is inactive, i.e. the episode completed for that @@ -616,7 +617,10 @@ def flatten_trajectories( infos = traj.infos parts["infos"].append(infos) - cat_parts = {key: types.concatenate(part_list) for key, part_list in parts.items()} + cat_parts = { + key: types.concatenate_maybe_dictobs(part_list) + for key, part_list in parts.items() + } lengths = set(map(len, cat_parts.values())) assert len(lengths) == 1, f"expected one length, got {lengths}" return types.Transitions(**cat_parts) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 068a264c8..8917c2d77 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -6,9 +6,9 @@ import warnings from typing import ( Any, - Callable, Dict, Iterable, + Iterator, List, Mapping, Optional, @@ -25,90 +25,91 @@ T = TypeVar("T") +TensorVar = TypeVar("TensorVar", np.ndarray, th.Tensor) + + AnyPath = Union[str, bytes, os.PathLike] @dataclasses.dataclass(frozen=True) class DictObs: - # TODO: Docs! """ Stores observations from an environment with a dictionary observation space. - This class serves two purposes: - 1. enforcing invariants on the observations - 2. providing an interface that more closely reflects that of observations - stored in a numpy array. - - Observations are in the format dict[str, np.ndarray]. - DictObs enforces that: - - all arrays have equal first dimension. This dimension usually represents - timesteps, but sometimes represents different environments in a vecenv. - - For the inteface, DictObs provides: - - len(DictObs) returns the first dimension of the arrays(enforced to be equal) - - slicing/indexing along this first dimension (returning a dictobs) - - iterating (yeilds a series of dictobs, iterating over first dimension of each array) - - There are some other convenience functions for mapping / stacking / concatenating + + Provides an interface that is similar to observations in a numpy array. + Length, slicing, indexing, and iterating operations will operate on the first + dimension of the constituent arrays, as they would for observations in a single + array. + + There are also utility functions for mapping / stacking / concatenating lists of dictobs. """ - # TODO: should support th.tensor? d: dict[str, np.ndarray] + def __post_init__(self): + if not all((isinstance(v, np.ndarray) for v in self.d.values())): + raise ValueError("keys must by numpy arrays") + def __len__(self): + """Returns the first dimension of constiuent arrays. + + Only defined if there is at least one array, and all arrays have the same + length of first dimension. Otherwise raises ValueError. + + Len of a DictObs usually represents number of timesteps, or number of + environments in a VecEnv. + + Use `dict_len` to get the number of entries in the dictionary. + + raises: + ValueError: if the arrays have different lengths or there are no arrays. + """ lens = set(len(v) for v in self.d.values()) if len(lens) == 1: return lens.pop() + elif len(lens) == 0: + raise ValueError("Length not defined as DictObs is empty") else: - raise ValueError(f"observations of conflicting lengths found: {lens}") - - @classmethod - def map(cls, obj: Union["DictObs", np.ndarray, th.Tensor], fn: Callable): - if isinstance(obj, cls): - return cls({k: fn(v) for k, v in obj.d.items()}) - else: - return fn(obj) - - @classmethod - def stack(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": - unstacked: dict[str, list[np.ndarray]] = collections.defaultdict(list) - for do in dictobs_list: - for k, array in do.d.items(): - unstacked[k].append(array) - stacked = {k: np.stack(arr_list) for k, arr_list in unstacked.items()} - return cls(stacked) + raise ValueError( + f"Length not defined; arrays have conflicting first dimensions: {lens}" + ) - @classmethod - def concatenate(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": - unstacked: dict[str, list[np.ndarray]] = collections.defaultdict(list) - for do in dictobs_list: - for k, array in do.d.items(): - unstacked[k].append(array) - catted = { - k: np.concatenate(arr_list, axis=0) for k, arr_list in unstacked.items() - } - return cls(catted) + @property + def dict_len(self): + return len(self.d) def __getitem__(self, key: Union[slice, int]) -> "DictObs": - # TODO assert just one slice? (multi dimensional slices are sketchy) - # TODO test - # TODO compare to adam's below - - # TODO -- what if you slice for a single timestep? This would presumably invalidate - # the equal-first-dimension invariant. Do we want that invariant? - # maybe it should be optional at least? - return DictObs.map(self, lambda a: a[key]) + """ + Indexes or slices into the first element of every array. + Note that it will still return singleton values as np.arrays, not scalars. + """ + # asarray to handle case where we slice to a single array element. + return self.__class__({k: np.asarray(v[key]) for k, v in self.d.items()}) - def __iter__(self): + def __iter__(self) -> Iterator["DictObs"]: + """ + Iterates over the first dimension of each array. + """ return (self[i] for i in range(len(self))) + # TODO: eq? + + @property def shape(self) -> dict[str, tuple[int, ...]]: + """ + Returns a dictionary with shape-tuples in place of the arrays. + """ return {k: v.shape for k, v in self.d.items()} - def dtype(self) -> dict[str, tuple[int, ...]]: + @property + def dtype(self) -> dict[str, np._DType]: + """ + Returns a dictionary with shape-tuples in place of the arrays. + """ return {k: v.dtype for k, v in self.d.items()} - def unwrap(self) -> dict: + def unwrap(self) -> dict[str, np.ndarray]: return self.d @classmethod @@ -123,18 +124,44 @@ def maybe_wrap( assert isinstance(obs, (np.ndarray, cls)) return obs + @overload + @classmethod + def maybe_unwrap(cls, maybe_dictobs: "DictObs") -> dict[str, np.ndarray]: + ... + + @overload @classmethod - def maybe_unwrap( - cls, maybe_dictobs: Union["DictObs", np.ndarray] - ) -> Union[dict[str, np.ndarray], np.ndarray]: + def maybe_unwrap(cls, maybe_dictobs: TensorVar) -> TensorVar: + ... + + @classmethod + def maybe_unwrap(cls, maybe_dictobs): if isinstance(maybe_dictobs, cls): return maybe_dictobs.unwrap() else: assert isinstance(maybe_dictobs, (np.ndarray, th.Tensor)) return maybe_dictobs - # TODO: eq - # TODO: post_init len check? + @staticmethod + def _unravel(dictobs_list: Iterable["DictObs"]) -> dict[str, list[np.ndarray]]: + """Converts a list of DictObs into a dictionary of lists of arrays.""" + unraveled: dict[str, list[np.ndarray]] = collections.defaultdict(list) + for do in dictobs_list: + for k, array in do.d.items(): + unraveled[k].append(array) + return unraveled + + @classmethod + def stack(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": + return cls( + {k: np.stack(arr_list) for k, arr_list in cls._unravel(dictobs_list)} + ) + + @classmethod + def concatenate(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": + return cls( + {k: np.concatenate(arr_list) for k, arr_list in cls._unravel(dictobs_list)} + ) def assert_not_dictobs(x: Union[np.ndarray, DictObs]) -> np.ndarray: @@ -297,8 +324,8 @@ def transitions_collate_fn( result = th_data.dataloader.default_collate(batch_no_infos) assert isinstance(result, dict) result["infos"] = [sample["infos"] for sample in batch] - result["obs"] = stack([sample["obs"] for sample in batch]) - result["next_obs"] = stack([sample["next_obs"] for sample in batch]) + result["obs"] = stack_maybe_dictobs([sample["obs"] for sample in batch]) + result["next_obs"] = stack_maybe_dictobs([sample["next_obs"] for sample in batch]) # TODO: clean names, docs return result @@ -473,7 +500,7 @@ def __post_init__(self): ObsType = TypeVar("ObsType", np.ndarray, DictObs) -def concatenate(arrs: List[ObsType]) -> ObsType: +def concatenate_maybe_dictobs(arrs: List[ObsType]) -> ObsType: assert len(arrs) > 0 if isinstance(arrs[0], DictObs): return DictObs.concatenate(arrs) @@ -481,26 +508,9 @@ def concatenate(arrs: List[ObsType]) -> ObsType: return np.concatenate(arrs) -def stack(arrs: List[ObsType]) -> ObsType: +def stack_maybe_dictobs(arrs: List[ObsType]) -> ObsType: assert len(arrs) > 0 if isinstance(arrs[0], DictObs): return DictObs.stack(arrs) else: return np.stack(arrs) - - -# class ObsList(collections.UserList[O]): -# def isDict(self) -> bool: -# return (len(self) > 0) and isinstance(self[0], DictObs) - -# def concatenate(self) -> O: -# if self.isDict(): -# return DictObs.concatenate(self) -# else: -# return np.concatenate(self) - -# def stack(self) -> O: -# if self.isDict(): -# return DictObs.concatenate(self) -# else: -# return np.concatenate(self) diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 443333e81..a741d8af9 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -10,6 +10,7 @@ from typing import ( Any, Callable, + Dict, Iterable, Iterator, List, From c3331f616e41ecfe6d43536368c184f257ac2c4e Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Wed, 13 Sep 2023 18:20:17 -0700 Subject: [PATCH 03/85] add dict space to test_types.py, fix some problems --- src/imitation/data/huggingface_utils.py | 2 ++ src/imitation/data/types.py | 28 +++++++++++++++++--- src/imitation/util/util.py | 1 - tests/data/test_types.py | 35 ++++++++++++++++++++----- 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/imitation/data/huggingface_utils.py b/src/imitation/data/huggingface_utils.py index 158984486..58fb8be2b 100644 --- a/src/imitation/data/huggingface_utils.py +++ b/src/imitation/data/huggingface_utils.py @@ -124,6 +124,8 @@ def trajectories_to_dict( ], terminal=[traj.terminal for traj in trajectories], ) + if any(isinstance(traj.obs, types.DictObs) for traj in trajectories): + raise ValueError("DictObs are not currently supported") # Encode infos as jsonpickled strings trajectory_dict["infos"] = [ diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 8917c2d77..c10e48fec 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -6,6 +6,7 @@ import warnings from typing import ( Any, + Callable, Dict, Iterable, Iterator, @@ -47,6 +48,12 @@ class DictObs: d: dict[str, np.ndarray] + @classmethod + def from_obs_list(cls, obs_list: Iterable[Dict[str, np.ndarray]]): + return cls( + {k: np.stack([obs[k] for obs in obs_list]) for k in obs_list[0].keys()} + ) + def __post_init__(self): if not all((isinstance(v, np.ndarray) for v in self.d.values())): raise ValueError("keys must by numpy arrays") @@ -93,7 +100,12 @@ def __iter__(self) -> Iterator["DictObs"]: """ return (self[i] for i in range(len(self))) - # TODO: eq? + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + if not self.d.keys() == other.d.keys(): + return False + return all(np.array_equal(self.d[k], other.d[k]) for k in self.d.keys()) @property def shape(self) -> dict[str, tuple[int, ...]]: @@ -103,7 +115,7 @@ def shape(self) -> dict[str, tuple[int, ...]]: return {k: v.shape for k, v in self.d.items()} @property - def dtype(self) -> dict[str, np._DType]: + def dtype(self) -> dict[str, np.dtype]: """ Returns a dictionary with shape-tuples in place of the arrays. """ @@ -142,6 +154,9 @@ def maybe_unwrap(cls, maybe_dictobs): assert isinstance(maybe_dictobs, (np.ndarray, th.Tensor)) return maybe_dictobs + def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": + return self.__class__({k: fn(v) for k, v in self.d.items()}) + @staticmethod def _unravel(dictobs_list: Iterable["DictObs"]) -> dict[str, list[np.ndarray]]: """Converts a list of DictObs into a dictionary of lists of arrays.""" @@ -163,6 +178,8 @@ def concatenate(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": {k: np.concatenate(arr_list) for k, arr_list in cls._unravel(dictobs_list)} ) + # TODO: add keys, values, items? + def assert_not_dictobs(x: Union[np.ndarray, DictObs]) -> np.ndarray: if isinstance(x, DictObs): @@ -224,7 +241,9 @@ def __eq__(self, other) -> bool: if not isinstance(other, Trajectory): return False - dict_self, dict_other = dataclasses.asdict(self), dataclasses.asdict(other) + dict_self, dict_other = dataclass_quick_asdict(self), dataclass_quick_asdict( + other + ) # Trajectory objects may still have different keys if different subclasses if dict_self.keys() != dict_other.keys(): return False @@ -240,6 +259,9 @@ def __eq__(self, other) -> bool: # Treat None equivalent to sequence of empty dicts self_v = [{}] * len(self) if self_v is None else self_v other_v = [{}] * len(other) if other_v is None else other_v + if isinstance(self_v, DictObs): + if not self_v == other_v: + return False if not np.array_equal(self_v, other_v): return False diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index a741d8af9..443333e81 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -10,7 +10,6 @@ from typing import ( Any, Callable, - Dict, Iterable, Iterator, List, diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 1f5bcc8ba..e44f795a9 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -22,7 +22,11 @@ gym.spaces.Box(-1, 1, shape=(2,)), gym.spaces.Box(-np.inf, np.inf, shape=(2,)), ] -OBS_SPACES = SPACES +DICT_SPACE = gym.spaces.Dict( + {"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))} +) + +OBS_SPACES = SPACES + [DICT_SPACE] ACT_SPACES = SPACES LENGTHS = [0, 1, 2, 10] @@ -42,7 +46,12 @@ def trajectory( """Fixture to generate trajectory of length `length` iid sampled from spaces.""" if length == 0: pytest.skip() - obs = np.array([obs_space.sample() for _ in range(length + 1)]) + + raw_obs = [obs_space.sample() for _ in range(length + 1)] + if isinstance(obs_space, gym.spaces.Dict): + obs = types.DictObs.from_obs_list(raw_obs) + else: + obs = np.array(raw_obs) acts = np.array([act_space.sample() for _ in range(length)]) infos = np.array([{f"key{i}": i} for i in range(length)]) return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True) @@ -52,7 +61,9 @@ def trajectory( def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew: """Like `trajectory` but with reward randomly sampled from a Gaussian.""" rews = np.random.randn(len(trajectory)) - return types.TrajectoryWithRew(**dataclasses.asdict(trajectory), rews=rews) + return types.TrajectoryWithRew( + **types.dataclass_quick_asdict(trajectory), rews=rews + ) @pytest.fixture @@ -77,7 +88,7 @@ def transitions( next_obs = np.array([obs_space.sample() for _ in range(length)]) dones = np.zeros(length, dtype=bool) return types.Transitions( - **dataclasses.asdict(transitions_min), + **types.dataclass_quick_asdict(transitions_min), next_obs=next_obs, dones=dones, ) @@ -90,7 +101,9 @@ def transitions_rew( ) -> types.TransitionsWithRew: """Like `transitions` but with reward randomly sampled from a Gaussian.""" rews = np.random.randn(length) - return types.TransitionsWithRew(**dataclasses.asdict(transitions), rews=rews) + return types.TransitionsWithRew( + **types.dataclass_quick_asdict(transitions), rews=rews + ) def _check_transitions_get_item(trans, key): @@ -187,14 +200,22 @@ def test_traj_unequal_to_perturbations( # Or with contents changed for t in [trajectory, trajectory_rew]: - as_dict = dataclasses.asdict(t) + as_dict = types.dataclass_quick_asdict(t) for k in as_dict.keys(): perturbed = dict(as_dict) if k == "infos": perturbed["infos"] = [{"foo": 42}] * len(as_dict["infos"]) + elif isinstance(as_dict[k], types.DictObs): + perturbed[k] = as_dict[k].map_arrays(lambda x: x + 1) + # print(getattr(t, k), perturbed[k]) else: perturbed[k] = as_dict[k] + 1 - assert trajectory != type(t)(**perturbed) + if t == type(t)(**perturbed): + print("\n\n\n") + print(t) + print(perturbed) + print(k) + assert t != type(t)(**perturbed) @pytest.mark.parametrize("type_safe", [False, True]) @pytest.mark.parametrize("use_pickle", [False, True]) From fc9838d32e4c6aaba48e7bd19c364f9bd4911b3a Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Wed, 13 Sep 2023 19:17:19 -0700 Subject: [PATCH 04/85] add dict-obs test for rollout --- src/imitation/data/rollout.py | 33 +++--------------------- src/imitation/data/types.py | 46 ++++++++++++++++------------------ src/imitation/policies/base.py | 38 ++++++++++++++++++++++------ tests/data/test_rollout.py | 36 ++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 61 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 09d130f86..d30efc0cd 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -15,8 +15,6 @@ Sequence, Tuple, Union, - cast, - overload, ) import numpy as np @@ -454,12 +452,13 @@ def generate_trajectories( state = None dones = np.zeros(venv.num_envs, dtype=bool) while np.any(active): - acts, state = get_actions(obs, state, dones) + acts, state = get_actions(wrapped_obs, state, dones) obs, rews, dones, infos = venv.step(acts) assert isinstance( obs, - (np.ndarray, dict), + (np.ndarray, types.DictObs), ), "Tuple observations are not supported." + wrapped_obs = types.DictObs.maybe_wrap(obs) # If an environment is inactive, i.e. the episode completed for that # environment after `sample_until(trajectories)` was true, then we do @@ -469,7 +468,7 @@ def generate_trajectories( new_trajs = trajectories_accum.add_steps_and_auto_finish( acts, - obs, + wrapped_obs, rews, dones, infos, @@ -562,30 +561,6 @@ def rollout_stats( return out_stats -# TODO: I don't love these helpers here. - - -@overload -def concat_arrays_or_dictobs(arrs: Iterable[types.DictObs]) -> types.DictObs: - ... - - -@overload -def concat_arrays_or_dictobs(arrs: Iterable[np.ndarray]) -> np.ndarray: - ... - - -# TODO: awkward that it officially accepts union of -def concat_arrays_or_dictobs(arrs): - if isinstance(arrs[0], types.DictObs): - assert all((isinstance(a, types.DictObs) for a in arrs)) - return types.DictObs.concatenate(arrs) - else: - assert all((isinstance(a, np.ndarray) for a in arrs)) - cast_arrs = cast(Iterable[np.ndarray], arrs) - return np.concatenate(cast_arrs) - - def flatten_trajectories( trajectories: Iterable[types.Trajectory], ) -> types.Transitions: diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index c10e48fec..35c34d875 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -34,8 +34,7 @@ @dataclasses.dataclass(frozen=True) class DictObs: - """ - Stores observations from an environment with a dictionary observation space. + """Stores observations from an environment with a dictionary observation space. Provides an interface that is similar to observations in a numpy array. Length, slicing, indexing, and iterating operations will operate on the first @@ -49,9 +48,9 @@ class DictObs: d: dict[str, np.ndarray] @classmethod - def from_obs_list(cls, obs_list: Iterable[Dict[str, np.ndarray]]): + def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]): return cls( - {k: np.stack([obs[k] for obs in obs_list]) for k in obs_list[0].keys()} + {k: np.stack([obs[k] for obs in obs_list]) for k in obs_list[0].keys()}, ) def __post_init__(self): @@ -79,7 +78,7 @@ def __len__(self): raise ValueError("Length not defined as DictObs is empty") else: raise ValueError( - f"Length not defined; arrays have conflicting first dimensions: {lens}" + f"Length not defined; arrays have conflicting first dimensions: {lens}", ) @property @@ -87,17 +86,15 @@ def dict_len(self): return len(self.d) def __getitem__(self, key: Union[slice, int]) -> "DictObs": - """ - Indexes or slices into the first element of every array. + """Indexes or slices into the first element of every array. + Note that it will still return singleton values as np.arrays, not scalars. """ # asarray to handle case where we slice to a single array element. return self.__class__({k: np.asarray(v[key]) for k, v in self.d.items()}) def __iter__(self) -> Iterator["DictObs"]: - """ - Iterates over the first dimension of each array. - """ + """Iterates over the first dimension of each array.""" return (self[i] for i in range(len(self))) def __eq__(self, other): @@ -109,16 +106,12 @@ def __eq__(self, other): @property def shape(self) -> dict[str, tuple[int, ...]]: - """ - Returns a dictionary with shape-tuples in place of the arrays. - """ + """Returns a dictionary with shape-tuples in place of the arrays.""" return {k: v.shape for k, v in self.d.items()} @property def dtype(self) -> dict[str, np.dtype]: - """ - Returns a dictionary with shape-tuples in place of the arrays. - """ + """Returns a dictionary with shape-tuples in place of the arrays.""" return {k: v.dtype for k, v in self.d.items()} def unwrap(self) -> dict[str, np.ndarray]: @@ -126,10 +119,10 @@ def unwrap(self) -> dict[str, np.ndarray]: @classmethod def maybe_wrap( - cls, obs: Union[dict[str, np.ndarray], np.ndarray, "DictObs"] + cls, + obs: Union[dict[str, np.ndarray], np.ndarray, "DictObs"], ) -> Union["DictObs", np.ndarray]: - """If `obs` is a dict, wraps in a dict obs. - If `obs` is an array or already an obsdict, returns it unchanged""" + """Converts an observation into a DictObs, if necessary.""" if isinstance(obs, dict): return cls(obs) else: @@ -169,13 +162,19 @@ def _unravel(dictobs_list: Iterable["DictObs"]) -> dict[str, list[np.ndarray]]: @classmethod def stack(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": return cls( - {k: np.stack(arr_list) for k, arr_list in cls._unravel(dictobs_list)} + { + k: np.stack(arr_list) + for k, arr_list in cls._unravel(dictobs_list).items() + }, ) @classmethod def concatenate(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": return cls( - {k: np.concatenate(arr_list) for k, arr_list in cls._unravel(dictobs_list)} + { + k: np.concatenate(arr_list) + for k, arr_list in cls._unravel(dictobs_list).items() + }, ) # TODO: add keys, values, items? @@ -241,9 +240,8 @@ def __eq__(self, other) -> bool: if not isinstance(other, Trajectory): return False - dict_self, dict_other = dataclass_quick_asdict(self), dataclass_quick_asdict( - other - ) + dict_self = dataclass_quick_asdict(self) + dict_other = dataclass_quick_asdict(other) # Trajectory objects may still have different keys if different subclasses if dict_self.keys() != dict_other.keys(): return False diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 60db89f50..7bebd8e2a 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -1,7 +1,7 @@ """Custom policy classes and convenience methods.""" import abc -from typing import Type +from typing import Dict, Type, Union import gym import numpy as np @@ -10,6 +10,7 @@ from stable_baselines3.sac import policies as sac_policies from torch import nn +from imitation.data import types from imitation.util import networks @@ -23,18 +24,33 @@ def __init__(self, observation_space: gym.Space, action_space: gym.Space): action_space=action_space, ) - def _predict(self, obs: th.Tensor, deterministic: bool = False): + # TODO: support + def _predict( + self, + obs: Union[th.Tensor, Dict[str, th.Tensor]], + deterministic: bool = False, + ): np_actions = [] - np_obs = obs.detach().cpu().numpy() + if isinstance(obs, dict): + np_obs = types.DictObs( + {k: v.detach().cpu().numpy() for k, v in obs.items()}, + ) + else: + np_obs = obs.detach().cpu().numpy() for np_ob in np_obs: - assert self.observation_space.contains(np_ob) - np_actions.append(self._choose_action(np_ob)) + np_ob_unwrapped = types.DictObs.maybe_unwrap(np_ob) + # print(np_ob_unwrapped, self.observation_space) + assert self.observation_space.contains(np_ob_unwrapped) + np_actions.append(self._choose_action(np_ob_unwrapped)) np_actions = np.stack(np_actions, axis=0) th_actions = th.as_tensor(np_actions, device=self.device) return th_actions @abc.abstractmethod - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: """Chooses an action, optionally based on observation obs.""" def forward(self, *args): @@ -46,14 +62,20 @@ def forward(self, *args): class RandomPolicy(NonTrainablePolicy): """Returns random actions.""" - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: return self.action_space.sample() class ZeroPolicy(NonTrainablePolicy): """Returns constant zero action.""" - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: return np.zeros(self.action_space.shape, dtype=self.action_space.dtype) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 51b70851c..64655b0e9 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -371,3 +371,39 @@ def test_rollout_normal_error_for_other_shape_mismatch(rng): rollout.make_sample_until(min_timesteps=None, min_episodes=2), rng=rng, ) + + +class DictObsWrapper(gym.ObservationWrapper): + """Simple wrapper that turns the observation into a dictionary.""" + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + self.observation_space = gym.spaces.Dict( + {"a": env.observation_space, "b": env.observation_space}, + ) + + def observation(self, observation): + return {"a": observation, "b": observation / 2} + + +def test_dictionary_observations(rng): + """Test we can generate a rollout for a dict-type observation environment. + + Args: + rng: Random state to use (with fixed seed). + """ + env = gym.make("CartPole-v1") + env = monitor.Monitor(env, None) + env = DictObsWrapper(env) + venv = vec_env.DummyVecEnv([lambda: env]) + + policy = serialize.load_policy("zero", venv) + trajs = rollout.generate_trajectories( + policy, + venv, + rollout.make_min_episodes(10), + rng=rng, + ) + for traj in trajs: + assert isinstance(traj.obs, types.DictObs) + np.testing.assert_allclose(traj.obs.d["a"] / 2, traj.obs.d["b"]) From fb9498bca95d93986e090015e4be0f0ec9dfbb1b Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:32:27 -0700 Subject: [PATCH 05/85] add bc.py test --- tests/algorithms/test_bc.py | 59 +++++++++++++++++++++++++++++++++++++ tests/data/test_buffer.py | 6 ++++ 2 files changed, 65 insertions(+) diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 71a58b105..32556d483 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -4,10 +4,13 @@ import os from typing import Any, Callable, Optional, Sequence +import gym import hypothesis import hypothesis.strategies as st import numpy as np import pytest +import stable_baselines3.common.envs as sb_envs +import stable_baselines3.common.policies as sb_policies import torch as th from stable_baselines3.common import evaluation, vec_env @@ -371,3 +374,59 @@ def inc_batch_cnt(): # THEN assert batch_cnt == no_yield_after_iter + + +class FloatReward(gym.RewardWrapper): + """Typecasts reward to a float.""" + + def reward(self, reward): + return float(reward) + + +# TODO: make test nicer +def test_dict_space(): + # TODO: is sb_envs okay? + def make_env(): + env = sb_envs.SimpleMultiObsEnv(channel_last=False) + return RolloutInfoWrapper(FloatReward(env)) + + env = vec_env.DummyVecEnv([make_env, make_env]) + env.observation_space["img"], env.observation_space["vec"] + env.observation_space.shape + + policy = sb_policies.MultiInputActorCriticPolicy( + env.observation_space, + env.action_space, + lambda: 0.001, + ) + rng = np.random.default_rng() + + def sample_expert_transitions(): + print("Sampling expert transitions.") + rollouts = rollout.rollout( + policy=None, + venv=env, + sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + unwrap=False, # TODO have rollout unwrap wrapper support dict + ) + return rollout.flatten_trajectories(rollouts) + + transitions = sample_expert_transitions() + + bc_trainer = bc.BC( + observation_space=env.observation_space, + policy=policy, + action_space=env.action_space, + rng=rng, + demonstrations=transitions, + ) + + bc_trainer.train(n_epochs=1) + + reward, _ = evaluation.evaluate_policy( + bc_trainer.policy, # type: ignore[arg-type] + env, + n_eval_episodes=3, + render=False, # comment out to speed up + ) diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index 205a43717..858fb5229 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -17,6 +17,7 @@ def _fill_chunk(start, chunk_len, sample_shape, dtype=float): def _get_fill_from_chunk(chunk): chunk_len, *sample_shape = chunk.shape sample_size = max(1, np.prod(sample_shape)) + types.assert_not_dictobs(chunk) return chunk.flatten()[::sample_size] @@ -134,6 +135,11 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype): sample = buf.sample(100) info_vals = np.array([info["a"] for info in sample.infos]) + # dictobs not supported for buffers, or by current code in + # this test file (eg `_get_fill_from_chunk`) + types.assert_not_dictobs(sample.obs) + types.assert_not_dictobs(sample.next_obs) + assert sample.obs.shape == sample.next_obs.shape == (100,) + obs_shape assert sample.acts.shape == (100,) + act_shape assert sample.dones.shape == (100,) From e54c36ce0dbbaf9b33075c6c59720398cacdd98f Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:05:58 -0700 Subject: [PATCH 06/85] cleanup --- src/imitation/algorithms/bc.py | 19 ++++++++++++---- src/imitation/data/rollout.py | 2 +- src/imitation/data/types.py | 41 ++++++++++++++++++++++------------ src/imitation/policies/base.py | 2 -- src/imitation/util/util.py | 2 -- tests/data/test_buffer.py | 5 ++--- 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 08e2611b2..df1c5d14c 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -101,7 +101,12 @@ class BehaviorCloningLossCalculator: def __call__( self, policy: policies.ActorCriticPolicy, - obs: Union[th.Tensor, np.ndarray, types.DictObs], + obs: Union[ + th.Tensor, + np.ndarray, + types.DictObs, + dict[str, Union[np.ndarray, th.Tensor]], + ], acts: Union[th.Tensor, np.ndarray], ) -> BCTrainingMetrics: """Calculate the supervised learning loss used to train the behavioral clone. @@ -118,15 +123,20 @@ def __call__( tensor_obs: Union[th.Tensor, Dict[str, th.Tensor]] if isinstance(obs, types.DictObs): tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.unwrap().items()} + elif isinstance(obs, dict): + tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.items()} else: tensor_obs = util.safe_to_tensor(obs) acts = util.safe_to_tensor(acts) - # TODO: add check obs is proper type? + # policy.evaluate_actions's type signature seems wrong to me. # it declares it only takes a tensor but it calls # extract_features which is happy with Dict[str, tensor]. # In reality the required type of obs depends on the feature extractor. - _, log_prob, entropy = policy.evaluate_actions(tensor_obs, acts) # type: ignore + (_, log_prob, entropy) = policy.evaluate_actions( + tensor_obs, # type: ignore[arg-type] + acts, + ) prob_true_act = th.exp(log_prob).mean() log_prob = log_prob.mean() entropy = entropy.mean() if entropy is not None else None @@ -485,7 +495,8 @@ def process_batch(): } else: obs_tensor = util.safe_to_tensor( - batch["obs"], device=self.policy.device + batch["obs"], + device=self.policy.device, ) acts = util.safe_to_tensor(batch["acts"], device=self.policy.device) training_metrics = self.loss_calculator(self.policy, obs_tensor, acts) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index d30efc0cd..2ff1bc30a 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -456,7 +456,7 @@ def generate_trajectories( obs, rews, dones, infos = venv.step(acts) assert isinstance( obs, - (np.ndarray, types.DictObs), + (np.ndarray, dict), ), "Tuple observations are not supported." wrapped_obs = types.DictObs.maybe_wrap(obs) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 35c34d875..fbda0b372 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -26,9 +26,6 @@ T = TypeVar("T") -TensorVar = TypeVar("TensorVar", np.ndarray, th.Tensor) - - AnyPath = Union[str, bytes, os.PathLike] @@ -45,7 +42,7 @@ class DictObs: lists of dictobs. """ - d: dict[str, np.ndarray] + d: Dict[str, np.ndarray] @classmethod def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]): @@ -68,8 +65,11 @@ def __len__(self): Use `dict_len` to get the number of entries in the dictionary. - raises: + Raises: ValueError: if the arrays have different lengths or there are no arrays. + + Returns: + The length (first dimension) of the constiuent arrays """ lens = set(len(v) for v in self.d.values()) if len(lens) == 1: @@ -88,9 +88,17 @@ def dict_len(self): def __getitem__(self, key: Union[slice, int]) -> "DictObs": """Indexes or slices into the first element of every array. - Note that it will still return singleton values as np.arrays, not scalars. + Note that it will still return singleton values as np.arrays, not scalars, + to be consistent with DictObs type signature. + Also note that we don't support multi-dimensional slicing. + + Args: + key: a single slice + + Returns: + A new DictObj object with each array indexed. """ - # asarray to handle case where we slice to a single array element. + # asarray handles case where we slice to a single array element. return self.__class__({k: np.asarray(v[key]) for k, v in self.d.items()}) def __iter__(self) -> Iterator["DictObs"]: @@ -105,22 +113,22 @@ def __eq__(self, other): return all(np.array_equal(self.d[k], other.d[k]) for k in self.d.keys()) @property - def shape(self) -> dict[str, tuple[int, ...]]: + def shape(self) -> Dict[str, tuple[int, ...]]: """Returns a dictionary with shape-tuples in place of the arrays.""" return {k: v.shape for k, v in self.d.items()} @property - def dtype(self) -> dict[str, np.dtype]: + def dtype(self) -> Dict[str, np.dtype]: """Returns a dictionary with shape-tuples in place of the arrays.""" return {k: v.dtype for k, v in self.d.items()} - def unwrap(self) -> dict[str, np.ndarray]: + def unwrap(self) -> Dict[str, np.ndarray]: return self.d @classmethod def maybe_wrap( cls, - obs: Union[dict[str, np.ndarray], np.ndarray, "DictObs"], + obs: Union[Dict[str, np.ndarray], np.ndarray, "DictObs"], ) -> Union["DictObs", np.ndarray]: """Converts an observation into a DictObs, if necessary.""" if isinstance(obs, dict): @@ -129,9 +137,11 @@ def maybe_wrap( assert isinstance(obs, (np.ndarray, cls)) return obs + TensorVar = TypeVar("TensorVar", np.ndarray, th.Tensor) + @overload @classmethod - def maybe_unwrap(cls, maybe_dictobs: "DictObs") -> dict[str, np.ndarray]: + def maybe_unwrap(cls, maybe_dictobs: "DictObs") -> Dict[str, np.ndarray]: ... @overload @@ -151,9 +161,9 @@ def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": return self.__class__({k: fn(v) for k, v in self.d.items()}) @staticmethod - def _unravel(dictobs_list: Iterable["DictObs"]) -> dict[str, list[np.ndarray]]: + def _unravel(dictobs_list: Iterable["DictObs"]) -> Dict[str, List[np.ndarray]]: """Converts a list of DictObs into a dictionary of lists of arrays.""" - unraveled: dict[str, list[np.ndarray]] = collections.defaultdict(list) + unraveled: Dict[str, List[np.ndarray]] = collections.defaultdict(list) for do in dictobs_list: for k, array in do.d.items(): unraveled[k].append(array) @@ -197,6 +207,9 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]: undocumentedly deep-copies every numpy array value. See https://stackoverflow.com/a/52229565/1091722. + This is also used to preserve DictObj objects, as `dataclasses.asdict` + unwraps them recursively. + Args: obj: A dataclass instance. diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 7bebd8e2a..7d9c1c741 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -24,7 +24,6 @@ def __init__(self, observation_space: gym.Space, action_space: gym.Space): action_space=action_space, ) - # TODO: support def _predict( self, obs: Union[th.Tensor, Dict[str, th.Tensor]], @@ -39,7 +38,6 @@ def _predict( np_obs = obs.detach().cpu().numpy() for np_ob in np_obs: np_ob_unwrapped = types.DictObs.maybe_unwrap(np_ob) - # print(np_ob_unwrapped, self.observation_space) assert self.observation_space.contains(np_ob_unwrapped) np_actions.append(self._choose_action(np_ob_unwrapped)) np_actions = np.stack(np_actions, axis=0) diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index 443333e81..cba83f504 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -240,8 +240,6 @@ def safe_to_tensor( you just use `th.as_tensor` for this, an ugly warning is logged and there's undefined behavior if you try to write to the tensor. - `array` can also be a dictionary, in which case all values will be converted. - Args: array: The array to convert to a PyTorch tensor. kwargs: Additional keyword arguments to pass to `th.as_tensor`. diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index 858fb5229..5acb6b344 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -17,7 +17,6 @@ def _fill_chunk(start, chunk_len, sample_shape, dtype=float): def _get_fill_from_chunk(chunk): chunk_len, *sample_shape = chunk.shape sample_size = max(1, np.prod(sample_shape)) - types.assert_not_dictobs(chunk) return chunk.flatten()[::sample_size] @@ -137,8 +136,8 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype): # dictobs not supported for buffers, or by current code in # this test file (eg `_get_fill_from_chunk`) - types.assert_not_dictobs(sample.obs) - types.assert_not_dictobs(sample.next_obs) + sample.obs = types.assert_not_dictobs(sample.obs) + sample.next_obs = types.assert_not_dictobs(sample.next_obs) assert sample.obs.shape == sample.next_obs.shape == (100,) + obs_shape assert sample.acts.shape == (100,) + act_shape From ee043839d6ae42ef3ffd73d9eeadfdbd2e51b531 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:28:26 -0700 Subject: [PATCH 07/85] small fixes --- src/imitation/algorithms/bc.py | 2 +- src/imitation/data/rollout.py | 2 +- src/imitation/data/types.py | 2 +- tests/algorithms/test_bc.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index df1c5d14c..2ac516da4 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -105,7 +105,7 @@ def __call__( th.Tensor, np.ndarray, types.DictObs, - dict[str, Union[np.ndarray, th.Tensor]], + Dict[str, Union[np.ndarray, th.Tensor]], ], acts: Union[th.Tensor, np.ndarray], ) -> BCTrainingMetrics: diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 2ff1bc30a..c630394a5 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -123,7 +123,7 @@ def finish_trajectory( def add_steps_and_auto_finish( self, acts: np.ndarray, - obs: Union[np.ndarray, dict[str, np.ndarray], types.DictObs], + obs: Union[np.ndarray, Dict[str, np.ndarray], types.DictObs], rews: np.ndarray, dones: np.ndarray, infos: List[dict], diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index fbda0b372..5bf7f0047 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -113,7 +113,7 @@ def __eq__(self, other): return all(np.array_equal(self.d[k], other.d[k]) for k in self.d.keys()) @property - def shape(self) -> Dict[str, tuple[int, ...]]: + def shape(self) -> Dict[str, Tuple[int, ...]]: """Returns a dictionary with shape-tuples in place of the arrays.""" return {k: v.shape for k, v in self.d.items()} diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 32556d483..a166aebbd 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -397,7 +397,7 @@ def make_env(): policy = sb_policies.MultiInputActorCriticPolicy( env.observation_space, env.action_space, - lambda: 0.001, + lambda _: 0.001, ) rng = np.random.default_rng() From 6e2218ab1fa5081bda73943b02cdc8aa42025c94 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:36:27 -0700 Subject: [PATCH 08/85] small fixes --- src/imitation/data/types.py | 3 ++- tests/algorithms/test_bc.py | 7 ++++--- tests/data/test_buffer.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 5bf7f0047..4ac2cb09e 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -154,7 +154,8 @@ def maybe_unwrap(cls, maybe_dictobs): if isinstance(maybe_dictobs, cls): return maybe_dictobs.unwrap() else: - assert isinstance(maybe_dictobs, (np.ndarray, th.Tensor)) + if not isinstance(maybe_dictobs, (np.ndarray, th.Tensor)): + warnings.warn(f"trying to unwrap object of type {type(maybe_dictobs)}") return maybe_dictobs def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index a166aebbd..2423387e4 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -9,10 +9,11 @@ import hypothesis.strategies as st import numpy as np import pytest -import stable_baselines3.common.envs as sb_envs -import stable_baselines3.common.policies as sb_policies import torch as th -from stable_baselines3.common import evaluation, vec_env +from stable_baselines3.common import envs as sb_envs +from stable_baselines3.common import evaluation +from stable_baselines3.common import policies as sb_policies +from stable_baselines3.common import vec_env from imitation.algorithms import bc from imitation.data import rollout, types diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index 5acb6b344..9b53349a1 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -136,31 +136,31 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype): # dictobs not supported for buffers, or by current code in # this test file (eg `_get_fill_from_chunk`) - sample.obs = types.assert_not_dictobs(sample.obs) - sample.next_obs = types.assert_not_dictobs(sample.next_obs) + obs = types.assert_not_dictobs(sample.obs) + next_obs = types.assert_not_dictobs(sample.next_obs) - assert sample.obs.shape == sample.next_obs.shape == (100,) + obs_shape + assert obs.shape == next_obs.shape == (100,) + obs_shape assert sample.acts.shape == (100,) + act_shape assert sample.dones.shape == (100,) assert info_vals.shape == (100,) # Are samples right data type? - assert sample.obs.dtype == dtype + assert obs.dtype == dtype assert sample.acts.dtype == dtype - assert sample.next_obs.dtype == dtype + assert next_obs.dtype == dtype assert info_vals.dtype == dtype assert sample.dones.dtype == bool assert sample.infos.dtype == object # Are samples in range? - _check_bound(i + chunk_len, capacity, sample.obs) - _check_bound(i + chunk_len, capacity, sample.next_obs, 3 * capacity) + _check_bound(i + chunk_len, capacity, obs) + _check_bound(i + chunk_len, capacity, next_obs, 3 * capacity) _check_bound(i + chunk_len, capacity, sample.acts, 6 * capacity) _check_bound(i + chunk_len, capacity, info_vals, 9 * capacity) # Are samples in-order? - obs_fill = _get_fill_from_chunk(sample.obs) - next_obs_fill = _get_fill_from_chunk(sample.next_obs) + obs_fill = _get_fill_from_chunk(obs) + next_obs_fill = _get_fill_from_chunk(next_obs) act_fill = _get_fill_from_chunk(sample.acts) info_vals_fill = _get_fill_from_chunk(info_vals) From 68fe666c9cbdf508542f568e990f320f855aabf1 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:42:45 -0700 Subject: [PATCH 09/85] fix type error in interactive.py --- src/imitation/policies/interactive.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py index 64be29b0f..f0c8b1210 100644 --- a/src/imitation/policies/interactive.py +++ b/src/imitation/policies/interactive.py @@ -2,7 +2,7 @@ import abc import collections -from typing import Optional, Union +from typing import Dict, Optional, Union import gym import matplotlib.pyplot as plt @@ -56,10 +56,16 @@ def __init__( } self.clear_screen_on_query = clear_screen_on_query - def _choose_action(self, obs: np.ndarray) -> np.ndarray: + def _choose_action( + self, + obs: Union[np.ndarray, Dict[str, np.ndarray]], + ) -> np.ndarray: if self.clear_screen_on_query: util.clear_screen() + if isinstance(obs, dict): + raise ValueError("Dictionary observations are not supported here") + context = self._render(obs) key = self._get_input_key() self._clean_up(context) From 9ad2aaf7b9ee9f8719754d250ccc1a3920a1151b Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:51:29 -0700 Subject: [PATCH 10/85] fix introduced error in mce_irl.py --- src/imitation/algorithms/mce_irl.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index fba45e0d3..0667928ba 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -347,11 +347,7 @@ def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None num_demos = 0 for traj in trajs: cum_discount = 1.0 - if isinstance(traj.obs, types.DictObs): - raise ValueError( - "Dictionary observations are not currently supported for mce_irl" - ) - for obs in traj.obs: + for obs in types.assert_not_dictobs(traj.obs): self.demo_state_om[obs] += cum_discount cum_discount *= self.discount num_demos += 1 @@ -421,7 +417,9 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: ) elif isinstance(demonstrations, types.TransitionsMinimal): self._set_demo_from_obs( - types.assert_not_dictobs(demonstrations.obs), None, None + types.assert_not_dictobs(demonstrations.obs), + None, + None, ) elif isinstance(demonstrations, Iterable): # Demonstrations are a Torch DataLoader or other Mapping iterable @@ -433,11 +431,7 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: assert isinstance(batch, Mapping) for k in ("obs", "dones", "next_obs"): if k in batch: - if isinstance(batch[k], types.DictObs): - raise ValueError( - "Dictionary observations are not currently supported for buffers" - ) - collated_list[k].append() + collated_list[k].append(batch[k]) collated = {k: np.concatenate(v) for k, v in collated_list.items()} assert "obs" in collated From 67341d5c6876a14719b3eb8fc4bc139c9e42f027 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 12:04:01 -0700 Subject: [PATCH 11/85] fix minor ci complaint --- src/imitation/algorithms/mce_irl.py | 23 +++++++++++++++++++++-- src/imitation/data/buffer.py | 12 ++++-------- tests/data/test_types.py | 8 +++++--- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index 0667928ba..19cf40bb3 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -7,7 +7,18 @@ """ import collections import warnings -from typing import Any, Iterable, List, Mapping, NoReturn, Optional, Tuple, Type, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import gym import numpy as np @@ -426,7 +437,15 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: # Collect them together into one big NumPy array. This is inefficient, # we could compute the running statistics instead, but in practice do # not expect large dataset sizes together with MCE IRL. - collated_list = collections.defaultdict(list) + collated_list: Dict[ + str, + List[ + Union[ + np.ndarray, + th.Tensor, + ] + ], + ] = collections.defaultdict(list) for batch in demonstrations: assert isinstance(batch, Mapping) for k in ("obs", "dones", "next_obs"): diff --git a/src/imitation/data/buffer.py b/src/imitation/data/buffer.py index c6a3783dc..2a6939408 100644 --- a/src/imitation/data/buffer.py +++ b/src/imitation/data/buffer.py @@ -345,20 +345,16 @@ def from_data( Returns: A new ReplayBuffer. """ - if isinstance(transitions.obs, types.DictObs): - raise ValueError( - "Dictionary observations are not currently supported for buffers" - ) - - obs_shape = transitions.obs.shape[1:] + obs = types.assert_not_dictobs(transitions.obs) + obs_shape = obs.shape[1:] act_shape = transitions.acts.shape[1:] if capacity is None: - capacity = transitions.obs.shape[0] + capacity = obs.shape[0] instance = cls( capacity=capacity, obs_shape=obs_shape, act_shape=act_shape, - obs_dtype=transitions.obs.dtype, + obs_dtype=obs.dtype, act_dtype=transitions.acts.dtype, ) instance.store(transitions, truncate_ok=truncate_ok) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index e44f795a9..cbc431471 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -23,7 +23,7 @@ gym.spaces.Box(-np.inf, np.inf, shape=(2,)), ] DICT_SPACE = gym.spaces.Dict( - {"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))} + {"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))}, ) OBS_SPACES = SPACES + [DICT_SPACE] @@ -62,7 +62,8 @@ def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew: """Like `trajectory` but with reward randomly sampled from a Gaussian.""" rews = np.random.randn(len(trajectory)) return types.TrajectoryWithRew( - **types.dataclass_quick_asdict(trajectory), rews=rews + **types.dataclass_quick_asdict(trajectory), + rews=rews, ) @@ -102,7 +103,8 @@ def transitions_rew( """Like `transitions` but with reward randomly sampled from a Gaussian.""" rews = np.random.randn(length) return types.TransitionsWithRew( - **types.dataclass_quick_asdict(transitions), rews=rews + **types.dataclass_quick_asdict(transitions), + rews=rews, ) From c497b561782b180b87b4e51733896435c677a8d5 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 14:41:42 -0700 Subject: [PATCH 12/85] add basic dictobs tests --- src/imitation/data/types.py | 17 +++++++++--- tests/data/test_types.py | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 4ac2cb09e..3157a9077 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -85,12 +85,14 @@ def __len__(self): def dict_len(self): return len(self.d) - def __getitem__(self, key: Union[slice, int]) -> "DictObs": - """Indexes or slices into the first element of every array. + def __getitem__( + self, + key: Union[int, slice, Tuple[Union[int, slice], ...]], + ) -> "DictObs": + """Indexes or slices each array. Note that it will still return singleton values as np.arrays, not scalars, to be consistent with DictObs type signature. - Also note that we don't support multi-dimensional slicing. Args: key: a single slice @@ -102,7 +104,14 @@ def __getitem__(self, key: Union[slice, int]) -> "DictObs": return self.__class__({k: np.asarray(v[key]) for k, v in self.d.items()}) def __iter__(self) -> Iterator["DictObs"]: - """Iterates over the first dimension of each array.""" + """Iterates over the first dimension of each array. + + Raises: + ValueError if len() is not defined. + + Returns: + Iterator of dictobjs by first dimension. + """ return (self[i] for i in range(len(self))) def __eq__(self, other): diff --git a/tests/data/test_types.py b/tests/data/test_types.py index cbc431471..af3cdb7bb 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -459,3 +459,58 @@ def test_parse_path(): # Parse optional path. Works the same way but passes None down the line. assert util.parse_optional_path(None) is None assert util.parse_optional_path("/foo/bar") == util.parse_path("/foo/bar") + + +def test_dict_obs(): + A = np.random.rand(3, 4) + B = np.random.rand(3, 7, 1) + C = np.random.rand(4) + + ab = types.DictObs({"a": A, "b": B}) + abc = types.DictObs({"a": A, "b": B, "c": C}) + + # len + assert len(ab) == 3 + with pytest.raises(ValueError): + len(abc) + with pytest.raises(ValueError): + len(types.DictObs({})) + + # slicing + np.testing.assert_equal(abc[0].d["a"], A[0]) + np.testing.assert_equal(abc[0].d["c"], np.array(C[0])) + np.testing.assert_equal(abc[0:2].d["a"], np.array(A[0:2])) + np.testing.assert_equal(ab[:, 0].d["a"], np.array(A[:, 0])) + with pytest.raises(IndexError): + abc[:, 0] + + # iter + for i, a_row in enumerate(A): + np.testing.assert_equal(a_row, ab[i].d["a"]) + assert ab[0] == next(iter(ab)) + + # eq + assert abc == types.DictObs({"a": A, "b": B, "c": C}) + assert abc != types.DictObs({"a": A, "c": B, "b": C}) # diff keys + assert abc != types.DictObs({"a": A, "b": B + 1, "c": C}) # diff values + + # shape / dtype + assert abc.shape == {"a": A.shape, "b": B.shape, "c": C.shape} + assert abc.dtype == {"a": A.dtype, "b": B.dtype, "c": C.dtype} + + # wrap + assert types.DictObs.maybe_wrap({"a": A, "b": B, "c": C}) == abc + assert abc.unwrap() == {"a": A, "b": B, "c": C} + + # map, stack, concat + assert abc.map_arrays(lambda arr: arr + 1) == types.DictObs( + {"a": A + 1, "b": B + 1, "c": C + 1}, + ) + assert types.DictObs.stack(list(iter(ab))) == ab + np.testing.assert_equal( + types.DictObs.concatenate([abc, abc]).d["a"], + np.concatenate([A, A]), + ) + + with pytest.raises(ValueError): + types.assert_not_dictobs(abc) From d3f79bf29de401605eb160451e146be6206d479f Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:40:14 -0700 Subject: [PATCH 13/85] change default bc policy for dict obs space --- src/imitation/algorithms/bc.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 2ac516da4..9e32958bf 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -23,7 +23,7 @@ import numpy as np import torch as th import tqdm -from stable_baselines3.common import policies, utils, vec_env +from stable_baselines3.common import policies, torch_layers, utils, vec_env from imitation.algorithms import base as algo_base from imitation.data import rollout, types @@ -345,13 +345,18 @@ def __init__( self.rng = rng if policy is None: - # TODO: maybe default to comb. dict when dict obs space? + extractor = ( + torch_layers.CombinedExtractor + if isinstance(observation_space, gym.spaces.Dict) + else torch_layers.FlattenExtractor + ) policy = policy_base.FeedForward32Policy( observation_space=observation_space, action_space=action_space, # Set lr_schedule to max value to force error if policy.optimizer # is used by mistake (should use self.optimizer instead). lr_schedule=lambda _: th.finfo(th.float32).max, + features_extractor_class=extractor, ) self._policy = policy.to(utils.get_device(device)) # TODO(adam): make policy mandatory and delete observation/action space params? From 2de9e493da7db31b60cec8e1f747e2c22ed55250 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:42:42 -0700 Subject: [PATCH 14/85] refine rollout.py typechecks, comments --- src/imitation/data/rollout.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index c630394a5..6cc9c8489 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -109,11 +109,10 @@ def finish_trajectory( for k, array in part_dict.items(): out_dict_unstacked[k].append(array) - # TODO: what about infos? Does this actually handle them well? traj = types.TrajectoryWithRew( obs=types.stack_maybe_dictobs(out_dict_unstacked["obs"]), acts=np.stack(out_dict_unstacked["acts"], axis=0), - infos=np.stack(out_dict_unstacked["infos"], axis=0), # TODO: confused + infos=np.stack(out_dict_unstacked["infos"], axis=0), # array of dict objs rews=np.stack(out_dict_unstacked["rews"], axis=0), terminal=terminal, ) @@ -430,9 +429,6 @@ def generate_trajectories( # need to wrap here to iterate over envs properly wrapped_obs = types.DictObs.maybe_wrap(obs) - # TODO: make this nicer, it's currently non-mypy compliant - # probably want helper - for env_idx, ob in enumerate(wrapped_obs): # Seed with first obs only. Inside loop, we'll only add second obs from # each (s,a,r,s') tuple, under the same "obs" key again. That way we still @@ -572,13 +568,22 @@ def flatten_trajectories( Returns: The trajectories flattened into a single batch of Transitions. """ + all_of_type = lambda key, t: all( + isinstance(getattr(traj, key), t) for traj in trajectories + ) + assert all_of_type("obs", types.DictObs) or all_of_type("obs", np.ndarray) + assert all_of_type("acts", np.ndarray) + assert all_of_type("dones", np.ndarray) + + # sad to use Any here, but mypy struggles otherwise. + # we enforce type constraints in asserts above and below. keys = ["obs", "next_obs", "acts", "dones", "infos"] - # TODO: sad to use Any here parts: Mapping[str, List[Any]] = {key: [] for key in keys} for traj in trajectories: parts["acts"].append(traj.acts) obs = traj.obs + parts["obs"].append(obs[:-1]) parts["next_obs"].append(obs[1:]) @@ -599,13 +604,6 @@ def flatten_trajectories( lengths = set(map(len, cat_parts.values())) assert len(lengths) == 1, f"expected one length, got {lengths}" return types.Transitions(**cat_parts) - # TODO: clean - # cat_parts["obs"], - # types.assert_not_dictobs(cat_parts["acts"]), - # types.assert_not_dictobs(cat_parts["infos"]), - # cat_parts["next_obs"], - # types.assert_not_dictobs(cat_parts["done"]), - # ) def flatten_trajectories_with_rew( From c47cca64736d6e49c8bf0419af5b055f667b0e0b Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:53:41 -0700 Subject: [PATCH 15/85] check rollout produces dictobs of correct shape --- src/imitation/data/rollout.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 6cc9c8489..d11d2cd88 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -489,10 +489,14 @@ def generate_trajectories( for trajectory in trajectories: n_steps = len(trajectory.acts) # extra 1 for the end - if not isinstance(venv.observation_space, spaces.dict.Dict): + if isinstance(venv.observation_space, spaces.dict.Dict): + exp_obs = { + k: (n_steps + 1,) + v.shape for k, v in venv.observation_space.items() + } + else: exp_obs = (n_steps + 1,) + venv.observation_space.shape - real_obs = types.assert_not_dictobs(trajectory.obs).shape - assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" + real_obs = trajectory.obs.shape + assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" exp_act = (n_steps,) + venv.action_space.shape real_act = trajectory.acts.shape assert real_act == exp_act, f"expected shape {exp_act}, got {real_act}" @@ -568,9 +572,12 @@ def flatten_trajectories( Returns: The trajectories flattened into a single batch of Transitions. """ - all_of_type = lambda key, t: all( - isinstance(getattr(traj, key), t) for traj in trajectories - ) + + def all_of_type(key, desired_type): + return all( + isinstance(getattr(traj, key), desired_type) for traj in trajectories + ) + assert all_of_type("obs", types.DictObs) or all_of_type("obs", np.ndarray) assert all_of_type("acts", np.ndarray) assert all_of_type("dones", np.ndarray) From 276294ba8fd670a150d0e0f6b1cf07261e0eaeb2 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:33:14 -0700 Subject: [PATCH 16/85] cleanup types and dictobs helpers --- src/imitation/algorithms/bc.py | 22 ++--- src/imitation/algorithms/mce_irl.py | 7 +- src/imitation/data/rollout.py | 3 +- src/imitation/data/types.py | 128 +++++++++++++++------------- src/imitation/policies/base.py | 2 +- 5 files changed, 80 insertions(+), 82 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 9e32958bf..ac7c5849b 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -102,10 +102,10 @@ def __call__( self, policy: policies.ActorCriticPolicy, obs: Union[ - th.Tensor, - np.ndarray, + types.AnyTensor, types.DictObs, - Dict[str, Union[np.ndarray, th.Tensor]], + Dict[str, np.ndarray], + Dict[str, th.Tensor], ], acts: Union[th.Tensor, np.ndarray], ) -> BCTrainingMetrics: @@ -492,17 +492,11 @@ def process_batch(): num_samples_so_far, ), batch in batches_with_stats: obs_tensor: Union[th.Tensor, Dict[str, th.Tensor]] - if isinstance(batch["obs"], types.DictObs): - obs_dict = batch["obs"].unwrap() - obs_tensor = { - k: util.safe_to_tensor(v, device=self.policy.device) - for k, v in obs_dict.items() - } - else: - obs_tensor = util.safe_to_tensor( - batch["obs"], - device=self.policy.device, - ) + # unwraps the observation if it's a dictobs and converts arrays to tensors + obs_tensor = types.map_maybe_dict( + lambda x: util.safe_to_tensor(x, device=self.policy.device), + types.maybe_unwrap_dictobs(batch["obs"]), + ) acts = util.safe_to_tensor(batch["acts"], device=self.policy.device) training_metrics = self.loss_calculator(self.policy, obs_tensor, acts) diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index 19cf40bb3..038116044 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -439,12 +439,7 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: # not expect large dataset sizes together with MCE IRL. collated_list: Dict[ str, - List[ - Union[ - np.ndarray, - th.Tensor, - ] - ], + List[types.AnyTensor], ] = collections.defaultdict(list) for batch in demonstrations: assert isinstance(batch, Mapping) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index d11d2cd88..7d147f209 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -321,7 +321,7 @@ def get_actions( # pytype doesn't seem to understand that policy is a BaseAlgorithm # or BasePolicy here, rather than a Callable (acts, states) = policy.predict( # pytype: disable=attribute-error - types.DictObs.maybe_unwrap(observations), + types.maybe_unwrap_dictobs(observations), state=states, episode_start=episode_starts, deterministic=deterministic_policy, @@ -580,7 +580,6 @@ def all_of_type(key, desired_type): assert all_of_type("obs", types.DictObs) or all_of_type("obs", np.ndarray) assert all_of_type("acts", np.ndarray) - assert all_of_type("dones", np.ndarray) # sad to use Any here, but mypy struggles otherwise. # we enforce type constraints in asserts above and below. diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 3157a9077..652673795 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -27,6 +27,8 @@ T = TypeVar("T") AnyPath = Union[str, bytes, os.PathLike] +AnyTensor = Union[np.ndarray, th.Tensor] +TensorVar = TypeVar("TensorVar", np.ndarray, th.Tensor) @dataclasses.dataclass(frozen=True) @@ -146,27 +148,6 @@ def maybe_wrap( assert isinstance(obs, (np.ndarray, cls)) return obs - TensorVar = TypeVar("TensorVar", np.ndarray, th.Tensor) - - @overload - @classmethod - def maybe_unwrap(cls, maybe_dictobs: "DictObs") -> Dict[str, np.ndarray]: - ... - - @overload - @classmethod - def maybe_unwrap(cls, maybe_dictobs: TensorVar) -> TensorVar: - ... - - @classmethod - def maybe_unwrap(cls, maybe_dictobs): - if isinstance(maybe_dictobs, cls): - return maybe_dictobs.unwrap() - else: - if not isinstance(maybe_dictobs, (np.ndarray, th.Tensor)): - warnings.warn(f"trying to unwrap object of type {type(maybe_dictobs)}") - return maybe_dictobs - def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": return self.__class__({k: fn(v) for k, v in self.d.items()}) @@ -200,14 +181,65 @@ def concatenate(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": # TODO: add keys, values, items? -def assert_not_dictobs(x: Union[np.ndarray, DictObs]) -> np.ndarray: +# DicObs utilities + + +Observation = Union[np.ndarray, DictObs] +ObsVar = TypeVar("ObsVar", np.ndarray, DictObs) + + +def assert_not_dictobs(x: Observation) -> np.ndarray: if isinstance(x, DictObs): raise ValueError("Dictionary observations are not supported here.") return x +def concatenate_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: + assert len(arrs) > 0 + if isinstance(arrs[0], DictObs): + return DictObs.concatenate(arrs) + else: + return np.concatenate(arrs) + + +def stack_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: + assert len(arrs) > 0 + if isinstance(arrs[0], DictObs): + return DictObs.stack(arrs) + else: + return np.stack(arrs) + + +@overload +def maybe_unwrap_dictobs(maybe_dictobs: DictObs) -> Dict[str, np.ndarray]: + ... + + +@overload +def maybe_unwrap_dictobs(maybe_dictobs: TensorVar) -> TensorVar: + ... + + +def maybe_unwrap_dictobs(maybe_dictobs): + """Unwraps if a DictObs, otherwise returns the object.""" + if isinstance(maybe_dictobs, DictObs): + return maybe_dictobs.unwrap() + else: + if not isinstance(maybe_dictobs, (np.ndarray, th.Tensor)): + warnings.warn(f"trying to unwrap object of type {type(maybe_dictobs)}") + return maybe_dictobs + + +def map_maybe_dict(fn, maybe_dict): + """Applies fn to all values a dictionary, or to the value itself if not a dict.""" + if isinstance(maybe_dict, dict): + return {k: fn(v) for k, v in maybe_dict.items()} + else: + return fn(maybe_dict) + + # TODO: maybe should support DictObs? -TransitionMapping = Mapping[str, Union[np.ndarray, th.Tensor]] +TransitionMapping = Mapping[str, AnyTensor] def dataclass_quick_asdict(obj) -> Dict[str, Any]: @@ -234,7 +266,7 @@ def dataclass_quick_asdict(obj) -> Dict[str, Any]: class Trajectory: """A trajectory, e.g. a one episode rollout from an expert policy.""" - obs: Union[np.ndarray, DictObs] + obs: Observation """Observations, shape (trajectory_len + 1, ) + observation_shape.""" acts: np.ndarray @@ -344,7 +376,7 @@ def __post_init__(self): def transitions_collate_fn( batch: Sequence[Mapping[str, np.ndarray]], -) -> Mapping[str, Union[np.ndarray, th.Tensor]]: +) -> Mapping[str, AnyTensor]: """Custom `torch.utils.data.DataLoader` collate_fn for `TransitionsMinimal`. Use this as the `collate_fn` argument to `DataLoader` if using an instance of @@ -392,7 +424,7 @@ class TransitionsMinimal(th_data.Dataset, Sequence[Mapping[str, np.ndarray]]): field has been sliced. """ - obs: Union[np.ndarray, DictObs] + obs: Observation """ Previous observations. Shape: (batch_size, ) + observation_shape. @@ -479,7 +511,7 @@ def __getitem__(self, key): class Transitions(TransitionsMinimal): """A batch of obs-act-obs-done transitions.""" - next_obs: Union[np.ndarray, DictObs] + next_obs: Observation """New observation. Shape: (batch_size, ) + observation_shape. The i'th observation `next_obs[i]` in this array is the observation @@ -500,19 +532,16 @@ class Transitions(TransitionsMinimal): def __post_init__(self): """Performs input validation: check shapes & dtypes match docstring.""" super().__post_init__() - # TODO: could add support for checking dictobs shape/dtype of each array - # would be nice for debugging to have a dictobs.shapes -> dict{str: shape} - if isinstance(self.obs, np.ndarray): - if self.obs.shape != self.next_obs.shape: - raise ValueError( - "obs and next_obs must have same shape: " - f"{self.obs.shape} != {self.next_obs.shape}", - ) - if self.obs.dtype != self.next_obs.dtype: - raise ValueError( - "obs and next_obs must have the same dtype: " - f"{self.obs.dtype} != {self.next_obs.dtype}", - ) + if self.obs.shape != self.next_obs.shape: + raise ValueError( + "obs and next_obs must have same shape: " + f"{self.obs.shape} != {self.next_obs.shape}", + ) + if self.obs.dtype != self.next_obs.dtype: + raise ValueError( + "obs and next_obs must have the same dtype: " + f"{self.obs.dtype} != {self.next_obs.dtype}", + ) if self.dones.shape != (len(self.acts),): raise ValueError( "dones must be 1D array, one entry for each timestep: " @@ -538,22 +567,3 @@ def __post_init__(self): """Performs input validation, including for rews.""" super().__post_init__() _rews_validation(self.rews, self.acts) - - -ObsType = TypeVar("ObsType", np.ndarray, DictObs) - - -def concatenate_maybe_dictobs(arrs: List[ObsType]) -> ObsType: - assert len(arrs) > 0 - if isinstance(arrs[0], DictObs): - return DictObs.concatenate(arrs) - else: - return np.concatenate(arrs) - - -def stack_maybe_dictobs(arrs: List[ObsType]) -> ObsType: - assert len(arrs) > 0 - if isinstance(arrs[0], DictObs): - return DictObs.stack(arrs) - else: - return np.stack(arrs) diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 7d9c1c741..5f13154ed 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -37,7 +37,7 @@ def _predict( else: np_obs = obs.detach().cpu().numpy() for np_ob in np_obs: - np_ob_unwrapped = types.DictObs.maybe_unwrap(np_ob) + np_ob_unwrapped = types.maybe_unwrap_dictobs(np_ob) assert self.observation_space.contains(np_ob_unwrapped) np_actions.append(self._choose_action(np_ob_unwrapped)) np_actions = np.stack(np_actions, axis=0) From 071d2a73d3dbaf99f80c5f790d1f593f221e32b1 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:48:42 -0700 Subject: [PATCH 17/85] clean useless lines --- tests/algorithms/test_bc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 2423387e4..7de910555 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -392,8 +392,6 @@ def make_env(): return RolloutInfoWrapper(FloatReward(env)) env = vec_env.DummyVecEnv([make_env, make_env]) - env.observation_space["img"], env.observation_space["vec"] - env.observation_space.shape policy = sb_policies.MultiInputActorCriticPolicy( env.observation_space, From a2ccd7ef61fc6db4e1abc44089194a11999c75b3 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:53:16 -0700 Subject: [PATCH 18/85] clean up print statements --- tests/data/test_types.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index af3cdb7bb..71c50eeea 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -209,14 +209,8 @@ def test_traj_unequal_to_perturbations( perturbed["infos"] = [{"foo": 42}] * len(as_dict["infos"]) elif isinstance(as_dict[k], types.DictObs): perturbed[k] = as_dict[k].map_arrays(lambda x: x + 1) - # print(getattr(t, k), perturbed[k]) else: perturbed[k] = as_dict[k] + 1 - if t == type(t)(**perturbed): - print("\n\n\n") - print(t) - print(perturbed) - print(k) assert t != type(t)(**perturbed) @pytest.mark.parametrize("type_safe", [False, True]) From 93baa2d715d41cbdc7b2c774aa0b278c897ea4c2 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 17:48:41 -0700 Subject: [PATCH 19/85] fix typos Co-authored-by: Adam Gleave --- src/imitation/data/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 652673795..320dca98f 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -57,7 +57,7 @@ def __post_init__(self): raise ValueError("keys must by numpy arrays") def __len__(self): - """Returns the first dimension of constiuent arrays. + """Returns the first dimension of constituent arrays. Only defined if there is at least one array, and all arrays have the same length of first dimension. Otherwise raises ValueError. @@ -130,7 +130,7 @@ def shape(self) -> Dict[str, Tuple[int, ...]]: @property def dtype(self) -> Dict[str, np.dtype]: - """Returns a dictionary with shape-tuples in place of the arrays.""" + """Returns a dictionary with dtype-tuples in place of the arrays.""" return {k: v.dtype for k, v in self.d.items()} def unwrap(self) -> Dict[str, np.ndarray]: From 54f33af88d07cf511ec575454d627a7c822659ad Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Thu, 14 Sep 2023 17:50:07 -0700 Subject: [PATCH 20/85] assert matching keys in from_obs_list --- src/imitation/data/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 320dca98f..9e1c88e14 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -48,6 +48,7 @@ class DictObs: @classmethod def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]): + assert len(set(obs.keys() for obs in obs_list)) == 1 return cls( {k: np.stack([obs[k] for obs in obs_list]) for k in obs_list[0].keys()}, ) From c711abf4fa3dee1b5aab30e7799162345f59577e Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:04:13 -0700 Subject: [PATCH 21/85] move maybe_wrap, clean rollout --- src/imitation/data/rollout.py | 31 ++++++++++++------------------- src/imitation/data/types.py | 33 +++++++++++++++++++++------------ tests/data/test_types.py | 2 +- 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 7d147f209..61b48f84d 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -70,7 +70,7 @@ def __init__(self): def add_step( self, - step_dict: Mapping[str, Union[np.ndarray, Mapping[str, Any], types.DictObs]], + step_dict: Mapping[str, Union[types.Observation, Mapping[str, Any]]], key: Hashable = None, ) -> None: """Add a single step to the partial trajectory identified by `key`. @@ -122,7 +122,7 @@ def finish_trajectory( def add_steps_and_auto_finish( self, acts: np.ndarray, - obs: Union[np.ndarray, Dict[str, np.ndarray], types.DictObs], + obs: Union[types.Observation, Dict[str, np.ndarray]], rews: np.ndarray, dones: np.ndarray, infos: List[dict], @@ -147,9 +147,9 @@ def add_steps_and_auto_finish( each `True` in the `dones` argument. """ trajs: List[types.TrajectoryWithRew] = [] - wrapped_obs = types.DictObs.maybe_wrap(obs) + wrapped_obs = types.maybe_wrap_in_dictobs(obs) - # len of dictobs is the shape[0] of each value array - which here is # of envs + # iterate through environments for env_idx in range(len(wrapped_obs)): assert env_idx in self.partial_trajectories assert list(self.partial_trajectories[env_idx][0].keys()) == ["obs"], ( @@ -157,16 +157,14 @@ def add_steps_and_auto_finish( "self._traj_accum.add_step({'obs': ob}, key=env_idx)" ) + # iterate through steps zip_iter = enumerate(zip(acts, wrapped_obs, rews, dones, infos)) for env_idx, (act, ob, rew, done, info) in zip_iter: if done: # When dones[i] from VecEnv.step() is True, obs[i] is the first # observation following reset() of the ith VecEnv, and # infos[i]["terminal_observation"] is the actual final observation. - real_ob = info["terminal_observation"] - if isinstance(real_ob, dict): - # TODO: does this need to be unsqueezed or something? - real_ob = types.DictObs(real_ob) + real_ob = types.maybe_wrap_in_dictobs(info["terminal_observation"]) else: real_ob = ob @@ -280,7 +278,7 @@ def sample_until(trajs: Sequence[types.TrajectoryWithRew]) -> bool: # corresponding actions. PolicyCallable = Callable[ [ - Union[np.ndarray, types.DictObs], + types.Observation, Optional[Tuple[np.ndarray, ...]], Optional[np.ndarray], ], @@ -299,7 +297,7 @@ def policy_to_callable( if policy is None: def get_actions( - observations: Union[np.ndarray, types.DictObs], + observations: types.Observation, states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -313,7 +311,7 @@ def get_actions( # (which would call .forward()). So this elif clause must come first! def get_actions( - observations: Union[np.ndarray, types.DictObs], + observations: types.Observation, states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -418,17 +416,12 @@ def generate_trajectories( # accumulator for incomplete trajectories trajectories_accum = TrajectoryAccumulator() obs = venv.reset() - assert isinstance( obs, - ( - np.ndarray, - dict, - ), + (np.ndarray, dict), ), "Tuple observations are not supported." + wrapped_obs = types.maybe_wrap_in_dictobs(obs) - # need to wrap here to iterate over envs properly - wrapped_obs = types.DictObs.maybe_wrap(obs) for env_idx, ob in enumerate(wrapped_obs): # Seed with first obs only. Inside loop, we'll only add second obs from # each (s,a,r,s') tuple, under the same "obs" key again. That way we still @@ -454,7 +447,7 @@ def generate_trajectories( obs, (np.ndarray, dict), ), "Tuple observations are not supported." - wrapped_obs = types.DictObs.maybe_wrap(obs) + wrapped_obs = types.maybe_wrap_in_dictobs(obs) # If an environment is inactive, i.e. the episode completed for that # environment after `sample_until(trajectories)` was true, then we do diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 9e1c88e14..74a37b805 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -137,18 +137,6 @@ def dtype(self) -> Dict[str, np.dtype]: def unwrap(self) -> Dict[str, np.ndarray]: return self.d - @classmethod - def maybe_wrap( - cls, - obs: Union[Dict[str, np.ndarray], np.ndarray, "DictObs"], - ) -> Union["DictObs", np.ndarray]: - """Converts an observation into a DictObs, if necessary.""" - if isinstance(obs, dict): - return cls(obs) - else: - assert isinstance(obs, (np.ndarray, cls)) - return obs - def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": return self.__class__({k: fn(v) for k, v in self.d.items()}) @@ -231,6 +219,27 @@ def maybe_unwrap_dictobs(maybe_dictobs): return maybe_dictobs +@overload +def maybe_wrap_in_dictobs(obs: Union[Dict[str, np.ndarray], DictObs]) -> DictObs: + ... + + +@overload +def maybe_wrap_in_dictobs(obs: np.ndarray) -> np.ndarray: + ... + + +def maybe_wrap_in_dictobs( + obs: Union[Dict[str, np.ndarray], np.ndarray, DictObs], +) -> Observation: + """Converts an observation into a DictObs, if necessary.""" + if isinstance(obs, dict): + return DictObs(obs) + else: + assert isinstance(obs, (np.ndarray, DictObs)) + return obs + + def map_maybe_dict(fn, maybe_dict): """Applies fn to all values a dictionary, or to the value itself if not a dict.""" if isinstance(maybe_dict, dict): diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 71c50eeea..82c0665ba 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -493,7 +493,7 @@ def test_dict_obs(): assert abc.dtype == {"a": A.dtype, "b": B.dtype, "c": C.dtype} # wrap - assert types.DictObs.maybe_wrap({"a": A, "b": B, "c": C}) == abc + assert types.maybe_wrap_in_dictobs({"a": A, "b": B, "c": C}) == abc assert abc.unwrap() == {"a": A, "b": B, "c": C} # map, stack, concat From 58a0d709cf14959dcf38235b156df4780199c0d2 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:16:47 -0700 Subject: [PATCH 22/85] change policy callable to take dict[str, np.ndarray] not dictobs --- src/imitation/data/rollout.py | 12 +++++++----- src/imitation/policies/exploration_wrapper.py | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 61b48f84d..e8722bd7c 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -278,7 +278,7 @@ def sample_until(trajs: Sequence[types.TrajectoryWithRew]) -> bool: # corresponding actions. PolicyCallable = Callable[ [ - types.Observation, + Union[np.ndarray, Dict[str, np.ndarray]], Optional[Tuple[np.ndarray, ...]], Optional[np.ndarray], ], @@ -297,7 +297,7 @@ def policy_to_callable( if policy is None: def get_actions( - observations: types.Observation, + observations: Union[np.ndarray, Dict[str, np.ndarray]], states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -311,7 +311,7 @@ def get_actions( # (which would call .forward()). So this elif clause must come first! def get_actions( - observations: types.Observation, + observations: Union[np.ndarray, Dict[str, np.ndarray]], states: Optional[Tuple[np.ndarray, ...]], episode_starts: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -319,7 +319,7 @@ def get_actions( # pytype doesn't seem to understand that policy is a BaseAlgorithm # or BasePolicy here, rather than a Callable (acts, states) = policy.predict( # pytype: disable=attribute-error - types.maybe_unwrap_dictobs(observations), + observations, state=states, episode_start=episode_starts, deterministic=deterministic_policy, @@ -422,6 +422,7 @@ def generate_trajectories( ), "Tuple observations are not supported." wrapped_obs = types.maybe_wrap_in_dictobs(obs) + # we use dictobs to iterate over the envs in a vecenv for env_idx, ob in enumerate(wrapped_obs): # Seed with first obs only. Inside loop, we'll only add second obs from # each (s,a,r,s') tuple, under the same "obs" key again. That way we still @@ -441,7 +442,8 @@ def generate_trajectories( state = None dones = np.zeros(venv.num_envs, dtype=bool) while np.any(active): - acts, state = get_actions(wrapped_obs, state, dones) + # policy gets unwrapped observations (eg as dict, not dictobs) + acts, state = get_actions(obs, state, dones) obs, rews, dones, infos = venv.step(acts) assert isinstance( obs, diff --git a/src/imitation/policies/exploration_wrapper.py b/src/imitation/policies/exploration_wrapper.py index 447708745..cde576466 100644 --- a/src/imitation/policies/exploration_wrapper.py +++ b/src/imitation/policies/exploration_wrapper.py @@ -1,11 +1,11 @@ """Wrapper to turn a policy into a more exploratory version.""" -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np from stable_baselines3.common import vec_env -from imitation.data import rollout, types +from imitation.data import rollout from imitation.util import util @@ -57,7 +57,7 @@ def __init__( def _random_policy( self, - obs: Union[np.ndarray, types.DictObs], + obs: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[Tuple[np.ndarray, ...]], episode_start: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: @@ -74,7 +74,7 @@ def _switch(self) -> None: def __call__( self, - observation: Union[np.ndarray, types.DictObs], + observation: Union[np.ndarray, Dict[str, np.ndarray]], input_state: Optional[Tuple[np.ndarray, ...]], episode_start: Optional[np.ndarray], ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: From 0f080d470cc509e57ec3affe05c95a5a656500b3 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 09:36:51 -0700 Subject: [PATCH 23/85] rollout info wrapper supports dictobs --- src/imitation/data/wrappers.py | 6 +++--- tests/algorithms/test_bc.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 11f22775d..0a35cd02a 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -187,19 +187,19 @@ def __init__(self, env: gym.Env): def reset(self, **kwargs): new_obs = super().reset(**kwargs) - self._obs = [new_obs] + self._obs = [types.maybe_wrap_in_dictobs(new_obs)] self._rews = [] return new_obs def step(self, action): obs, rew, done, info = self.env.step(action) - self._obs.append(obs) + self._obs.append(types.maybe_wrap_in_dictobs(obs)) self._rews.append(rew) if done: assert "rollout" not in info info["rollout"] = { - "obs": np.stack(self._obs), + "obs": types.stack_maybe_dictobs(self._obs), "rews": np.stack(self._rews), } return obs, rew, done, info diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 7de910555..fc43a6041 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -407,7 +407,7 @@ def sample_expert_transitions(): venv=env, sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), rng=rng, - unwrap=False, # TODO have rollout unwrap wrapper support dict + unwrap=True, ) return rollout.flatten_trajectories(rollouts) From c4d3e11ff8be3d57c55151866a235221e91e04cb Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:02:12 -0700 Subject: [PATCH 24/85] fix from_obs_list key consistency check --- src/imitation/data/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 74a37b805..4319707a6 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -48,7 +48,8 @@ class DictObs: @classmethod def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]): - assert len(set(obs.keys() for obs in obs_list)) == 1 + # assert all have same keys + assert len(set(frozenset(obs.keys()) for obs in obs_list)) == 1 return cls( {k: np.stack([obs[k] for obs in obs_list]) for k in obs_list[0].keys()}, ) From b93294a3cb76828c75980fca3c17b76fb434ee02 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:02:41 -0700 Subject: [PATCH 25/85] xfail save/load tests with dictobs --- tests/data/test_types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 82c0665ba..25dc3f8f4 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -227,6 +227,9 @@ def test_save_trajectories( use_rewards, type_safe, ): + if isinstance(trajectory.obs, types.DictObs): + pytest.xfail("Saving/loading dictobs trajectories not yet supported") + chdir_context: contextlib.AbstractContextManager """Check that trajectories are properly saved.""" if use_chdir: From 3f17ff2aa3ec2d57a80e1b6c3100eb24c2b04a67 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:03:02 -0700 Subject: [PATCH 26/85] doc for dictobs wrapper --- tests/data/test_rollout.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 64655b0e9..0408b516e 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -374,9 +374,17 @@ def test_rollout_normal_error_for_other_shape_mismatch(rng): class DictObsWrapper(gym.ObservationWrapper): - """Simple wrapper that turns the observation into a dictionary.""" + """Simple wrapper that turns the observation into a dictionary. - def __init__(self, env: gym.Env) -> None: + The observation is duplicated and returned under two keys, one of them divided + in half.""" + + def __init__(self, env: gym.Env): + """Builds DictObsWrapper. + + Args: + venv: The wrapped VecEnv. + """ super().__init__(env) self.observation_space = gym.spaces.Dict( {"a": env.observation_space, "b": env.observation_space}, From 0212e0edd7063d612b01b2ed4d17b02ff8f61a44 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:21:41 -0700 Subject: [PATCH 27/85] don't error on int observations --- src/imitation/data/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 4319707a6..4d45b6d6b 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -215,7 +215,7 @@ def maybe_unwrap_dictobs(maybe_dictobs): if isinstance(maybe_dictobs, DictObs): return maybe_dictobs.unwrap() else: - if not isinstance(maybe_dictobs, (np.ndarray, th.Tensor)): + if not isinstance(maybe_dictobs, (np.ndarray, th.Tensor, int)): warnings.warn(f"trying to unwrap object of type {type(maybe_dictobs)}") return maybe_dictobs @@ -237,7 +237,8 @@ def maybe_wrap_in_dictobs( if isinstance(obs, dict): return DictObs(obs) else: - assert isinstance(obs, (np.ndarray, DictObs)) + if not isinstance(obs, (np.ndarray, DictObs, float, int)): + warnings.warn(f"tried to wrap {type(obs)} as an observation") return obs From 070ebf9f30b2ad83436de648833620e4796fa67c Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:27:12 -0700 Subject: [PATCH 28/85] lint fixes --- tests/algorithms/test_preference_comparisons.py | 3 ++- tests/data/test_rollout.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index 6baeb90f8..3921c7a52 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -95,7 +95,8 @@ def _check_trajs_equal( assert len(trajs1) == len(trajs2) for traj1, traj2 in zip(trajs1, trajs2): assert np.array_equal( - types.assert_not_dictobs(traj1.obs), types.assert_not_dictobs(traj2.obs) + types.assert_not_dictobs(traj1.obs), + types.assert_not_dictobs(traj2.obs), ) assert np.array_equal(traj1.acts, traj2.acts) assert np.array_equal(traj1.rews, traj2.rews) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 0408b516e..8d96cc9f1 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -376,14 +376,14 @@ def test_rollout_normal_error_for_other_shape_mismatch(rng): class DictObsWrapper(gym.ObservationWrapper): """Simple wrapper that turns the observation into a dictionary. - The observation is duplicated and returned under two keys, one of them divided - in half.""" + The observation is duplicated, with "b" rescaled. + """ def __init__(self, env: gym.Env): """Builds DictObsWrapper. Args: - venv: The wrapped VecEnv. + env: The wrapped Env. """ super().__init__(env) self.observation_space = gym.spaces.Dict( From 657e17ee198232d88a77d0e499b41455dcbdf962 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:52:07 -0700 Subject: [PATCH 29/85] cleanup bc test for dict obs --- tests/algorithms/test_bc.py | 97 ++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 54 deletions(-) diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index fc43a6041..16359f6f8 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -291,6 +291,49 @@ def test_that_policy_reconstruction_preserves_parameters( th.testing.assert_close(original, reconstructed) +class FloatReward(gym.RewardWrapper): + """Typecasts reward to a float.""" + + def reward(self, reward): + return float(reward) + + +def test_dict_space(): + def make_env(): + env = sb_envs.SimpleMultiObsEnv(channel_last=False) + env = FloatReward(env) + return RolloutInfoWrapper(env) + + env = vec_env.DummyVecEnv([make_env, make_env]) + + # multi-input policy to accept dict observations + policy = sb_policies.MultiInputActorCriticPolicy( + env.observation_space, + env.action_space, + lambda _: 0.001, + ) + rng = np.random.default_rng() + + # sample random transitions + rollouts = rollout.rollout( + policy=None, + venv=env, + sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + unwrap=True, + ) + transitions = rollout.flatten_trajectories(rollouts) + bc_trainer = bc.BC( + observation_space=env.observation_space, + policy=policy, + action_space=env.action_space, + rng=rng, + demonstrations=transitions, + ) + # confirm that training works + bc_trainer.train(n_epochs=1) + + ############################################# # ENSURE EXCEPTIONS ARE THROWN WHEN EXPECTED ############################################# @@ -375,57 +418,3 @@ def inc_batch_cnt(): # THEN assert batch_cnt == no_yield_after_iter - - -class FloatReward(gym.RewardWrapper): - """Typecasts reward to a float.""" - - def reward(self, reward): - return float(reward) - - -# TODO: make test nicer -def test_dict_space(): - # TODO: is sb_envs okay? - def make_env(): - env = sb_envs.SimpleMultiObsEnv(channel_last=False) - return RolloutInfoWrapper(FloatReward(env)) - - env = vec_env.DummyVecEnv([make_env, make_env]) - - policy = sb_policies.MultiInputActorCriticPolicy( - env.observation_space, - env.action_space, - lambda _: 0.001, - ) - rng = np.random.default_rng() - - def sample_expert_transitions(): - print("Sampling expert transitions.") - rollouts = rollout.rollout( - policy=None, - venv=env, - sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), - rng=rng, - unwrap=True, - ) - return rollout.flatten_trajectories(rollouts) - - transitions = sample_expert_transitions() - - bc_trainer = bc.BC( - observation_space=env.observation_space, - policy=policy, - action_space=env.action_space, - rng=rng, - demonstrations=transitions, - ) - - bc_trainer.train(n_epochs=1) - - reward, _ = evaluation.evaluate_policy( - bc_trainer.policy, # type: ignore[arg-type] - env, - n_eval_episodes=3, - render=False, # comment out to speed up - ) From 1f8c12ab39ec953686cc1ac66fb68636107debf9 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:24:05 -0700 Subject: [PATCH 30/85] cleanup bc.py unwrapping --- src/imitation/algorithms/bc.py | 11 ++++------- src/imitation/data/types.py | 9 +++++++-- tests/algorithms/test_bc.py | 1 + 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index ac7c5849b..75f7c4463 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -120,13 +120,10 @@ def __call__( A BCTrainingMetrics object with the loss and all the components it consists of. """ - tensor_obs: Union[th.Tensor, Dict[str, th.Tensor]] - if isinstance(obs, types.DictObs): - tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.unwrap().items()} - elif isinstance(obs, dict): - tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.items()} - else: - tensor_obs = util.safe_to_tensor(obs) + tensor_obs = types.map_maybe_dict( + util.safe_to_tensor, + types.maybe_unwrap_dictobs(obs), + ) acts = util.safe_to_tensor(acts) # policy.evaluate_actions's type signature seems wrong to me. diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 4d45b6d6b..323e080d2 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -200,13 +200,18 @@ def stack_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: return np.stack(arrs) +# the following overloads have a type error as a DictObs matches both definitions, but +# the return types are incompatible. Ideally T would exclude DictObs but that's not +# possible. @overload -def maybe_unwrap_dictobs(maybe_dictobs: DictObs) -> Dict[str, np.ndarray]: +def maybe_unwrap_dictobs( # type: ignore[misc] + maybe_dictobs: DictObs, +) -> Dict[str, np.ndarray]: ... @overload -def maybe_unwrap_dictobs(maybe_dictobs: TensorVar) -> TensorVar: +def maybe_unwrap_dictobs(maybe_dictobs: T) -> T: ... diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 16359f6f8..2a6038188 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -291,6 +291,7 @@ def test_that_policy_reconstruction_preserves_parameters( th.testing.assert_close(original, reconstructed) +# TODO: remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 merged class FloatReward(gym.RewardWrapper): """Typecasts reward to a float.""" From bd70ecd53203a30ee93d67f4d6dddb5b8b488899 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:38:10 -0700 Subject: [PATCH 31/85] cleanup rollout.py --- src/imitation/data/rollout.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index e8722bd7c..bfe10e305 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -109,13 +109,11 @@ def finish_trajectory( for k, array in part_dict.items(): out_dict_unstacked[k].append(array) - traj = types.TrajectoryWithRew( - obs=types.stack_maybe_dictobs(out_dict_unstacked["obs"]), - acts=np.stack(out_dict_unstacked["acts"], axis=0), - infos=np.stack(out_dict_unstacked["infos"], axis=0), # array of dict objs - rews=np.stack(out_dict_unstacked["rews"], axis=0), - terminal=terminal, - ) + out_dict_stacked = { + k: types.stack_maybe_dictobs(arr_list) + for k, arr_list in out_dict_unstacked.items() + } + traj = types.TrajectoryWithRew(**out_dict_stacked, terminal=terminal) assert traj.rews.shape[0] == traj.acts.shape[0] == len(traj.obs) - 1 return traj @@ -278,11 +276,11 @@ def sample_until(trajs: Sequence[types.TrajectoryWithRew]) -> bool: # corresponding actions. PolicyCallable = Callable[ [ - Union[np.ndarray, Dict[str, np.ndarray]], - Optional[Tuple[np.ndarray, ...]], - Optional[np.ndarray], + Union[np.ndarray, Dict[str, np.ndarray]], # observations + Optional[Tuple[np.ndarray, ...]], # states + Optional[np.ndarray], # episode_starts ], - Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]], + Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]], # actions, states ] AnyPolicy = Union[BaseAlgorithm, BasePolicy, PolicyCallable, None] @@ -576,15 +574,14 @@ def all_of_type(key, desired_type): assert all_of_type("obs", types.DictObs) or all_of_type("obs", np.ndarray) assert all_of_type("acts", np.ndarray) - # sad to use Any here, but mypy struggles otherwise. - # we enforce type constraints in asserts above and below. + # mypy struggles without Any annotation here. + # The necessary constraints are enforced above. keys = ["obs", "next_obs", "acts", "dones", "infos"] parts: Mapping[str, List[Any]] = {key: [] for key in keys} for traj in trajectories: parts["acts"].append(traj.acts) obs = traj.obs - parts["obs"].append(obs[:-1]) parts["next_obs"].append(obs[1:]) From bec464cedc97efca7ea658e525610d27d98934b8 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 12:09:12 -0700 Subject: [PATCH 32/85] cleanup dictobs interface --- src/imitation/data/types.py | 74 ++++++++++++++++++++++++------------- tests/data/test_rollout.py | 2 +- tests/data/test_types.py | 14 +++---- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 323e080d2..da9145c2e 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -44,18 +44,15 @@ class DictObs: lists of dictobs. """ - d: Dict[str, np.ndarray] + _d: Dict[str, np.ndarray] @classmethod - def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]): - # assert all have same keys - assert len(set(frozenset(obs.keys()) for obs in obs_list)) == 1 - return cls( - {k: np.stack([obs[k] for obs in obs_list]) for k in obs_list[0].keys()}, - ) + def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]) -> "DictObs": + """Stacks the observation list into a single DictObs.""" + return cls.stack(map(cls, obs_list)) def __post_init__(self): - if not all((isinstance(v, np.ndarray) for v in self.d.values())): + if not all((isinstance(v, np.ndarray) for v in self._d.values())): raise ValueError("keys must by numpy arrays") def __len__(self): @@ -75,7 +72,7 @@ def __len__(self): Returns: The length (first dimension) of the constiuent arrays """ - lens = set(len(v) for v in self.d.values()) + lens = set(len(v) for v in self._d.values()) if len(lens) == 1: return lens.pop() elif len(lens) == 0: @@ -87,7 +84,8 @@ def __len__(self): @property def dict_len(self): - return len(self.d) + """Returns the number of arrays in the DictObs.""" + return len(self._d) def __getitem__( self, @@ -95,6 +93,8 @@ def __getitem__( ) -> "DictObs": """Indexes or slices each array. + See `.get` for accessing a value from the underlying dictionary. + Note that it will still return singleton values as np.arrays, not scalars, to be consistent with DictObs type signature. @@ -105,7 +105,7 @@ def __getitem__( A new DictObj object with each array indexed. """ # asarray handles case where we slice to a single array element. - return self.__class__({k: np.asarray(v[key]) for k, v in self.d.items()}) + return self.__class__({k: np.asarray(v[key]) for k, v in self._d.items()}) def __iter__(self) -> Iterator["DictObs"]: """Iterates over the first dimension of each array. @@ -121,55 +121,80 @@ def __iter__(self) -> Iterator["DictObs"]: def __eq__(self, other): if not isinstance(other, self.__class__): return False - if not self.d.keys() == other.d.keys(): + if not self.keys() == other.keys(): return False - return all(np.array_equal(self.d[k], other.d[k]) for k in self.d.keys()) + return all(np.array_equal(self.get(k), other.get(k)) for k in self.keys()) @property def shape(self) -> Dict[str, Tuple[int, ...]]: """Returns a dictionary with shape-tuples in place of the arrays.""" - return {k: v.shape for k, v in self.d.items()} + return {k: v.shape for k, v in self.items()} @property def dtype(self) -> Dict[str, np.dtype]: """Returns a dictionary with dtype-tuples in place of the arrays.""" - return {k: v.dtype for k, v in self.d.items()} + return {k: v.dtype for k, v in self.items()} + + def keys(self): + return self._d.keys() + + def values(self): + return self._d.values() + + def items(self): + return self._d.items() + + def __contains__(self, key): + return key in self._d + + def get(self, key: str) -> np.ndarray: + """Returns the array for the given key, or raises KeyError.""" + return self._d[key] def unwrap(self) -> Dict[str, np.ndarray]: - return self.d + """Returns a copy of the underlying dictionary (arrays are not copied).""" + return {k: v for k, v in self._d.items()} def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": - return self.__class__({k: fn(v) for k, v in self.d.items()}) + """Returns a new DictObs with `fn` applied to every array.""" + return self.__class__({k: fn(v) for k, v in self.items()}) @staticmethod def _unravel(dictobs_list: Iterable["DictObs"]) -> Dict[str, List[np.ndarray]]: """Converts a list of DictObs into a dictionary of lists of arrays.""" + # assert all have same keys + key_set = set(frozenset(obs.keys()) for obs in dictobs_list) + if len(key_set) == 0: + raise ValueError("Empty list of DictObs") + if not len(key_set) == 1: + raise ValueError(f"Inconsistent keys: {key_set}") + unraveled: Dict[str, List[np.ndarray]] = collections.defaultdict(list) for do in dictobs_list: - for k, array in do.d.items(): + for k, array in do._d.items(): unraveled[k].append(array) return unraveled @classmethod - def stack(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": + def stack(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": + """Returns a single dictobs stacking the arrays by key.""" return cls( { - k: np.stack(arr_list) + k: np.stack(arr_list, axis=axis) for k, arr_list in cls._unravel(dictobs_list).items() }, ) @classmethod - def concatenate(cls, dictobs_list: Iterable["DictObs"]) -> "DictObs": + def concatenate(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": + """Returns a single dictobs concatenating the arrays by key.""" return cls( { - k: np.concatenate(arr_list) + k: np.concatenate(arr_list, axis=axis) for k, arr_list in cls._unravel(dictobs_list).items() }, ) - # TODO: add keys, values, items? - # DicObs utilities @@ -418,7 +443,6 @@ def transitions_collate_fn( result["infos"] = [sample["infos"] for sample in batch] result["obs"] = stack_maybe_dictobs([sample["obs"] for sample in batch]) result["next_obs"] = stack_maybe_dictobs([sample["next_obs"] for sample in batch]) - # TODO: clean names, docs return result diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 8d96cc9f1..0855efa01 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -414,4 +414,4 @@ def test_dictionary_observations(rng): ) for traj in trajs: assert isinstance(traj.obs, types.DictObs) - np.testing.assert_allclose(traj.obs.d["a"] / 2, traj.obs.d["b"]) + np.testing.assert_allclose(traj.obs.get("a") / 2, traj.obs.get("b")) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 25dc3f8f4..9c528fc4b 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -49,7 +49,7 @@ def trajectory( raw_obs = [obs_space.sample() for _ in range(length + 1)] if isinstance(obs_space, gym.spaces.Dict): - obs = types.DictObs.from_obs_list(raw_obs) + obs: types.Observation = types.DictObs.from_obs_list(raw_obs) else: obs = np.array(raw_obs) acts = np.array([act_space.sample() for _ in range(length)]) @@ -474,16 +474,16 @@ def test_dict_obs(): len(types.DictObs({})) # slicing - np.testing.assert_equal(abc[0].d["a"], A[0]) - np.testing.assert_equal(abc[0].d["c"], np.array(C[0])) - np.testing.assert_equal(abc[0:2].d["a"], np.array(A[0:2])) - np.testing.assert_equal(ab[:, 0].d["a"], np.array(A[:, 0])) + np.testing.assert_equal(abc[0].get("a"), A[0]) + np.testing.assert_equal(abc[0].get("c"), np.array(C[0])) + np.testing.assert_equal(abc[0:2].get("a"), np.array(A[0:2])) + np.testing.assert_equal(ab[:, 0].get("a"), np.array(A[:, 0])) with pytest.raises(IndexError): abc[:, 0] # iter for i, a_row in enumerate(A): - np.testing.assert_equal(a_row, ab[i].d["a"]) + np.testing.assert_equal(a_row, ab[i].get("a")) assert ab[0] == next(iter(ab)) # eq @@ -505,7 +505,7 @@ def test_dict_obs(): ) assert types.DictObs.stack(list(iter(ab))) == ab np.testing.assert_equal( - types.DictObs.concatenate([abc, abc]).d["a"], + types.DictObs.concatenate([abc, abc]).get("a"), np.concatenate([A, A]), ) From bef19e699caf4e2e1dcd252c5d5e531c91791ff3 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:32:46 -0700 Subject: [PATCH 33/85] small cleanups --- src/imitation/data/buffer.py | 3 +-- src/imitation/data/rollout.py | 7 +++++-- src/imitation/data/types.py | 17 +++++++++++------ src/imitation/util/util.py | 5 +---- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/imitation/data/buffer.py b/src/imitation/data/buffer.py index 2a6939408..aa4ecb36e 100644 --- a/src/imitation/data/buffer.py +++ b/src/imitation/data/buffer.py @@ -1,6 +1,5 @@ """Buffers to store NumPy arrays and transitions in.""" -import dataclasses from typing import Any, Mapping, Optional, Tuple import numpy as np @@ -384,7 +383,7 @@ def store(self, transitions: types.Transitions, truncate_ok: bool = True) -> Non Raises: ValueError: The arguments didn't have the same length. """ # noqa: DAR402 - trans_dict = dataclasses.asdict(transitions) + trans_dict = types.dataclass_quick_asdict(transitions) # Remove unnecessary fields trans_dict = {k: trans_dict[k] for k in self._buffer.sample_shapes.keys()} self._buffer.store(trans_dict, truncate_ok=truncate_ok) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index bfe10e305..5593964ce 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -609,7 +609,10 @@ def flatten_trajectories_with_rew( ) -> types.TransitionsWithRew: transitions = flatten_trajectories(trajectories) rews = np.concatenate([traj.rews for traj in trajectories]) - return types.TransitionsWithRew(**dataclasses.asdict(transitions), rews=rews) + return types.TransitionsWithRew( + **types.dataclass_quick_asdict(transitions), + rews=rews, + ) def generate_transitions( @@ -650,7 +653,7 @@ def generate_transitions( ) transitions = flatten_trajectories_with_rew(traj) if truncate and n_timesteps is not None: - as_dict = dataclasses.asdict(transitions) + as_dict = types.dataclass_quick_asdict(transitions) truncated = {k: arr[:n_timesteps] for k, arr in as_dict.items()} transitions = types.TransitionsWithRew(**truncated) return transitions diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index da9145c2e..bcd9e7282 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -2,6 +2,8 @@ import collections import dataclasses +import itertools +import numbers import os import warnings from typing import ( @@ -52,7 +54,9 @@ def from_obs_list(cls, obs_list: List[Dict[str, np.ndarray]]) -> "DictObs": return cls.stack(map(cls, obs_list)) def __post_init__(self): - if not all((isinstance(v, np.ndarray) for v in self._d.values())): + if not all( + isinstance(v, (np.ndarray, numbers.Number)) for v in self._d.values() + ): raise ValueError("keys must by numpy arrays") def __len__(self): @@ -162,15 +166,16 @@ def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": @staticmethod def _unravel(dictobs_list: Iterable["DictObs"]) -> Dict[str, List[np.ndarray]]: """Converts a list of DictObs into a dictionary of lists of arrays.""" + it1, it2 = itertools.tee(dictobs_list) # assert all have same keys - key_set = set(frozenset(obs.keys()) for obs in dictobs_list) + key_set = set(frozenset(obs.keys()) for obs in it1) if len(key_set) == 0: raise ValueError("Empty list of DictObs") if not len(key_set) == 1: raise ValueError(f"Inconsistent keys: {key_set}") unraveled: Dict[str, List[np.ndarray]] = collections.defaultdict(list) - for do in dictobs_list: + for do in it2: for k, array in do._d.items(): unraveled[k].append(array) return unraveled @@ -205,7 +210,7 @@ def concatenate(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": def assert_not_dictobs(x: Observation) -> np.ndarray: if isinstance(x, DictObs): - raise ValueError("Dictionary observations are not supported here.") + assert False, "Dictionary observations are not supported here." return x @@ -433,12 +438,12 @@ def transitions_collate_fn( list of dicts. (The default behavior would recursively collate every info dict into a single dict, which is incorrect.) """ - batch_no_infos = [ + batch_acts_and_dones = [ {k: np.array(v) for k, v in sample.items() if k in ["acts", "dones"]} for sample in batch ] - result = th_data.dataloader.default_collate(batch_no_infos) + result = th_data.dataloader.default_collate(batch_acts_and_dones) assert isinstance(result, dict) result["infos"] = [sample["infos"] for sample in batch] result["obs"] = stack_maybe_dictobs([sample["obs"] for sample in batch]) diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index cba83f504..2abae1605 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -230,10 +230,7 @@ def endless_iter(iterable: Iterable[T]) -> Iterator[T]: return itertools.chain.from_iterable(itertools.repeat(iterable)) -def safe_to_tensor( - array: Union[np.ndarray, th.Tensor], - **kwargs, -) -> th.Tensor: +def safe_to_tensor(array: Union[np.ndarray, th.Tensor], **kwargs) -> th.Tensor: """Converts a NumPy array to a PyTorch tensor. The data is copied in the case where the array is non-writable. Unfortunately if From 9aaf73f98ecbb55e3dbd4c8475e39821b9344180 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:44:28 -0700 Subject: [PATCH 34/85] coverage fixes, test fix --- tests/data/test_types.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index 9c528fc4b..f4aa634ad 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -473,6 +473,8 @@ def test_dict_obs(): with pytest.raises(ValueError): len(types.DictObs({})) + assert abc.dict_len == 3 + # slicing np.testing.assert_equal(abc[0].get("a"), A[0]) np.testing.assert_equal(abc[0].get("c"), np.array(C[0])) @@ -490,6 +492,8 @@ def test_dict_obs(): assert abc == types.DictObs({"a": A, "b": B, "c": C}) assert abc != types.DictObs({"a": A, "c": B, "b": C}) # diff keys assert abc != types.DictObs({"a": A, "b": B + 1, "c": C}) # diff values + assert abc != {"a": A, "b": B + 1, "c": C} # diff type + assert abc != ab # diff keys # shape / dtype assert abc.shape == {"a": A.shape, "b": B.shape, "c": C.shape} @@ -509,5 +513,8 @@ def test_dict_obs(): np.concatenate([A, A]), ) - with pytest.raises(ValueError): + with pytest.raises(AssertionError): types.assert_not_dictobs(abc) + + with pytest.raises(ValueError): + types.DictObs({"a": "not an array"}) # type: ignore[wrong-arg-types] From 5d6aa7785e847a89aba71d945ac3fed63e633f6e Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:48:47 -0700 Subject: [PATCH 35/85] adjust error types --- src/imitation/data/types.py | 8 ++++---- tests/data/test_types.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index bcd9e7282..2b0d050f0 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -57,7 +57,7 @@ def __post_init__(self): if not all( isinstance(v, (np.ndarray, numbers.Number)) for v in self._d.values() ): - raise ValueError("keys must by numpy arrays") + raise TypeError("keys must by numpy arrays") def __len__(self): """Returns the first dimension of constituent arrays. @@ -71,7 +71,7 @@ def __len__(self): Use `dict_len` to get the number of entries in the dictionary. Raises: - ValueError: if the arrays have different lengths or there are no arrays. + RuntimeError: if the arrays have different lengths or there are no arrays. Returns: The length (first dimension) of the constiuent arrays @@ -80,9 +80,9 @@ def __len__(self): if len(lens) == 1: return lens.pop() elif len(lens) == 0: - raise ValueError("Length not defined as DictObs is empty") + raise RuntimeError("Length not defined as DictObs is empty") else: - raise ValueError( + raise RuntimeError( f"Length not defined; arrays have conflicting first dimensions: {lens}", ) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index f4aa634ad..d32e06562 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -468,9 +468,9 @@ def test_dict_obs(): # len assert len(ab) == 3 - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): len(abc) - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): len(types.DictObs({})) assert abc.dict_len == 3 @@ -516,5 +516,5 @@ def test_dict_obs(): with pytest.raises(AssertionError): types.assert_not_dictobs(abc) - with pytest.raises(ValueError): + with pytest.raises(TypeError): types.DictObs({"a": "not an array"}) # type: ignore[wrong-arg-types] From 86fbcf153cc76df5145bf23f5f919bfcfe455ef9 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 14:04:59 -0700 Subject: [PATCH 36/85] docstrings for type helpers --- src/imitation/data/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 2b0d050f0..a010b345e 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -209,12 +209,14 @@ def concatenate(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": def assert_not_dictobs(x: Observation) -> np.ndarray: + """Typeguard to assert `x` is an array, not a DictObs.""" if isinstance(x, DictObs): assert False, "Dictionary observations are not supported here." return x def concatenate_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: + """Concatenates a list of observations appropriately (depending on type).""" assert len(arrs) > 0 if isinstance(arrs[0], DictObs): return DictObs.concatenate(arrs) @@ -223,6 +225,7 @@ def concatenate_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: def stack_maybe_dictobs(arrs: List[ObsVar]) -> ObsVar: + """Stacks a list of observations appropriately (depending on type).""" assert len(arrs) > 0 if isinstance(arrs[0], DictObs): return DictObs.stack(arrs) From 8d1e0d66dba180c9fddc7a860ca1184f7f521332 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:56:59 -0700 Subject: [PATCH 37/85] add dict obs space support for density --- src/imitation/algorithms/density.py | 63 ++++++++++++---------- src/imitation/algorithms/mce_irl.py | 6 ++- src/imitation/data/types.py | 17 +++++- tests/algorithms/test_base.py | 5 +- tests/algorithms/test_density_baselines.py | 6 +-- 5 files changed, 59 insertions(+), 38 deletions(-) diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index 378c6bf80..931d24f88 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -134,9 +134,9 @@ def __init__( def _get_demo_from_batch( self, - obs_b: np.ndarray, + obs_b: types.Observation, act_b: np.ndarray, - next_obs_b: Optional[np.ndarray], + next_obs_b: Optional[types.Observation], ) -> Dict[Optional[int], List[np.ndarray]]: if next_obs_b is None and self.density_type == DensityType.STATE_STATE_DENSITY: raise ValueError( @@ -145,11 +145,16 @@ def _get_demo_from_batch( ) assert act_b.shape[1:] == self.venv.action_space.shape - assert obs_b.shape[1:] == self.venv.observation_space.shape + + if isinstance(obs_b, types.DictObs): + exp_shape = {k: v.shape for k, v in self.venv.observation_space.items()} + obs_shape = {k: v.shape[1:] for k, v in obs_b.items()} + assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}" + else: + assert obs_b.shape[1:] == self.venv.observation_space.shape assert len(act_b) == len(obs_b) if next_obs_b is not None: - assert next_obs_b.shape[1:] == self.venv.observation_space.shape - assert len(next_obs_b) == len(obs_b) + assert next_obs_b.shape == obs_b.shape if next_obs_b is not None: next_obs_b_iterator: Iterable = next_obs_b @@ -168,11 +173,9 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: if isinstance(demonstrations, types.TransitionsMinimal): next_obs_b = getattr(demonstrations, "next_obs", None) - if next_obs_b is not None: - next_obs_b = types.assert_not_dictobs(next_obs_b) transitions.update( self._get_demo_from_batch( - types.assert_not_dictobs(demonstrations.obs), + demonstrations.obs, demonstrations.acts, next_obs_b, ), @@ -193,9 +196,8 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: demonstrations = cast(Iterable[types.Trajectory], demonstrations) for traj in demonstrations: - traj_obs = types.assert_not_dictobs(traj.obs) for i, (obs, act, next_obs) in enumerate( - zip(traj_obs[:-1], traj.acts, traj_obs[1:]), + zip(traj.obs[:-1], traj.acts, traj.obs[1:]), ): flat_trans = self._preprocess_transition(obs, act, next_obs) transitions.setdefault(i, []).append(flat_trans) @@ -203,14 +205,17 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: # analogous to cast above. demonstrations = cast(Iterable[types.TransitionMapping], demonstrations) + def to_np_maybe_dictobs(x): + if isinstance(x, types.DictObs): + return x + else: + return util.safe_to_numpy(x, warn=True) + for batch in demonstrations: - transitions.update( - self._get_demo_from_batch( - util.safe_to_numpy(batch["obs"], warn=True), - util.safe_to_numpy(batch["acts"], warn=True), - util.safe_to_numpy(batch.get("next_obs"), warn=True), - ), - ) + obs = to_np_maybe_dictobs(batch["obs"]) + acts = util.safe_to_numpy(batch["acts"], warn=True) + next_obs = to_np_maybe_dictobs(batch.get("next_obs")) + transitions.update(self._get_demo_from_batch(obs, acts, next_obs)) else: raise TypeError( f"Unsupported demonstration type {type(demonstrations)}", @@ -256,28 +261,28 @@ def _fit_density(self, transitions: np.ndarray) -> neighbors.KernelDensity: def _preprocess_transition( self, - obs: np.ndarray, + obs: types.Observation, act: np.ndarray, - next_obs: Optional[np.ndarray], + next_obs: Optional[types.Observation], ) -> np.ndarray: """Compute flattened transition on subset specified by `self.density_type`.""" + flattened_obs = flatten( + self.venv.observation_space, + types.maybe_unwrap_dictobs(obs), + ) if self.density_type == DensityType.STATE_DENSITY: - return flatten(self.venv.observation_space, obs) + return flattened_obs elif self.density_type == DensityType.STATE_ACTION_DENSITY: return np.concatenate( - [ - flatten(self.venv.observation_space, obs), - flatten(self.venv.action_space, act), - ], + [flattened_obs, flatten(self.venv.action_space, act)], ) elif self.density_type == DensityType.STATE_STATE_DENSITY: assert next_obs is not None - return np.concatenate( - [ - flatten(self.venv.observation_space, obs), - flatten(self.venv.observation_space, next_obs), - ], + flattened_next_obs = flatten( + self.venv.observation_space, + types.maybe_unwrap_dictobs(obs), ) + return np.concatenate([flattened_obs, flattened_next_obs]) else: raise ValueError(f"Unknown density type {self.density_type}") diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index 038116044..96c5b7a44 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -444,8 +444,10 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: for batch in demonstrations: assert isinstance(batch, Mapping) for k in ("obs", "dones", "next_obs"): - if k in batch: - collated_list[k].append(batch[k]) + x = batch.get(k) + if x is not None: + assert isinstance(x, (np.ndarray, th.Tensor)) + collated_list[k].append(x) collated = {k: np.concatenate(v) for k, v in collated_list.items()} assert "obs" in collated diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index a010b345e..e22cea825 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -17,6 +17,7 @@ Optional, Sequence, Tuple, + TypedDict, TypeVar, Union, overload, @@ -288,8 +289,20 @@ def map_maybe_dict(fn, maybe_dict): return fn(maybe_dict) -# TODO: maybe should support DictObs? -TransitionMapping = Mapping[str, AnyTensor] +class TransitionMappingNoNextObs(TypedDict): + """Dictionary with `obs` and `acts`.""" + + obs: Union[Observation, th.Tensor] + acts: AnyTensor + + +# inheritance with total=False so these are not required +class TransitionMapping(TransitionMappingNoNextObs, total=False): + """Dictionary with `obs` and `acts`, maybe also `next_obs`, `dones`, `rew`.""" + + next_obs: AnyTensor + dones: AnyTensor + rew: AnyTensor def dataclass_quick_asdict(obj) -> Dict[str, Any]: diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 802ac0d7f..1654a1482 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -41,7 +41,10 @@ def test_check_fixed_horizon_flag(custom_logger): def _make_and_iterate_loader(*args, **kwargs): - loader = base.make_data_loader(*args, **kwargs) + # our pytype version doesn't understand optional arguments in TypedDict + # this is fixed in 2023.04.11, but we require 2022.7.26 + # See https://github.com/google/pytype/issues/1195 + loader = base.make_data_loader(*args, **kwargs) # pytype: disable=wrong-arg-types for batch in loader: pass diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index a7288e1bf..6e9ca8fd7 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -1,7 +1,7 @@ """Tests for `imitation.algorithms.density_baselines`.""" from dataclasses import asdict -from typing import Sequence +from typing import Sequence, cast import numpy as np import pytest @@ -119,9 +119,7 @@ def test_density_with_other_trajectory_types( ) rollouts = pendulum_expert_trajectories[:2] transitions = rollout.flatten_trajectories_with_rew(rollouts) - transitions_mappings = [ - asdict(transitions), - ] + transitions_mappings = [cast(types.TransitionMapping, asdict(transitions))] minimal_transitions = types.TransitionsMinimal( obs=transitions.obs, From 96978d54cab6c505305cbc02e3a37a7bc6d38eac Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 15:59:27 -0700 Subject: [PATCH 38/85] fix typos Co-authored-by: Adam Gleave --- src/imitation/data/types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index e22cea825..63d1107d4 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -58,7 +58,7 @@ def __post_init__(self): if not all( isinstance(v, (np.ndarray, numbers.Number)) for v in self._d.values() ): - raise TypeError("keys must by numpy arrays") + raise TypeError("Values must be NumPy arrays") def __len__(self): """Returns the first dimension of constituent arrays. @@ -75,7 +75,7 @@ def __len__(self): RuntimeError: if the arrays have different lengths or there are no arrays. Returns: - The length (first dimension) of the constiuent arrays + The length (first dimension) of the constituent arrays """ lens = set(len(v) for v in self._d.values()) if len(lens) == 1: @@ -158,7 +158,7 @@ def get(self, key: str) -> np.ndarray: def unwrap(self) -> Dict[str, np.ndarray]: """Returns a copy of the underlying dictionary (arrays are not copied).""" - return {k: v for k, v in self._d.items()} + return dict(self._d) def map_arrays(self, fn: Callable[[np.ndarray], np.ndarray]) -> "DictObs": """Returns a new DictObs with `fn` applied to every array.""" @@ -202,7 +202,7 @@ def concatenate(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": ) -# DicObs utilities +# DictObs utilities Observation = Union[np.ndarray, DictObs] From e95df9dafabeb1b2f07b447e8a09c9866f039774 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:18:17 -0700 Subject: [PATCH 39/85] Adam suggestions from code review Co-authored-by: Adam Gleave --- src/imitation/data/types.py | 2 +- tests/data/test_types.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 63d1107d4..515303792 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -300,7 +300,7 @@ class TransitionMappingNoNextObs(TypedDict): class TransitionMapping(TransitionMappingNoNextObs, total=False): """Dictionary with `obs` and `acts`, maybe also `next_obs`, `dones`, `rew`.""" - next_obs: AnyTensor + next_obs: Union[Observation, th.Tensor] dones: AnyTensor rew: AnyTensor diff --git a/tests/data/test_types.py b/tests/data/test_types.py index d32e06562..58c3b9527 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -490,6 +490,7 @@ def test_dict_obs(): # eq assert abc == types.DictObs({"a": A, "b": B, "c": C}) + assert abc == types.DictObs({"a": np.array(A), "b": np.array(B), "c": np.array(C)}) assert abc != types.DictObs({"a": A, "c": B, "b": C}) # diff keys assert abc != types.DictObs({"a": A, "b": B + 1, "c": C}) # diff values assert abc != {"a": A, "b": B + 1, "c": C} # diff type From 161ec9537183e4494e00c99fb2cd761213388a8d Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:31:24 -0700 Subject: [PATCH 40/85] small changes for code review --- src/imitation/algorithms/bc.py | 6 ++---- src/imitation/data/types.py | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 75f7c4463..3a5421b5d 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -126,10 +126,8 @@ def __call__( ) acts = util.safe_to_tensor(acts) - # policy.evaluate_actions's type signature seems wrong to me. - # it declares it only takes a tensor but it calls - # extract_features which is happy with Dict[str, tensor]. - # In reality the required type of obs depends on the feature extractor. + # policy.evaluate_actions's type signatures are incorrect. + # See https://github.com/DLR-RM/stable-baselines3/issues/1679 (_, log_prob, entropy) = policy.evaluate_actions( tensor_obs, # type: ignore[arg-type] acts, diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 515303792..ea38a6c27 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -176,8 +176,8 @@ def _unravel(dictobs_list: Iterable["DictObs"]) -> Dict[str, List[np.ndarray]]: raise ValueError(f"Inconsistent keys: {key_set}") unraveled: Dict[str, List[np.ndarray]] = collections.defaultdict(list) - for do in it2: - for k, array in do._d.items(): + for ob_dict in it2: + for k, array in ob_dict._d.items(): unraveled[k].append(array) return unraveled @@ -211,8 +211,7 @@ def concatenate(cls, dictobs_list: Iterable["DictObs"], axis=0) -> "DictObs": def assert_not_dictobs(x: Observation) -> np.ndarray: """Typeguard to assert `x` is an array, not a DictObs.""" - if isinstance(x, DictObs): - assert False, "Dictionary observations are not supported here." + assert not isinstance(x, DictObs), "Dictionary observations are not supported here." return x @@ -282,7 +281,16 @@ def maybe_wrap_in_dictobs( def map_maybe_dict(fn, maybe_dict): - """Applies fn to all values a dictionary, or to the value itself if not a dict.""" + """Either maps fn over the values of maybe_dict (if it is a dict), or applies fn + to `maybe dict` itself (if it's not a dict). + + Args: + fn: function to apply. Must take a single argument. + maybe_dict: either a dict or a value that can be passed to fn. + + Returns: + Either a dict (if maybe_dict was a dict) or `fn(maybe_dict)`. + """ if isinstance(maybe_dict, dict): return {k: fn(v) for k, v in maybe_dict.items()} else: From 90bdf57dcc77c3a879a506a787da15ef9d70d336 Mon Sep 17 00:00:00 2001 From: Nicholas Goldowsky-Dill <8730377+NixGD@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:32:11 -0700 Subject: [PATCH 41/85] fix docstring --- src/imitation/data/types.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index ea38a6c27..573176ffe 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -281,8 +281,7 @@ def maybe_wrap_in_dictobs( def map_maybe_dict(fn, maybe_dict): - """Either maps fn over the values of maybe_dict (if it is a dict), or applies fn - to `maybe dict` itself (if it's not a dict). + """Either maps fn over dictionary values or applies fn to `maybe_dict`. Args: fn: function to apply. Must take a single argument. From 6aa25ff3091a065c64838945bceaab8aaad918cb Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 11:24:43 -0700 Subject: [PATCH 42/85] remove FloatReward --- tests/algorithms/test_bc.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 2a6038188..2120b0da0 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -291,18 +291,9 @@ def test_that_policy_reconstruction_preserves_parameters( th.testing.assert_close(original, reconstructed) -# TODO: remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 merged -class FloatReward(gym.RewardWrapper): - """Typecasts reward to a float.""" - - def reward(self, reward): - return float(reward) - - def test_dict_space(): def make_env(): env = sb_envs.SimpleMultiObsEnv(channel_last=False) - env = FloatReward(env) return RolloutInfoWrapper(env) env = vec_env.DummyVecEnv([make_env, make_env]) From 4ce1b57090c1c01db3929c5797ceaae5e5683fcf Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 13:05:07 -0700 Subject: [PATCH 43/85] Fix test_bc --- src/imitation/data/rollout.py | 10 ++-------- tests/algorithms/test_bc.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 69251bd7a..b63d1db5d 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -18,7 +18,7 @@ ) import numpy as np -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.utils import check_for_correct_spaces @@ -413,12 +413,6 @@ def generate_trajectories( ValueError: If the observation or action space has no shape or the observations are not a numpy array. """ - if venv.observation_space.shape is None: - raise ValueError("Observation space must have a shape.") - - if venv.action_space.shape is None: - raise ValueError("Action space must have a shape.") - get_actions = policy_to_callable(policy, venv, deterministic_policy) # Collect rollout tuples. @@ -494,7 +488,7 @@ def generate_trajectories( for trajectory in trajectories: n_steps = len(trajectory.acts) # extra 1 for the end - if isinstance(venv.observation_space, spaces.dict.Dict): + if isinstance(venv.observation_space, spaces.Dict): exp_obs = { k: (n_steps + 1,) + v.shape for k, v in venv.observation_space.items() } diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index da8616259..43898b008 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -4,7 +4,7 @@ import os from typing import Any, Callable, Optional, Sequence -import gym +import gymnasium as gym import hypothesis import hypothesis.strategies as st import numpy as np @@ -291,9 +291,18 @@ def test_that_policy_reconstruction_preserves_parameters( th.testing.assert_close(original, reconstructed) +# TODO: remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 merged +class FloatReward(gym.RewardWrapper): + """Typecasts reward to a float.""" + + def reward(self, reward): + return float(reward) + + def test_dict_space(): def make_env(): env = sb_envs.SimpleMultiObsEnv(channel_last=False) + env = FloatReward(env) return RolloutInfoWrapper(env) env = vec_env.DummyVecEnv([make_env, make_env]) From de1b1c8a93c584589e77021968bda8460af89aa6 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 13:48:55 -0700 Subject: [PATCH 44/85] Turn off GPU finding to avoid using gpu device --- tests/algorithms/test_dagger.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/algorithms/test_dagger.py b/tests/algorithms/test_dagger.py index 92aaf2da2..c4d9b69c6 100644 --- a/tests/algorithms/test_dagger.py +++ b/tests/algorithms/test_dagger.py @@ -20,6 +20,8 @@ from imitation.testing import reward_improvement from imitation.util import util +os.environ["CUDA_VISIBLE_DEVICES"] = "" + @pytest.fixture(params=[True, False]) def maybe_pendulum_expert_trajectories( @@ -114,7 +116,7 @@ def get_random_acts(obs): for info in infos: assert isinstance(info, dict) # roll out 5 * venv.num_envs episodes (Pendulum-v1 has 200 timestep episodes) - for i in range(1000): + for _ in range(1000): _, _, dones, _ = collector.step(zero_acts) num_episodes += np.sum(dones) @@ -163,7 +165,7 @@ def test_traj_collector_reproducible(tmpdir, pendulum_venv): (pendulum_venv.num_envs,) + pendulum_venv.action_space.shape, dtype=pendulum_venv.action_space.dtype, ) - for i in range(1000): + for _ in range(1000): _, _, dones, _ = collector.step(zero_acts) # Get the observations from all the collected trajectories. @@ -377,7 +379,7 @@ def test_trainer_makes_progress(init_trainer_fn, pendulum_venv, pendulum_expert_ assert np.mean(novice_rewards) < -1000 # Train for 5 iterations. (4 or fewer causes test to fail on some configs.) # see https://github.com/HumanCompatibleAI/imitation/issues/580 for details - for i in range(5): + for _ in range(5): # roll out a few trajectories for dataset, then train for a few steps collector = trainer.create_trajectory_collector() for _ in range(4): @@ -447,7 +449,7 @@ def test_trainer_reproducible( rng, ) - for i in range(2): + for _ in range(2): collector = trainer.create_trajectory_collector() obs = collector.reset() dones = [False] * pendulum_venv.num_envs From 1a1a45896b4f941e5d51e2435a62f71fc2b77ec7 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 14:43:45 -0700 Subject: [PATCH 45/85] Check None to ensure __add__ can work --- src/imitation/data/rollout.py | 13 ++++++++++--- tests/algorithms/test_mce_irl.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index b63d1db5d..032202814 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -410,8 +410,7 @@ def generate_trajectories( should truncate if required. Raises: - ValueError: If the observation or action space has no shape or the observations - are not a numpy array. + ValueError: If the environment's observation space is not a tuple or spaces.Dict. """ get_actions = policy_to_callable(policy, venv, deterministic_policy) @@ -489,13 +488,21 @@ def generate_trajectories( n_steps = len(trajectory.acts) # extra 1 for the end if isinstance(venv.observation_space, spaces.Dict): + for v in venv.observation_space.values(): + assert v.shape is not None exp_obs = { k: (n_steps + 1,) + v.shape for k, v in venv.observation_space.items() } - else: + elif isinstance(venv.observation_space.shape, tuple): exp_obs = (n_steps + 1,) + venv.observation_space.shape + else: + raise ValueError( + "Observation space has unexpected shape type:" + f"{type(venv.observation_space.shape)}." + ) real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" + assert venv.action_space.shape is not None exp_act = (n_steps,) + venv.action_space.shape real_act = trajectory.acts.shape assert real_act == exp_act, f"expected shape {exp_act}, got {real_act}" diff --git a/tests/algorithms/test_mce_irl.py b/tests/algorithms/test_mce_irl.py index 5f9549ea4..e8e99edbe 100644 --- a/tests/algorithms/test_mce_irl.py +++ b/tests/algorithms/test_mce_irl.py @@ -24,7 +24,7 @@ def rollouts(env, n=10, seed=None): rv = [] - for i in range(n): + for _ in range(n): done = False # if a seed is given, then we use the same seed each time (should # give same trajectory each time) From f7866f4c2e03efdbe1b2fcb9cd46a90abe072ce7 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 14:52:44 -0700 Subject: [PATCH 46/85] fix docstring --- src/imitation/data/rollout.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 032202814..62fa908b9 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -410,7 +410,7 @@ def generate_trajectories( should truncate if required. Raises: - ValueError: If the environment's observation space is not a tuple or spaces.Dict. + ValueError: If the environment's observation space is not tuple or spaces.Dict. """ get_actions = policy_to_callable(policy, venv, deterministic_policy) @@ -498,7 +498,7 @@ def generate_trajectories( else: raise ValueError( "Observation space has unexpected shape type:" - f"{type(venv.observation_space.shape)}." + f"{type(venv.observation_space.shape)}", ) real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" From daa838da1d874216f832918ded3486ced7e102e5 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 15:44:22 -0700 Subject: [PATCH 47/85] bypass pytype and lint test --- src/imitation/algorithms/density.py | 21 +++++++++++---------- src/imitation/data/rollout.py | 19 ++++++------------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index 29a8ed234..c93e74a74 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -147,6 +147,7 @@ def _get_demo_from_batch( assert act_b.shape[1:] == self.venv.action_space.shape if isinstance(obs_b, types.DictObs): + # type: ignore[attr-defined] exp_shape = {k: v.shape for k, v in self.venv.observation_space.items()} obs_shape = {k: v.shape[1:] for k, v in obs_b.items()} assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}" @@ -270,12 +271,12 @@ def _preprocess_transition( self.venv.observation_space, types.maybe_unwrap_dictobs(obs), ) - _check_data_is_np_array(flattened_obs, "observation") + flattened_obs = _check_data_is_np_array(flattened_obs, "observation") if self.density_type == DensityType.STATE_DENSITY: return flattened_obs elif self.density_type == DensityType.STATE_ACTION_DENSITY: flattened_action = space_utils.flatten(self.venv.action_space, act) - _check_data_is_np_array(flattened_action, "action") + flattened_action = _check_data_is_np_array(flattened_action, "action") return np.concatenate([flattened_obs, flattened_action]) elif self.density_type == DensityType.STATE_STATE_DENSITY: assert next_obs is not None @@ -283,7 +284,7 @@ def _preprocess_transition( self.venv.observation_space, types.maybe_unwrap_dictobs(next_obs), ) - _check_data_is_np_array(flat_next_obs, "observation") + flat_next_obs = _check_data_is_np_array(flat_next_obs, "observation") assert type(flattened_obs) is type(flat_next_obs) return np.concatenate([flattened_obs, flat_next_obs]) @@ -409,11 +410,11 @@ def policy(self) -> base_class.BasePolicy: return self.rl_algo.policy -def _check_data_is_np_array(data: space_utils.FlatType, name: str) -> None: +def _check_data_is_np_array(data: space_utils.FlatType, name: str) -> np.ndarray: """Raises error if the flattened data is not a numpy array.""" - if not isinstance(data, np.ndarray): - raise ValueError( - "The density estimator only supports spaces that " - f"flatten to a numpy array but the {name} space " - f"flattens to {type(data)}", - ) + assert isinstance(data, np.ndarray), ( + "The density estimator only supports spaces that " + f"flatten to a numpy array but the {name} space " + f"flattens to {type(data)}", + ) + return data diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 62fa908b9..222f14cb8 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -408,9 +408,6 @@ def generate_trajectories( Sequence of trajectories, satisfying `sample_until`. Additional trajectories may be collected to avoid biasing process towards short episodes; the user should truncate if required. - - Raises: - ValueError: If the environment's observation space is not tuple or spaces.Dict. """ get_actions = policy_to_callable(policy, venv, deterministic_policy) @@ -488,18 +485,14 @@ def generate_trajectories( n_steps = len(trajectory.acts) # extra 1 for the end if isinstance(venv.observation_space, spaces.Dict): - for v in venv.observation_space.values(): + exp_obs = {} + for k, v in venv.observation_space.items(): assert v.shape is not None - exp_obs = { - k: (n_steps + 1,) + v.shape for k, v in venv.observation_space.items() - } - elif isinstance(venv.observation_space.shape, tuple): - exp_obs = (n_steps + 1,) + venv.observation_space.shape + exp_obs[k] = (n_steps + 1,) + v.shape else: - raise ValueError( - "Observation space has unexpected shape type:" - f"{type(venv.observation_space.shape)}", - ) + assert venv.observation_space.shape is not None + # type: ignore[assignment] + exp_obs = (n_steps + 1,) + venv.observation_space.shape real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" assert venv.action_space.shape is not None From 803eab0e929a6be1a552d6d30fba6904983939c5 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 16:10:52 -0700 Subject: [PATCH 48/85] format with black --- src/imitation/algorithms/density.py | 7 +++++-- src/imitation/data/rollout.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index c93e74a74..4f1cdbdca 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -147,8 +147,11 @@ def _get_demo_from_batch( assert act_b.shape[1:] == self.venv.action_space.shape if isinstance(obs_b, types.DictObs): - # type: ignore[attr-defined] - exp_shape = {k: v.shape for k, v in self.venv.observation_space.items()} + exp_shape = { + k: v.shape + # type: ignore[attr-defined] + for k, v in self.venv.observation_space.items() + } obs_shape = {k: v.shape[1:] for k, v in obs_b.items()} assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}" else: diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 222f14cb8..d7e1a70f8 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -491,8 +491,10 @@ def generate_trajectories( exp_obs[k] = (n_steps + 1,) + v.shape else: assert venv.observation_space.shape is not None - # type: ignore[assignment] - exp_obs = (n_steps + 1,) + venv.observation_space.shape + exp_obs = ( + (n_steps + 1,) + + venv.observation_space.shape, # type: ignore[assignment] + ) real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" assert venv.action_space.shape is not None From 0ac6f548943f66352198a63d3bbaff04ae87ffc6 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 16:13:05 -0700 Subject: [PATCH 49/85] Test dict space in density algo --- tests/algorithms/test_density_baselines.py | 62 +++++++++++++++++++--- 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index 6e9ca8fd7..16048b651 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -3,14 +3,17 @@ from dataclasses import asdict from typing import Sequence, cast +import gymnasium as gym import numpy as np import pytest import stable_baselines3 -from stable_baselines3.common import policies +from stable_baselines3.common import envs as sb_envs +from stable_baselines3.common import policies, vec_env from imitation.algorithms.density import DensityAlgorithm, DensityType from imitation.data import rollout, types from imitation.data.types import TrajectoryWithRew +from imitation.data.wrappers import RolloutInfoWrapper from imitation.policies.base import RandomPolicy from imitation.testing import reward_improvement @@ -76,7 +79,7 @@ def test_density_reward( sample_until=sample_until, rng=rng, ) - expert_trajectories_test = pendulum_expert_trajectories[n_experts // 2 :] + expert_trajectories_test = pendulum_expert_trajectories[n_experts // 2:] random_returns = score_trajectories(random_trajectories, reward_fn) expert_returns = score_trajectories(expert_trajectories_test, reward_fn) assert reward_improvement.is_significant_reward_improvement( @@ -120,12 +123,6 @@ def test_density_with_other_trajectory_types( rollouts = pendulum_expert_trajectories[:2] transitions = rollout.flatten_trajectories_with_rew(rollouts) transitions_mappings = [cast(types.TransitionMapping, asdict(transitions))] - - minimal_transitions = types.TransitionsMinimal( - obs=transitions.obs, - acts=transitions.acts, - infos=transitions.infos, - ) d = DensityAlgorithm( demonstrations=transitions_mappings, venv=pendulum_venv, @@ -136,6 +133,11 @@ def test_density_with_other_trajectory_types( d.train_policy(n_timesteps=2) d.test_policy(n_trajectories=2) + minimal_transitions = types.TransitionsMinimal( + obs=transitions.obs, + acts=transitions.acts, + infos=transitions.infos, + ) d = DensityAlgorithm( demonstrations=minimal_transitions, venv=pendulum_venv, @@ -168,3 +170,47 @@ def test_density_trainer_raises( with pytest.raises(TypeError, match="Unsupported demonstration type"): density_trainer.set_demonstrations("foo") # type: ignore[arg-type] + + +# TODO: remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 merged +class FloatReward(gym.RewardWrapper): + """Typecasts reward to a float.""" + + def reward(self, reward): + return float(reward) + + +@parametrize_density_stationary +def test_dict_space(density_type, is_stationary): + def make_env(): + env = sb_envs.SimpleMultiObsEnv(channel_last=False) + env = FloatReward(env) + return RolloutInfoWrapper(env) + + venv = vec_env.DummyVecEnv([make_env, make_env]) + + # multi-input policy to accept dict observations + rl_algo = stable_baselines3.PPO(policies.MultiInputActorCriticPolicy, venv) + rng = np.random.default_rng() + + # sample random transitions + rollouts = rollout.rollout( + policy=None, + venv=venv, + sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, + unwrap=True, + ) + density_trainer = DensityAlgorithm( + demonstrations=rollouts, + density_type=density_type, + kernel="gaussian", + venv=venv, + is_stationary=is_stationary, + rl_algo=rl_algo, + kernel_bandwidth=0.2, + standardise_inputs=True, + rng=rng, + ) + # confirm that training works + density_trainer.train() From be9798b99fe683a5c0dfb847c2d8ea1a92f42cfa Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 16:13:35 -0700 Subject: [PATCH 50/85] black format --- tests/algorithms/test_density_baselines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index 16048b651..ec417722a 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -79,7 +79,7 @@ def test_density_reward( sample_until=sample_until, rng=rng, ) - expert_trajectories_test = pendulum_expert_trajectories[n_experts // 2:] + expert_trajectories_test = pendulum_expert_trajectories[n_experts // 2 :] random_returns = score_trajectories(random_trajectories, reward_fn) expert_returns = score_trajectories(expert_trajectories_test, reward_fn) assert reward_improvement.is_significant_reward_improvement( From c7e680927905495cf0346f1080a0d8fe73d2fe23 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 16:35:22 -0700 Subject: [PATCH 51/85] small fix --- src/imitation/data/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index d7e1a70f8..ff3ccd759 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -493,7 +493,7 @@ def generate_trajectories( assert venv.observation_space.shape is not None exp_obs = ( (n_steps + 1,) - + venv.observation_space.shape, # type: ignore[assignment] + + venv.observation_space.shape # type: ignore[assignment] ) real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" From 82fb558f764fc38de17ef1480b19ca73e6088c05 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 21:27:40 -0700 Subject: [PATCH 52/85] Add DictObs into test_wrappers --- src/imitation/data/wrappers.py | 2 + src/imitation/rewards/reward_wrapper.py | 11 +++- tests/data/test_wrappers.py | 77 +++++++++++++++++++------ 3 files changed, 68 insertions(+), 22 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 5cd72c984..94c88111d 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -53,9 +53,11 @@ def reset(self, **kwargs): self.n_transitions = 0 obs = self.venv.reset(**kwargs) self._traj_accum = rollout.TrajectoryAccumulator() + obs = types.maybe_wrap_in_dictobs(obs) for i, ob in enumerate(obs): self._traj_accum.add_step({"obs": ob}, key=i) self._timesteps = np.zeros((len(obs),), dtype=int) + obs = types.maybe_unwrap_dictobs(obs) return obs def step_async(self, actions): diff --git a/src/imitation/rewards/reward_wrapper.py b/src/imitation/rewards/reward_wrapper.py index 7afa551b3..37415e916 100644 --- a/src/imitation/rewards/reward_wrapper.py +++ b/src/imitation/rewards/reward_wrapper.py @@ -8,6 +8,7 @@ from stable_baselines3.common import logger as sb_logger from stable_baselines3.common import vec_env +from imitation.data import types from imitation.rewards import reward_function @@ -95,14 +96,17 @@ def step_wait(self): # encounter a `done`, in which case the last observation corresponding to # the `done` is dropped. We're going to pull it back out of the info dict! obs_fixed = [] + obs = types.maybe_wrap_in_dictobs(obs) for single_obs, single_done, single_infos in zip(obs, dones, infos): if single_done: single_obs = single_infos["terminal_observation"] - obs_fixed.append(single_obs) - obs_fixed = np.stack(obs_fixed) + obs_fixed.append(types.maybe_wrap_in_dictobs(single_obs)) + obs_fixed = types.DictObs.stack(obs_fixed) if isinstance( + obs, types.DictObs) else np.stack(obs_fixed) - rews = self.reward_fn(self._old_obs, self._actions, obs_fixed, np.array(dones)) + rews = self.reward_fn(self._old_obs, self._actions, + types.maybe_unwrap_dictobs(obs_fixed), np.array(dones)) assert len(rews) == len(obs), "must return one rew for each env" done_mask = np.asarray(dones, dtype="bool").reshape((len(dones),)) @@ -116,6 +120,7 @@ def step_wait(self): # we can just use obs instead of obs_fixed because on the next iteration # after a reset we DO want to access the first observation of the new # trajectory, not the last observation of the old trajectory + obs = types.maybe_unwrap_dictobs(obs) self._old_obs = obs for info_dict, old_rew in zip(infos, old_rews): info_dict["original_env_rew"] = old_rew diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index b83339167..d51f81b69 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,6 +1,6 @@ """Tests for `imitation.data.wrappers`.""" -from typing import List, Sequence +from typing import List, Sequence, Type import gymnasium as gym import numpy as np @@ -48,10 +48,40 @@ def step(self, action): return t, t * 10, done, False, {} +class _CountingDictEnv(_CountingEnv): # pragma: no cover + """Similar to _CountingEnv, but with Dict observation.""" + + def __init__(self, episode_length=5): + super().__init__(episode_length) + self.observation_space = gym.spaces.Dict( + spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())} + ) + + def reset(self, seed=None): + t, self.timestep = 0.0, 1.0 + return {"t": t}, {} + + def step(self, action): + if self.timestep is None: + raise RuntimeError("Need to reset before first step().") + if self.timestep > self.episode_length: + raise RuntimeError("Episode is over. Need to step().") + if np.array(action) not in self.action_space: + raise ValueError(f"Invalid action {action}") + + t, self.timestep = self.timestep, self.timestep + 1 + done = t == self.episode_length + return {"t": t}, t * 10, done, False, {} + + +Envs = [_CountingEnv, _CountingDictEnv] + + def _make_buffering_venv( + Env: Type[gym.Env], error_on_premature_reset: bool, ) -> BufferingWrapper: - venv = DummyVecEnv([_CountingEnv] * 2) + venv = DummyVecEnv([Env] * 2) wrapped_venv = BufferingWrapper(venv, error_on_premature_reset) wrapped_venv.reset() return wrapped_venv @@ -86,10 +116,12 @@ def concat(x): ) +@pytest.mark.parametrize("Env", Envs) @pytest.mark.parametrize("episode_lengths", [(1,), (6, 5, 1, 2), (2, 2)]) @pytest.mark.parametrize("n_steps", [1, 2, 20, 21]) @pytest.mark.parametrize("extra_pop_timesteps", [(), (1,), (4, 8)]) def test_pop( + Env: Type[gym.Env], episode_lengths: Sequence[int], n_steps: int, extra_pop_timesteps: Sequence[int], @@ -141,7 +173,7 @@ def test_pop( pytest.skip("pop timesteps out of bounds for this test case") def make_env(ep_len): - return lambda: _CountingEnv(episode_length=ep_len) + return lambda: Env(episode_length=ep_len) venv = DummyVecEnv([make_env(ep_len) for ep_len in episode_lengths]) venv_buffer = BufferingWrapper(venv) @@ -153,10 +185,13 @@ def make_env(ep_len): # Initial observation (only matters for pop_transitions()). obs = venv_buffer.reset() - np.testing.assert_array_equal(obs, [0] * venv.num_envs) + if Env == _CountingEnv: + np.testing.assert_array_equal(obs, [0] * venv.num_envs) + else: + np.testing.assert_array_equal(obs["t"], [0] * venv.num_envs) for t in range(1, n_steps + 1): - acts = obs * 2.1 + acts = obs * 2.1 if Env == _CountingEnv else obs["t"] * 2.1 venv_buffer.step_async(acts) obs, *_ = venv_buffer.step_wait() @@ -179,32 +214,35 @@ def make_env(ep_len): # Check `pop_transitions()` trans = _join_transitions(transitions_list) - - _assert_equal_scrambled_vectors(types.assert_not_dictobs(trans.obs), expect_obs) - _assert_equal_scrambled_vectors( - types.assert_not_dictobs(trans.next_obs), - expect_next_obs, - ) + if Env == _CountingEnv: + actual_obs = types.assert_not_dictobs(trans.obs) + actual_next_obs = types.assert_not_dictobs(trans.next_obs) + else: + actual_obs = types.DictObs.stack(trans.obs).get("t") + actual_next_obs = types.DictObs.stack(trans.next_obs).get("t") + _assert_equal_scrambled_vectors(actual_obs, expect_obs) + _assert_equal_scrambled_vectors(actual_next_obs, expect_next_obs) _assert_equal_scrambled_vectors(trans.acts, expect_acts) _assert_equal_scrambled_vectors(trans.rews, expect_rews) -def test_reset_error(): +@pytest.mark.parametrize("Env", Envs) +def test_reset_error(Env: Type[gym.Env]): # Resetting before a `step()` is okay. for flag in [True, False]: - venv = _make_buffering_venv(flag) + venv = _make_buffering_venv(Env, flag) for _ in range(10): venv.reset() # Resetting after a `step()` is not okay if error flag is True. - venv = _make_buffering_venv(True) + venv = _make_buffering_venv(Env, True) zeros = np.array([0.0, 0.0], dtype=venv.action_space.dtype) venv.step(zeros) with pytest.raises(RuntimeError, match="before samples were accessed"): venv.reset() # Same as previous case, but insert a `pop_transitions()` in between. - venv = _make_buffering_venv(True) + venv = _make_buffering_venv(Env, True) venv.step(zeros) venv.pop_transitions() venv.step(zeros) @@ -212,20 +250,21 @@ def test_reset_error(): venv.reset() # Resetting after a `step()` is ok if error flag is False. - venv = _make_buffering_venv(False) + venv = _make_buffering_venv(Env, False) venv.step(zeros) venv.reset() # Resetting after a `step()` is ok if transitions are first collected. for flag in [True, False]: - venv = _make_buffering_venv(flag) + venv = _make_buffering_venv(Env, flag) venv.step(zeros) venv.pop_transitions() venv.reset() -def test_n_transitions_and_empty_error(): - venv = _make_buffering_venv(True) +@pytest.mark.parametrize("Env", Envs) +def test_n_transitions_and_empty_error(Env: Type[gym.Env]): + venv = _make_buffering_venv(Env, True) trajs, ep_lens = venv.pop_trajectories() assert trajs == [] assert ep_lens == [] From 03714cc33e9d9304bfd301ab1639eece9971bc98 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 22:46:34 -0700 Subject: [PATCH 53/85] fix format --- src/imitation/data/rollout.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index ff3ccd759..9d989a830 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -492,9 +492,8 @@ def generate_trajectories( else: assert venv.observation_space.shape is not None exp_obs = ( - (n_steps + 1,) - + venv.observation_space.shape # type: ignore[assignment] - ) + n_steps + 1, + ) + venv.observation_space.shape # type: ignore[assignment] real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" assert venv.action_space.shape is not None From 187e88166c9aa0478db134e945570b417393337f Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 22:58:59 -0700 Subject: [PATCH 54/85] minor fix --- src/imitation/algorithms/density.py | 6 ++++-- tests/algorithms/test_bc.py | 3 ++- tests/algorithms/test_dagger.py | 2 -- tests/algorithms/test_density_baselines.py | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index 4f1cdbdca..3c00f68f7 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -296,9 +296,9 @@ def _preprocess_transition( def __call__( self, - state: np.ndarray, + state: types.Observation, action: np.ndarray, - next_state: np.ndarray, + next_state: types.Observation, done: np.ndarray, steps: Optional[np.ndarray] = None, ) -> np.ndarray: @@ -334,6 +334,8 @@ def __call__( rew_list = [] assert len(state) == len(action) and len(state) == len(next_state) + state = types.maybe_wrap_in_dictobs(state) + next_state = types.maybe_wrap_in_dictobs(next_state) for idx, (obs, act, next_obs) in enumerate(zip(state, action, next_state)): flat_trans = self._preprocess_transition(obs, act, next_obs) assert self._scaler is not None diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 43898b008..ed9c5a347 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -291,7 +291,8 @@ def test_that_policy_reconstruction_preserves_parameters( th.testing.assert_close(original, reconstructed) -# TODO: remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 merged +# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 +# merged and released. class FloatReward(gym.RewardWrapper): """Typecasts reward to a float.""" diff --git a/tests/algorithms/test_dagger.py b/tests/algorithms/test_dagger.py index c4d9b69c6..01c1c5088 100644 --- a/tests/algorithms/test_dagger.py +++ b/tests/algorithms/test_dagger.py @@ -20,8 +20,6 @@ from imitation.testing import reward_improvement from imitation.util import util -os.environ["CUDA_VISIBLE_DEVICES"] = "" - @pytest.fixture(params=[True, False]) def maybe_pendulum_expert_trajectories( diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index ec417722a..a35755965 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -172,7 +172,8 @@ def test_density_trainer_raises( density_trainer.set_demonstrations("foo") # type: ignore[arg-type] -# TODO: remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 merged +# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 +# merged and released. class FloatReward(gym.RewardWrapper): """Typecasts reward to a float.""" From ae965210f940e1195c16bc2d1e426f78d9b0b5c0 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Mon, 2 Oct 2023 23:29:33 -0700 Subject: [PATCH 55/85] type and lint fix --- src/imitation/algorithms/density.py | 6 +----- tests/algorithms/test_density_baselines.py | 7 ++++++- tests/data/test_wrappers.py | 3 ++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index 3c00f68f7..9ee2fb495 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -147,11 +147,7 @@ def _get_demo_from_batch( assert act_b.shape[1:] == self.venv.action_space.shape if isinstance(obs_b, types.DictObs): - exp_shape = { - k: v.shape - # type: ignore[attr-defined] - for k, v in self.venv.observation_space.items() - } + exp_shape = {k: v.shape for k, v in self.venv.observation_space.items()} # type: ignore[attr-defined] obs_shape = {k: v.shape[1:] for k, v in obs_b.items()} assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}" else: diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index a35755965..3eed1f040 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -191,7 +191,12 @@ def make_env(): venv = vec_env.DummyVecEnv([make_env, make_env]) # multi-input policy to accept dict observations - rl_algo = stable_baselines3.PPO(policies.MultiInputActorCriticPolicy, venv) + rl_algo = stable_baselines3.PPO( + policies.MultiInputActorCriticPolicy, + venv, + n_steps=10, # small value to make test faster + n_epochs=2, # small value to make test faster + ) rng = np.random.default_rng() # sample random transitions diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index d51f81b69..33677c68f 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -54,7 +54,7 @@ class _CountingDictEnv(_CountingEnv): # pragma: no cover def __init__(self, episode_length=5): super().__init__(episode_length) self.observation_space = gym.spaces.Dict( - spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())} + spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, ) def reset(self, seed=None): @@ -151,6 +151,7 @@ def test_pop( ``` Args: + Env: Environment class type. episode_lengths: The number of timesteps before episode end in each dummy environment. n_steps: Number of times to call `step()` on the dummy environment. From 535a986be1ee578737e20d76b6f43a4b68563ab6 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 00:22:01 -0700 Subject: [PATCH 56/85] Add policy training test --- src/imitation/rewards/reward_wrapper.py | 16 +++++++++++----- tests/algorithms/test_density_baselines.py | 13 +++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/imitation/rewards/reward_wrapper.py b/src/imitation/rewards/reward_wrapper.py index 37415e916..b7db34a5c 100644 --- a/src/imitation/rewards/reward_wrapper.py +++ b/src/imitation/rewards/reward_wrapper.py @@ -102,11 +102,17 @@ def step_wait(self): single_obs = single_infos["terminal_observation"] obs_fixed.append(types.maybe_wrap_in_dictobs(single_obs)) - obs_fixed = types.DictObs.stack(obs_fixed) if isinstance( - obs, types.DictObs) else np.stack(obs_fixed) - - rews = self.reward_fn(self._old_obs, self._actions, - types.maybe_unwrap_dictobs(obs_fixed), np.array(dones)) + obs_fixed = ( + types.DictObs.stack(obs_fixed) + if isinstance(obs, types.DictObs) + else np.stack(obs_fixed) + ) + rews = self.reward_fn( + self._old_obs, + self._actions, + types.maybe_unwrap_dictobs(obs_fixed), + np.array(dones), + ) assert len(rews) == len(obs), "must return one rew for each env" done_mask = np.asarray(dones, dtype="bool").reshape((len(dones),)) diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index 3eed1f040..b72970a79 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -181,8 +181,7 @@ def reward(self, reward): return float(reward) -@parametrize_density_stationary -def test_dict_space(density_type, is_stationary): +def test_dict_space(): def make_env(): env = sb_envs.SimpleMultiObsEnv(channel_last=False) env = FloatReward(env) @@ -200,23 +199,25 @@ def make_env(): rng = np.random.default_rng() # sample random transitions + sample_until = rollout.make_min_episodes(15) rollouts = rollout.rollout( policy=None, venv=venv, - sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), + sample_until=sample_until, rng=rng, - unwrap=True, ) density_trainer = DensityAlgorithm( demonstrations=rollouts, - density_type=density_type, kernel="gaussian", venv=venv, - is_stationary=is_stationary, rl_algo=rl_algo, kernel_bandwidth=0.2, standardise_inputs=True, rng=rng, + # SimpleMultiObsEnv has early stop (issue #40) + allow_variable_horizon=True, ) # confirm that training works density_trainer.train() + density_trainer.train_policy(n_timesteps=2) + density_trainer.test_policy(n_trajectories=2) From de027c4b7b596938bdad7999c530c5be9b008446 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 07:36:25 -0700 Subject: [PATCH 57/85] suppress line too long lint check on a line --- src/imitation/algorithms/density.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index 9ee2fb495..a8bfceb7b 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -147,7 +147,8 @@ def _get_demo_from_batch( assert act_b.shape[1:] == self.venv.action_space.shape if isinstance(obs_b, types.DictObs): - exp_shape = {k: v.shape for k, v in self.venv.observation_space.items()} # type: ignore[attr-defined] + exp_shape = {k: v.shape for k, v in self.venv.observation_space.items()} # type: ignore[attr-defined] # noqa: E501 + obs_shape = {k: v.shape[1:] for k, v in obs_b.items()} assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}" else: From be79cf5614548e1e780dc06fa1648732bb44fb9d Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 10:31:50 -0700 Subject: [PATCH 58/85] acts to obs for clarity --- src/imitation/algorithms/dagger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index fb68713e6..67bb382ca 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -508,7 +508,7 @@ def create_trajectory_collector(self) -> InteractiveTrajectoryCollector: beta = self.beta_schedule(self.round_num) collector = InteractiveTrajectoryCollector( venv=self.venv, - get_robot_acts=lambda acts: self.bc_trainer.policy.predict(acts)[0], + get_robot_acts=lambda obs: self.bc_trainer.policy.predict(obs)[0], beta=beta, save_dir=save_dir, rng=self.rng, From 6e5c3e8d695c0c29b721672dbaf3dabae628b23f Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 12:18:06 -0700 Subject: [PATCH 59/85] Add HumanReadableWrapper --- src/imitation/data/wrappers.py | 61 ++++++++++++++++++++++++++++++++-- tests/data/test_wrappers.py | 43 ++++++++++++++++++++---- 2 files changed, 95 insertions(+), 9 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 94c88111d..83e4dd0cc 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -1,14 +1,18 @@ """Environment wrappers for collecting rollouts.""" -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Dict, Union import gymnasium as gym +from gymnasium.core import Env import numpy as np import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper from imitation.data import rollout, types +# The key for human readable data in the observation. +HR_OBS_KEY = "HR_OBS" + class BufferingWrapper(VecEnvWrapper): """Saves transitions of underlying VecEnv. @@ -170,7 +174,7 @@ def pop_transitions(self) -> types.TransitionsWithRew: class RolloutInfoWrapper(gym.Wrapper): - """Add the entire episode's rewards and observations to `info` at episode end. + """Adds the entire episode's rewards and observations to `info` at episode end. Whenever done=True, `info["rollouts"]` is a dict with keys "obs" and "rews", whose corresponding values hold the NumPy arrays containing the raw observations and @@ -206,3 +210,56 @@ def step(self, action): "rews": np.stack(self._rews), } return obs, rew, terminated, truncated, info + + +class HumanReadableWrapper(gym.Wrapper): + """Adds human-readable observation to `obs` at every step.""" + + def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): + """Builds HumanReadableWrapper + + Args: + env: Environment to wrap. + original_obs_key: The key for original observation if the original + observation is not in dict format. + """ + env.render_mode = "rgb_array" + self._original_obs_key = original_obs_key + super().__init__(env) + + def _add_hr_obs( + self, obs: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + """Adds human-readable observation to obs. + + Transforms obs into dictionary if it is not already, and adds the human-readable + observation from `env.render()` under the key HR_OBS_KEY. + + Args: + obs: Observation from environment. + + Returns: + Observation dictionary with the human-readable data + + Raises: + KeyError: When the key HR_OBS_KEY already exists in the observation + dictionary. + """ + assert self.env.render_mode is not None + assert self.env.render_mode == "rgb_array" + hr_obs = self.env.render() + if not isinstance(obs, Dict): + obs = {self._original_obs_key: obs} + + if HR_OBS_KEY in obs: + raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict") + obs[HR_OBS_KEY] = hr_obs + return obs + + def reset(self, **kwargs): + obs, info = super().reset(**kwargs) + return self._add_hr_obs(obs), info + + def step(self, action): + obs, rew, terminated, truncated, info = self.env.step(action) + return self._add_hr_obs(obs), rew, terminated, truncated, info diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 33677c68f..cfde9dbcc 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,6 +1,6 @@ """Tests for `imitation.data.wrappers`.""" -from typing import List, Sequence, Type +from typing import List, Sequence, Type, Dict, Optional import gymnasium as gym import numpy as np @@ -8,7 +8,11 @@ from stable_baselines3.common.vec_env import DummyVecEnv from imitation.data import types -from imitation.data.wrappers import BufferingWrapper +from imitation.data.wrappers import ( + BufferingWrapper, + HumanReadableWrapper, + HR_OBS_KEY, +) class _CountingEnv(gym.Env): # pragma: no cover @@ -31,7 +35,7 @@ def __init__(self, episode_length=5): self.episode_length = episode_length self.timestep = None - def reset(self, seed=None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): t, self.timestep = 0, 1 return t, {} @@ -47,6 +51,9 @@ def step(self, action): done = t == self.episode_length return t, t * 10, done, False, {} + def render(self) -> np.ndarray: + return np.array([self.timestep] * 10) + class _CountingDictEnv(_CountingEnv): # pragma: no cover """Similar to _CountingEnv, but with Dict observation.""" @@ -57,9 +64,9 @@ def __init__(self, episode_length=5): spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, ) - def reset(self, seed=None): - t, self.timestep = 0.0, 1.0 - return {"t": t}, {} + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + t, self.timestep = 0, 1 + return {"t": t, "2t": 2 * t}, {} def step(self, action): if self.timestep is None: @@ -71,7 +78,7 @@ def step(self, action): t, self.timestep = self.timestep, self.timestep + 1 done = t == self.episode_length - return {"t": t}, t * 10, done, False, {} + return {"t": t, "2t": 2 * t}, t * 10, done, False, {} Envs = [_CountingEnv, _CountingDictEnv] @@ -278,3 +285,25 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): assert venv.n_transitions == 0 with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() + + +@pytest.mark.parametrize("Env", Envs) +@pytest.mark.parametrize("original_obs_key", ["k1", "k2"]) +def test_human_readable_wrapper(Env: Type[gym.Env], original_obs_key: str): + num_obs_key_expected = 2 if Env == _CountingEnv else 3 + origin_obs_key = original_obs_key if Env == _CountingEnv else "t" + env = HumanReadableWrapper(Env(), original_obs_key=original_obs_key) + + obs, _ = env.reset() + assert isinstance(obs, Dict) + assert HR_OBS_KEY in obs + assert len(obs) == num_obs_key_expected + assert obs[origin_obs_key] == 0 + _assert_equal_scrambled_vectors(obs[HR_OBS_KEY], np.array([1] * 10)) + + next_obs, *_ = env.step(env.action_space.sample()) + assert isinstance(next_obs, Dict) + assert HR_OBS_KEY in next_obs + assert len(next_obs) == num_obs_key_expected + assert next_obs[origin_obs_key] == 1 + _assert_equal_scrambled_vectors(next_obs[HR_OBS_KEY], np.array([2] * 10)) From ba6a6a7eb5191bac4c5f70f12313185b15e3c350 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 13:08:01 -0700 Subject: [PATCH 60/85] fix dict env observation space --- tests/data/test_wrappers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index cfde9dbcc..279c72079 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -61,7 +61,10 @@ class _CountingDictEnv(_CountingEnv): # pragma: no cover def __init__(self, episode_length=5): super().__init__(episode_length) self.observation_space = gym.spaces.Dict( - spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, + spaces={ + "t": gym.spaces.Box(low=0, high=np.inf, shape=()), + "2t": gym.spaces.Box(low=0, high=np.inf, shape=()), + }, ) def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): From a9b32bdc3fdd8be9f0b50411cfcd1107316983c8 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 14:43:12 -0700 Subject: [PATCH 61/85] adjust wrapper and not set render_mode inside --- src/imitation/data/wrappers.py | 12 +++++++++--- tests/data/test_wrappers.py | 29 +++++++++++++++++++---------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 83e4dd0cc..beeb75a29 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -222,8 +222,16 @@ def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): env: Environment to wrap. original_obs_key: The key for original observation if the original observation is not in dict format. + + Raises: + ValueError: If `env.render_mode` is not "rgb_array". + """ - env.render_mode = "rgb_array" + if env.render_mode != "rgb_array": + raise ValueError( + "HumanReadableWrapper requires render_mode='rgb_array', " + f"got {env.render_mode!r}" + ) self._original_obs_key = original_obs_key super().__init__(env) @@ -245,8 +253,6 @@ def _add_hr_obs( KeyError: When the key HR_OBS_KEY already exists in the observation dictionary. """ - assert self.env.render_mode is not None - assert self.env.render_mode == "rgb_array" hr_obs = self.env.render() if not isinstance(obs, Dict): obs = {self._original_obs_key: obs} diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 279c72079..3b2f4a8ac 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -28,12 +28,17 @@ class _CountingEnv(gym.Env): # pragma: no cover ``` """ - def __init__(self, episode_length=5): + def __init__(self, episode_length: int = 5, render_mode: Optional[str] = None): assert episode_length >= 1 self.observation_space = gym.spaces.Box(low=0, high=np.inf, shape=()) self.action_space = gym.spaces.Box(low=0, high=np.inf, shape=()) self.episode_length = episode_length - self.timestep = None + self.timestep: Optional[int] = None + self._render_mode = render_mode + + @property + def render_mode(self): + return self._render_mode def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): t, self.timestep = 0, 1 @@ -52,14 +57,16 @@ def step(self, action): return t, t * 10, done, False, {} def render(self) -> np.ndarray: + if self._render_mode != "rgb_array": + raise ValueError(f"Invalid render mode {self._render_mode}") return np.array([self.timestep] * 10) class _CountingDictEnv(_CountingEnv): # pragma: no cover """Similar to _CountingEnv, but with Dict observation.""" - def __init__(self, episode_length=5): - super().__init__(episode_length) + def __init__(self, episode_length: int = 5, render_mode: Optional[str] = None): + super().__init__(episode_length, render_mode) self.observation_space = gym.spaces.Dict( spaces={ "t": gym.spaces.Box(low=0, high=np.inf, shape=()), @@ -88,7 +95,7 @@ def step(self, action): def _make_buffering_venv( - Env: Type[gym.Env], + Env: _CountingEnv, error_on_premature_reset: bool, ) -> BufferingWrapper: venv = DummyVecEnv([Env] * 2) @@ -131,7 +138,7 @@ def concat(x): @pytest.mark.parametrize("n_steps", [1, 2, 20, 21]) @pytest.mark.parametrize("extra_pop_timesteps", [(), (1,), (4, 8)]) def test_pop( - Env: Type[gym.Env], + Env: _CountingEnv, episode_lengths: Sequence[int], n_steps: int, extra_pop_timesteps: Sequence[int], @@ -238,7 +245,7 @@ def make_env(ep_len): @pytest.mark.parametrize("Env", Envs) -def test_reset_error(Env: Type[gym.Env]): +def test_reset_error(Env: _CountingEnv): # Resetting before a `step()` is okay. for flag in [True, False]: venv = _make_buffering_venv(Env, flag) @@ -274,7 +281,7 @@ def test_reset_error(Env: Type[gym.Env]): @pytest.mark.parametrize("Env", Envs) -def test_n_transitions_and_empty_error(Env: Type[gym.Env]): +def test_n_transitions_and_empty_error(Env: _CountingEnv): venv = _make_buffering_venv(Env, True) trajs, ep_lens = venv.pop_trajectories() assert trajs == [] @@ -292,10 +299,12 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): @pytest.mark.parametrize("Env", Envs) @pytest.mark.parametrize("original_obs_key", ["k1", "k2"]) -def test_human_readable_wrapper(Env: Type[gym.Env], original_obs_key: str): +def test_human_readable_wrapper(Env: _CountingEnv, original_obs_key: str): num_obs_key_expected = 2 if Env == _CountingEnv else 3 origin_obs_key = original_obs_key if Env == _CountingEnv else "t" - env = HumanReadableWrapper(Env(), original_obs_key=original_obs_key) + env = HumanReadableWrapper( + Env(render_mode="rgb_array"), original_obs_key=original_obs_key + ) obs, _ = env.reset() assert isinstance(obs, Dict) From 77eab66549f028435bd7a64ec93859ea6a817283 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Tue, 3 Oct 2023 18:12:49 -0700 Subject: [PATCH 62/85] Add additional obs check --- tests/data/test_rollout.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index f323446ba..5ba70855d 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -423,4 +423,5 @@ def test_dictionary_observations(rng): ) for traj in trajs: assert isinstance(traj.obs, types.DictObs) + assert venv.observation_space.contains(obs) np.testing.assert_allclose(traj.obs.get("a") / 2, traj.obs.get("b")) From 194ec1ac336d5ecdb36a5df92d53e056458291be Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Tue, 3 Oct 2023 18:14:43 -0700 Subject: [PATCH 63/85] Upgrade pytype and remove workaround for old versions --- setup.py | 2 +- tests/algorithms/test_base.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 5fc3354ad..a461ced36 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ ATARI_REQUIRE = [ "seals[atari]~=0.2.1", ] -PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] +PYTYPE = ["pytype==2023.9.27"] if IS_NOT_WINDOWS else [] # Note: the versions of the test and doc requirements should be tightly pinned to known # working versions to make our CI/CD pipeline as stable as possible. diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 1654a1482..802ac0d7f 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -41,10 +41,7 @@ def test_check_fixed_horizon_flag(custom_logger): def _make_and_iterate_loader(*args, **kwargs): - # our pytype version doesn't understand optional arguments in TypedDict - # this is fixed in 2023.04.11, but we require 2022.7.26 - # See https://github.com/google/pytype/issues/1195 - loader = base.make_data_loader(*args, **kwargs) # pytype: disable=wrong-arg-types + loader = base.make_data_loader(*args, **kwargs) for batch in loader: pass From 44b357e4cd900fb7d1eaeea490a05b05ec2ba74b Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Tue, 3 Oct 2023 19:07:42 -0700 Subject: [PATCH 64/85] Fix test_rollout test --- tests/data/test_rollout.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 5ba70855d..c8c8cd021 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -423,5 +423,6 @@ def test_dictionary_observations(rng): ) for traj in trajs: assert isinstance(traj.obs, types.DictObs) - assert venv.observation_space.contains(obs) + for obs in traj.obs: + assert venv.observation_space.contains(dict(obs.items())) np.testing.assert_allclose(traj.obs.get("a") / 2, traj.obs.get("b")) From ee83ec54419bea4508ab2fbf0d6c28e8f63f7e8f Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 22:00:38 -0700 Subject: [PATCH 65/85] add RemoveHumanReadableWrapper and update ob space --- .../train_dagger_atari_interactive_policy.py | 16 ++-- src/imitation/algorithms/bc.py | 3 +- src/imitation/algorithms/dagger.py | 34 +++++++-- src/imitation/data/wrappers.py | 73 +++++++++++++++++++ src/imitation/policies/interactive.py | 14 ++-- 5 files changed, 123 insertions(+), 17 deletions(-) diff --git a/examples/train_dagger_atari_interactive_policy.py b/examples/train_dagger_atari_interactive_policy.py index bb32f7194..f30b6e08b 100644 --- a/examples/train_dagger_atari_interactive_policy.py +++ b/examples/train_dagger_atari_interactive_policy.py @@ -10,25 +10,29 @@ from stable_baselines3.common import vec_env from imitation.algorithms import bc, dagger +from imitation.data import wrappers from imitation.policies import interactive if __name__ == "__main__": rng = np.random.default_rng(0) - env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)]) - env.seed(0) + env = gym.make("PongNoFrameskip-v4", render_mode="rgb_array") + env = wrappers.HumanReadableWrapper(env) + venv = vec_env.DummyVecEnv([lambda: env]) + venv.seed(0) - expert = interactive.AtariInteractivePolicy(env) + expert = interactive.AtariInteractivePolicy(venv) + venv_with_no_rgb = wrappers.RemoveHumanReadableWrapper(venv) bc_trainer = bc.BC( - observation_space=env.observation_space, - action_space=env.action_space, + observation_space=venv_with_no_rgb.observation_space, + action_space=venv_with_no_rgb.action_space, rng=rng, ) with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: dagger_trainer = dagger.SimpleDAggerTrainer( - venv=env, + venv=venv, scratch_dir=tmpdir, expert_policy=expert, bc_trainer=bc_trainer, diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 29f12588d..71d093d21 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -26,7 +26,7 @@ from stable_baselines3.common import policies, torch_layers, utils, vec_env from imitation.algorithms import base as algo_base -from imitation.data import rollout, types +from imitation.data import rollout, types, wrappers from imitation.policies import base as policy_base from imitation.util import logger as imit_logger from imitation.util import util @@ -491,6 +491,7 @@ def process_batch(): lambda x: util.safe_to_tensor(x, device=self.policy.device), types.maybe_unwrap_dictobs(batch["obs"]), ) + obs_tensor = wrappers.remove_rgb_obs(obs_tensor) acts = util.safe_to_tensor(batch["acts"], device=self.policy.device) training_metrics = self.loss_calculator(self.policy, obs_tensor, acts) diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index 67bb382ca..0e55738ad 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -11,11 +11,13 @@ import os import pathlib import uuid -from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch as th +from gymnasium import spaces from stable_baselines3.common import policies, utils, vec_env +from stable_baselines3.common.type_aliases import GymEnv from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn from torch.utils import data as th_data @@ -23,6 +25,7 @@ from imitation.data import rollout, serialize, types from imitation.util import logger as imit_logger from imitation.util import util +from imitation.data import wrappers class BetaSchedule(abc.ABC): @@ -213,7 +216,7 @@ def seed(self, seed: Optional[int] = None) -> List[Optional[int]]: self.rng = np.random.default_rng(seed=seed) return list(self.venv.seed(seed)) - def reset(self) -> np.ndarray: + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """Resets the environment. Returns: @@ -221,12 +224,14 @@ def reset(self) -> np.ndarray: """ self.traj_accum = rollout.TrajectoryAccumulator() obs = self.venv.reset() - assert isinstance(obs, np.ndarray) + obs = types.maybe_wrap_in_dictobs(obs) + assert isinstance(obs, types.DictObs) for i, ob in enumerate(obs): self.traj_accum.add_step({"obs": ob}, key=i) self._last_obs = obs self._is_reset = True self._last_user_actions = None + obs = types.maybe_unwrap_dictobs(obs) return obs def step_async(self, actions: np.ndarray) -> None: @@ -270,7 +275,6 @@ def step_wait(self) -> VecEnvStepReturn: Observation, reward, dones (is terminal?) and info dict. """ next_obs, rews, dones, infos = self.venv.step_wait() - assert isinstance(next_obs, np.ndarray) assert self.traj_accum is not None assert self._last_user_actions is not None self._last_obs = next_obs @@ -291,6 +295,26 @@ class NeedsDemosException(Exception): """Signals demos need to be collected for current round before continuing.""" +def _check_for_correct_spaces_with_rgb_env( + env_might_with_rgb: GymEnv, + obs_space: spaces.Space, + action_space: spaces.Space, +) -> None: + """Checks that whether an environment has the same spaces as provided ones.""" + if isinstance(obs_space, spaces.Dict): + assert wrappers.HR_OBS_KEY not in obs_space.spaces + env_obs_space = wrappers.remove_rgb_obs_space(env_might_with_rgb.observation_space) + if obs_space != env_obs_space: + raise ValueError( + f"Observation spaces do not match: obs {obs_space} != env {env_obs_space}" + ) + env_action_space = env_might_with_rgb.action_space + if action_space != env_action_space: + raise ValueError( + f"Action spaces do not match: obs {action_space} != env {env_action_space}" + ) + + class DAggerTrainer(base.BaseImitationAlgorithm): """DAgger training class with low-level API suitable for interactive human feedback. @@ -361,7 +385,7 @@ def __init__( self._all_demos = [] self.rng = rng - utils.check_for_correct_spaces( + _check_for_correct_spaces_with_rgb_env( self.venv, bc_trainer.observation_space, bc_trainer.action_space, diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index beeb75a29..40d16e702 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -7,6 +7,7 @@ import numpy as np import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper +import torch as th from imitation.data import rollout, types @@ -234,6 +235,25 @@ def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): ) self._original_obs_key = original_obs_key super().__init__(env) + self._update_obs_space() + + def _update_obs_space(self): + # need to reset before render. + self.env.reset() + example_rgb_obs = self.env.render() + new_rgb_space = gym.spaces.Box( + low=0, high=255, shape=example_rgb_obs.shape, dtype=np.uint8 + ) + curr_sapce = self.observation_space + if isinstance(curr_sapce, gym.spaces.Dict): + curr_sapce.spaces[HR_OBS_KEY] = new_rgb_space + else: + self.observation_space = gym.spaces.Dict( + { + HR_OBS_KEY: new_rgb_space, + self._original_obs_key: curr_sapce, + } + ) def _add_hr_obs( self, obs: Union[np.ndarray, Dict[str, np.ndarray]] @@ -269,3 +289,56 @@ def reset(self, **kwargs): def step(self, action): obs, rew, terminated, truncated, info = self.env.step(action) return self._add_hr_obs(obs), rew, terminated, truncated, info + + +def remove_rgb_obs_space(obs_space: gym.Space) -> gym.Space: + """Removes rgb observation space from the observation space.""" + if not isinstance(obs_space, gym.spaces.Dict): + return obs_space + if HR_OBS_KEY not in obs_space.spaces: + return obs_space + new_obs_space = gym.spaces.Dict(obs_space.spaces.copy()) + del new_obs_space.spaces[HR_OBS_KEY] + if len(new_obs_space.spaces) == 1: + # unwrap dictionary structure + return list(new_obs_space.values())[0] + return new_obs_space + + +def remove_rgb_obs( + obs: Union[Dict[str, np.ndarray | th.Tensor], np.ndarray, th.Tensor] +) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """Removes rgb observation from the observation.""" + if not isinstance(obs, dict): + return obs + if HR_OBS_KEY not in obs: + return obs + del obs[HR_OBS_KEY] + if len(obs) == 1: + # unwrap dictionary structure + return list(obs.values())[0] + return obs + + +class RemoveHumanReadableWrapper(VecEnvWrapper): + """A vectorized wrapper for removing human readable observations. + + :param venv: The vectorized environment + """ + + def __init__(self, venv: VecEnv): + assert isinstance( + venv.observation_space, gym.spaces.Dict + ), "RemoveHumanReadableWrapper only works with gym.spaces.Dict space" + assert ( + HR_OBS_KEY in venv.observation_space.spaces + ), f"Observation space must contain {HR_OBS_KEY!r}" + new_obs_space = remove_rgb_obs_space(venv.observation_space) + super().__init__(venv=venv, observation_space=new_obs_space) + + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + return remove_rgb_obs(self.venv.reset()) + + def step_wait(self): + observations, rewards, dones, infos = self.venv.step_wait() + return remove_rgb_obs(observations), rewards, dones, infos diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py index d9934f9a0..947309793 100644 --- a/src/imitation/policies/interactive.py +++ b/src/imitation/policies/interactive.py @@ -11,6 +11,7 @@ from stable_baselines3.common import vec_env import imitation.policies.base as base_policies +from imitation.data import wrappers from imitation.util import util @@ -64,9 +65,6 @@ def _choose_action( if self.clear_screen_on_query: util.clear_screen() - if isinstance(obs, dict): - raise ValueError("Dictionary observations are not supported here") - context = self._render(obs) key = self._get_input_key() self._clean_up(context) @@ -110,9 +108,15 @@ def _render(self, obs: np.ndarray) -> plt.Figure: def _clean_up(self, context: plt.Figure) -> None: plt.close(context) - def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: + def _prepare_obs_image( + self, obs: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> np.ndarray: """Applies any required observation processing to get an image to show.""" - return obs + if not isinstance(obs, Dict): + return obs + if wrappers.HR_OBS_KEY not in obs: + raise KeyError(f"Observation does not contain {wrappers.HR_OBS_KEY!r}") + return obs[wrappers.HR_OBS_KEY] ATARI_ACTION_NAMES_TO_KEYS = { From 27f9dc8e61dc435e3a11e6347121b3dc8dffd6c0 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 22:30:07 -0700 Subject: [PATCH 66/85] Revert "add RemoveHumanReadableWrapper and update ob space" This reverts commit ee83ec54419bea4508ab2fbf0d6c28e8f63f7e8f. --- .../train_dagger_atari_interactive_policy.py | 16 ++-- src/imitation/algorithms/bc.py | 3 +- src/imitation/algorithms/dagger.py | 34 ++------- src/imitation/data/wrappers.py | 73 ------------------- src/imitation/policies/interactive.py | 14 ++-- 5 files changed, 17 insertions(+), 123 deletions(-) diff --git a/examples/train_dagger_atari_interactive_policy.py b/examples/train_dagger_atari_interactive_policy.py index f30b6e08b..bb32f7194 100644 --- a/examples/train_dagger_atari_interactive_policy.py +++ b/examples/train_dagger_atari_interactive_policy.py @@ -10,29 +10,25 @@ from stable_baselines3.common import vec_env from imitation.algorithms import bc, dagger -from imitation.data import wrappers from imitation.policies import interactive if __name__ == "__main__": rng = np.random.default_rng(0) - env = gym.make("PongNoFrameskip-v4", render_mode="rgb_array") - env = wrappers.HumanReadableWrapper(env) - venv = vec_env.DummyVecEnv([lambda: env]) - venv.seed(0) + env = vec_env.DummyVecEnv([lambda: gym.wrappers.TimeLimit(gym.make("Pong-v4"), 10)]) + env.seed(0) - expert = interactive.AtariInteractivePolicy(venv) + expert = interactive.AtariInteractivePolicy(env) - venv_with_no_rgb = wrappers.RemoveHumanReadableWrapper(venv) bc_trainer = bc.BC( - observation_space=venv_with_no_rgb.observation_space, - action_space=venv_with_no_rgb.action_space, + observation_space=env.observation_space, + action_space=env.action_space, rng=rng, ) with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: dagger_trainer = dagger.SimpleDAggerTrainer( - venv=venv, + venv=env, scratch_dir=tmpdir, expert_policy=expert, bc_trainer=bc_trainer, diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index 71d093d21..29f12588d 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -26,7 +26,7 @@ from stable_baselines3.common import policies, torch_layers, utils, vec_env from imitation.algorithms import base as algo_base -from imitation.data import rollout, types, wrappers +from imitation.data import rollout, types from imitation.policies import base as policy_base from imitation.util import logger as imit_logger from imitation.util import util @@ -491,7 +491,6 @@ def process_batch(): lambda x: util.safe_to_tensor(x, device=self.policy.device), types.maybe_unwrap_dictobs(batch["obs"]), ) - obs_tensor = wrappers.remove_rgb_obs(obs_tensor) acts = util.safe_to_tensor(batch["acts"], device=self.policy.device) training_metrics = self.loss_calculator(self.policy, obs_tensor, acts) diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index 0e55738ad..67bb382ca 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -11,13 +11,11 @@ import os import pathlib import uuid -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch as th -from gymnasium import spaces from stable_baselines3.common import policies, utils, vec_env -from stable_baselines3.common.type_aliases import GymEnv from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn from torch.utils import data as th_data @@ -25,7 +23,6 @@ from imitation.data import rollout, serialize, types from imitation.util import logger as imit_logger from imitation.util import util -from imitation.data import wrappers class BetaSchedule(abc.ABC): @@ -216,7 +213,7 @@ def seed(self, seed: Optional[int] = None) -> List[Optional[int]]: self.rng = np.random.default_rng(seed=seed) return list(self.venv.seed(seed)) - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self) -> np.ndarray: """Resets the environment. Returns: @@ -224,14 +221,12 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ self.traj_accum = rollout.TrajectoryAccumulator() obs = self.venv.reset() - obs = types.maybe_wrap_in_dictobs(obs) - assert isinstance(obs, types.DictObs) + assert isinstance(obs, np.ndarray) for i, ob in enumerate(obs): self.traj_accum.add_step({"obs": ob}, key=i) self._last_obs = obs self._is_reset = True self._last_user_actions = None - obs = types.maybe_unwrap_dictobs(obs) return obs def step_async(self, actions: np.ndarray) -> None: @@ -275,6 +270,7 @@ def step_wait(self) -> VecEnvStepReturn: Observation, reward, dones (is terminal?) and info dict. """ next_obs, rews, dones, infos = self.venv.step_wait() + assert isinstance(next_obs, np.ndarray) assert self.traj_accum is not None assert self._last_user_actions is not None self._last_obs = next_obs @@ -295,26 +291,6 @@ class NeedsDemosException(Exception): """Signals demos need to be collected for current round before continuing.""" -def _check_for_correct_spaces_with_rgb_env( - env_might_with_rgb: GymEnv, - obs_space: spaces.Space, - action_space: spaces.Space, -) -> None: - """Checks that whether an environment has the same spaces as provided ones.""" - if isinstance(obs_space, spaces.Dict): - assert wrappers.HR_OBS_KEY not in obs_space.spaces - env_obs_space = wrappers.remove_rgb_obs_space(env_might_with_rgb.observation_space) - if obs_space != env_obs_space: - raise ValueError( - f"Observation spaces do not match: obs {obs_space} != env {env_obs_space}" - ) - env_action_space = env_might_with_rgb.action_space - if action_space != env_action_space: - raise ValueError( - f"Action spaces do not match: obs {action_space} != env {env_action_space}" - ) - - class DAggerTrainer(base.BaseImitationAlgorithm): """DAgger training class with low-level API suitable for interactive human feedback. @@ -385,7 +361,7 @@ def __init__( self._all_demos = [] self.rng = rng - _check_for_correct_spaces_with_rgb_env( + utils.check_for_correct_spaces( self.venv, bc_trainer.observation_space, bc_trainer.action_space, diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 40d16e702..beeb75a29 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -7,7 +7,6 @@ import numpy as np import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper -import torch as th from imitation.data import rollout, types @@ -235,25 +234,6 @@ def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): ) self._original_obs_key = original_obs_key super().__init__(env) - self._update_obs_space() - - def _update_obs_space(self): - # need to reset before render. - self.env.reset() - example_rgb_obs = self.env.render() - new_rgb_space = gym.spaces.Box( - low=0, high=255, shape=example_rgb_obs.shape, dtype=np.uint8 - ) - curr_sapce = self.observation_space - if isinstance(curr_sapce, gym.spaces.Dict): - curr_sapce.spaces[HR_OBS_KEY] = new_rgb_space - else: - self.observation_space = gym.spaces.Dict( - { - HR_OBS_KEY: new_rgb_space, - self._original_obs_key: curr_sapce, - } - ) def _add_hr_obs( self, obs: Union[np.ndarray, Dict[str, np.ndarray]] @@ -289,56 +269,3 @@ def reset(self, **kwargs): def step(self, action): obs, rew, terminated, truncated, info = self.env.step(action) return self._add_hr_obs(obs), rew, terminated, truncated, info - - -def remove_rgb_obs_space(obs_space: gym.Space) -> gym.Space: - """Removes rgb observation space from the observation space.""" - if not isinstance(obs_space, gym.spaces.Dict): - return obs_space - if HR_OBS_KEY not in obs_space.spaces: - return obs_space - new_obs_space = gym.spaces.Dict(obs_space.spaces.copy()) - del new_obs_space.spaces[HR_OBS_KEY] - if len(new_obs_space.spaces) == 1: - # unwrap dictionary structure - return list(new_obs_space.values())[0] - return new_obs_space - - -def remove_rgb_obs( - obs: Union[Dict[str, np.ndarray | th.Tensor], np.ndarray, th.Tensor] -) -> Union[np.ndarray, Dict[str, np.ndarray]]: - """Removes rgb observation from the observation.""" - if not isinstance(obs, dict): - return obs - if HR_OBS_KEY not in obs: - return obs - del obs[HR_OBS_KEY] - if len(obs) == 1: - # unwrap dictionary structure - return list(obs.values())[0] - return obs - - -class RemoveHumanReadableWrapper(VecEnvWrapper): - """A vectorized wrapper for removing human readable observations. - - :param venv: The vectorized environment - """ - - def __init__(self, venv: VecEnv): - assert isinstance( - venv.observation_space, gym.spaces.Dict - ), "RemoveHumanReadableWrapper only works with gym.spaces.Dict space" - assert ( - HR_OBS_KEY in venv.observation_space.spaces - ), f"Observation space must contain {HR_OBS_KEY!r}" - new_obs_space = remove_rgb_obs_space(venv.observation_space) - super().__init__(venv=venv, observation_space=new_obs_space) - - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: - return remove_rgb_obs(self.venv.reset()) - - def step_wait(self): - observations, rewards, dones, infos = self.venv.step_wait() - return remove_rgb_obs(observations), rewards, dones, infos diff --git a/src/imitation/policies/interactive.py b/src/imitation/policies/interactive.py index 947309793..d9934f9a0 100644 --- a/src/imitation/policies/interactive.py +++ b/src/imitation/policies/interactive.py @@ -11,7 +11,6 @@ from stable_baselines3.common import vec_env import imitation.policies.base as base_policies -from imitation.data import wrappers from imitation.util import util @@ -65,6 +64,9 @@ def _choose_action( if self.clear_screen_on_query: util.clear_screen() + if isinstance(obs, dict): + raise ValueError("Dictionary observations are not supported here") + context = self._render(obs) key = self._get_input_key() self._clean_up(context) @@ -108,15 +110,9 @@ def _render(self, obs: np.ndarray) -> plt.Figure: def _clean_up(self, context: plt.Figure) -> None: plt.close(context) - def _prepare_obs_image( - self, obs: Union[np.ndarray, Dict[str, np.ndarray]] - ) -> np.ndarray: + def _prepare_obs_image(self, obs: np.ndarray) -> np.ndarray: """Applies any required observation processing to get an image to show.""" - if not isinstance(obs, Dict): - return obs - if wrappers.HR_OBS_KEY not in obs: - raise KeyError(f"Observation does not contain {wrappers.HR_OBS_KEY!r}") - return obs[wrappers.HR_OBS_KEY] + return obs ATARI_ACTION_NAMES_TO_KEYS = { From d954fed193485fea69a37e96ec23b8c7568bb408 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 22:30:25 -0700 Subject: [PATCH 67/85] Revert "adjust wrapper and not set render_mode inside" This reverts commit a9b32bdc3fdd8be9f0b50411cfcd1107316983c8. --- src/imitation/data/wrappers.py | 12 +++--------- tests/data/test_wrappers.py | 29 ++++++++++------------------- 2 files changed, 13 insertions(+), 28 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index beeb75a29..83e4dd0cc 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -222,16 +222,8 @@ def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): env: Environment to wrap. original_obs_key: The key for original observation if the original observation is not in dict format. - - Raises: - ValueError: If `env.render_mode` is not "rgb_array". - """ - if env.render_mode != "rgb_array": - raise ValueError( - "HumanReadableWrapper requires render_mode='rgb_array', " - f"got {env.render_mode!r}" - ) + env.render_mode = "rgb_array" self._original_obs_key = original_obs_key super().__init__(env) @@ -253,6 +245,8 @@ def _add_hr_obs( KeyError: When the key HR_OBS_KEY already exists in the observation dictionary. """ + assert self.env.render_mode is not None + assert self.env.render_mode == "rgb_array" hr_obs = self.env.render() if not isinstance(obs, Dict): obs = {self._original_obs_key: obs} diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 3b2f4a8ac..279c72079 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -28,17 +28,12 @@ class _CountingEnv(gym.Env): # pragma: no cover ``` """ - def __init__(self, episode_length: int = 5, render_mode: Optional[str] = None): + def __init__(self, episode_length=5): assert episode_length >= 1 self.observation_space = gym.spaces.Box(low=0, high=np.inf, shape=()) self.action_space = gym.spaces.Box(low=0, high=np.inf, shape=()) self.episode_length = episode_length - self.timestep: Optional[int] = None - self._render_mode = render_mode - - @property - def render_mode(self): - return self._render_mode + self.timestep = None def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): t, self.timestep = 0, 1 @@ -57,16 +52,14 @@ def step(self, action): return t, t * 10, done, False, {} def render(self) -> np.ndarray: - if self._render_mode != "rgb_array": - raise ValueError(f"Invalid render mode {self._render_mode}") return np.array([self.timestep] * 10) class _CountingDictEnv(_CountingEnv): # pragma: no cover """Similar to _CountingEnv, but with Dict observation.""" - def __init__(self, episode_length: int = 5, render_mode: Optional[str] = None): - super().__init__(episode_length, render_mode) + def __init__(self, episode_length=5): + super().__init__(episode_length) self.observation_space = gym.spaces.Dict( spaces={ "t": gym.spaces.Box(low=0, high=np.inf, shape=()), @@ -95,7 +88,7 @@ def step(self, action): def _make_buffering_venv( - Env: _CountingEnv, + Env: Type[gym.Env], error_on_premature_reset: bool, ) -> BufferingWrapper: venv = DummyVecEnv([Env] * 2) @@ -138,7 +131,7 @@ def concat(x): @pytest.mark.parametrize("n_steps", [1, 2, 20, 21]) @pytest.mark.parametrize("extra_pop_timesteps", [(), (1,), (4, 8)]) def test_pop( - Env: _CountingEnv, + Env: Type[gym.Env], episode_lengths: Sequence[int], n_steps: int, extra_pop_timesteps: Sequence[int], @@ -245,7 +238,7 @@ def make_env(ep_len): @pytest.mark.parametrize("Env", Envs) -def test_reset_error(Env: _CountingEnv): +def test_reset_error(Env: Type[gym.Env]): # Resetting before a `step()` is okay. for flag in [True, False]: venv = _make_buffering_venv(Env, flag) @@ -281,7 +274,7 @@ def test_reset_error(Env: _CountingEnv): @pytest.mark.parametrize("Env", Envs) -def test_n_transitions_and_empty_error(Env: _CountingEnv): +def test_n_transitions_and_empty_error(Env: Type[gym.Env]): venv = _make_buffering_venv(Env, True) trajs, ep_lens = venv.pop_trajectories() assert trajs == [] @@ -299,12 +292,10 @@ def test_n_transitions_and_empty_error(Env: _CountingEnv): @pytest.mark.parametrize("Env", Envs) @pytest.mark.parametrize("original_obs_key", ["k1", "k2"]) -def test_human_readable_wrapper(Env: _CountingEnv, original_obs_key: str): +def test_human_readable_wrapper(Env: Type[gym.Env], original_obs_key: str): num_obs_key_expected = 2 if Env == _CountingEnv else 3 origin_obs_key = original_obs_key if Env == _CountingEnv else "t" - env = HumanReadableWrapper( - Env(render_mode="rgb_array"), original_obs_key=original_obs_key - ) + env = HumanReadableWrapper(Env(), original_obs_key=original_obs_key) obs, _ = env.reset() assert isinstance(obs, Dict) From d1131d0f9b03940c517fe299e03ec6eb2667da0d Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 22:30:37 -0700 Subject: [PATCH 68/85] Revert "fix dict env observation space" This reverts commit ba6a6a7eb5191bac4c5f70f12313185b15e3c350. --- tests/data/test_wrappers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 279c72079..cfde9dbcc 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -61,10 +61,7 @@ class _CountingDictEnv(_CountingEnv): # pragma: no cover def __init__(self, episode_length=5): super().__init__(episode_length) self.observation_space = gym.spaces.Dict( - spaces={ - "t": gym.spaces.Box(low=0, high=np.inf, shape=()), - "2t": gym.spaces.Box(low=0, high=np.inf, shape=()), - }, + spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, ) def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): From 31f88879e0d689fa75a4c101b7da7550f5bd9716 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 22:30:44 -0700 Subject: [PATCH 69/85] Revert "Add HumanReadableWrapper" This reverts commit 6e5c3e8d695c0c29b721672dbaf3dabae628b23f. --- src/imitation/data/wrappers.py | 61 ++-------------------------------- tests/data/test_wrappers.py | 43 ++++-------------------- 2 files changed, 9 insertions(+), 95 deletions(-) diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index 83e4dd0cc..94c88111d 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -1,18 +1,14 @@ """Environment wrappers for collecting rollouts.""" -from typing import List, Optional, Sequence, Tuple, Dict, Union +from typing import List, Optional, Sequence, Tuple import gymnasium as gym -from gymnasium.core import Env import numpy as np import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper from imitation.data import rollout, types -# The key for human readable data in the observation. -HR_OBS_KEY = "HR_OBS" - class BufferingWrapper(VecEnvWrapper): """Saves transitions of underlying VecEnv. @@ -174,7 +170,7 @@ def pop_transitions(self) -> types.TransitionsWithRew: class RolloutInfoWrapper(gym.Wrapper): - """Adds the entire episode's rewards and observations to `info` at episode end. + """Add the entire episode's rewards and observations to `info` at episode end. Whenever done=True, `info["rollouts"]` is a dict with keys "obs" and "rews", whose corresponding values hold the NumPy arrays containing the raw observations and @@ -210,56 +206,3 @@ def step(self, action): "rews": np.stack(self._rews), } return obs, rew, terminated, truncated, info - - -class HumanReadableWrapper(gym.Wrapper): - """Adds human-readable observation to `obs` at every step.""" - - def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): - """Builds HumanReadableWrapper - - Args: - env: Environment to wrap. - original_obs_key: The key for original observation if the original - observation is not in dict format. - """ - env.render_mode = "rgb_array" - self._original_obs_key = original_obs_key - super().__init__(env) - - def _add_hr_obs( - self, obs: Union[np.ndarray, Dict[str, np.ndarray]] - ) -> Dict[str, np.ndarray]: - """Adds human-readable observation to obs. - - Transforms obs into dictionary if it is not already, and adds the human-readable - observation from `env.render()` under the key HR_OBS_KEY. - - Args: - obs: Observation from environment. - - Returns: - Observation dictionary with the human-readable data - - Raises: - KeyError: When the key HR_OBS_KEY already exists in the observation - dictionary. - """ - assert self.env.render_mode is not None - assert self.env.render_mode == "rgb_array" - hr_obs = self.env.render() - if not isinstance(obs, Dict): - obs = {self._original_obs_key: obs} - - if HR_OBS_KEY in obs: - raise KeyError(f"{HR_OBS_KEY!r} already exists in observation dict") - obs[HR_OBS_KEY] = hr_obs - return obs - - def reset(self, **kwargs): - obs, info = super().reset(**kwargs) - return self._add_hr_obs(obs), info - - def step(self, action): - obs, rew, terminated, truncated, info = self.env.step(action) - return self._add_hr_obs(obs), rew, terminated, truncated, info diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index cfde9dbcc..33677c68f 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -1,6 +1,6 @@ """Tests for `imitation.data.wrappers`.""" -from typing import List, Sequence, Type, Dict, Optional +from typing import List, Sequence, Type import gymnasium as gym import numpy as np @@ -8,11 +8,7 @@ from stable_baselines3.common.vec_env import DummyVecEnv from imitation.data import types -from imitation.data.wrappers import ( - BufferingWrapper, - HumanReadableWrapper, - HR_OBS_KEY, -) +from imitation.data.wrappers import BufferingWrapper class _CountingEnv(gym.Env): # pragma: no cover @@ -35,7 +31,7 @@ def __init__(self, episode_length=5): self.episode_length = episode_length self.timestep = None - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + def reset(self, seed=None): t, self.timestep = 0, 1 return t, {} @@ -51,9 +47,6 @@ def step(self, action): done = t == self.episode_length return t, t * 10, done, False, {} - def render(self) -> np.ndarray: - return np.array([self.timestep] * 10) - class _CountingDictEnv(_CountingEnv): # pragma: no cover """Similar to _CountingEnv, but with Dict observation.""" @@ -64,9 +57,9 @@ def __init__(self, episode_length=5): spaces={"t": gym.spaces.Box(low=0, high=np.inf, shape=())}, ) - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - t, self.timestep = 0, 1 - return {"t": t, "2t": 2 * t}, {} + def reset(self, seed=None): + t, self.timestep = 0.0, 1.0 + return {"t": t}, {} def step(self, action): if self.timestep is None: @@ -78,7 +71,7 @@ def step(self, action): t, self.timestep = self.timestep, self.timestep + 1 done = t == self.episode_length - return {"t": t, "2t": 2 * t}, t * 10, done, False, {} + return {"t": t}, t * 10, done, False, {} Envs = [_CountingEnv, _CountingDictEnv] @@ -285,25 +278,3 @@ def test_n_transitions_and_empty_error(Env: Type[gym.Env]): assert venv.n_transitions == 0 with pytest.raises(RuntimeError, match=".* empty .*"): venv.pop_transitions() - - -@pytest.mark.parametrize("Env", Envs) -@pytest.mark.parametrize("original_obs_key", ["k1", "k2"]) -def test_human_readable_wrapper(Env: Type[gym.Env], original_obs_key: str): - num_obs_key_expected = 2 if Env == _CountingEnv else 3 - origin_obs_key = original_obs_key if Env == _CountingEnv else "t" - env = HumanReadableWrapper(Env(), original_obs_key=original_obs_key) - - obs, _ = env.reset() - assert isinstance(obs, Dict) - assert HR_OBS_KEY in obs - assert len(obs) == num_obs_key_expected - assert obs[origin_obs_key] == 0 - _assert_equal_scrambled_vectors(obs[HR_OBS_KEY], np.array([1] * 10)) - - next_obs, *_ = env.step(env.action_space.sample()) - assert isinstance(next_obs, Dict) - assert HR_OBS_KEY in next_obs - assert len(next_obs) == num_obs_key_expected - assert next_obs[origin_obs_key] == 1 - _assert_equal_scrambled_vectors(next_obs[HR_OBS_KEY], np.array([2] * 10)) From ae9fa64963be614b9e18061ca915ed146a7d3e06 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 22:31:03 -0700 Subject: [PATCH 70/85] Revert "acts to obs for clarity" This reverts commit be79cf5614548e1e780dc06fa1648732bb44fb9d. --- src/imitation/algorithms/dagger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index 67bb382ca..fb68713e6 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -508,7 +508,7 @@ def create_trajectory_collector(self) -> InteractiveTrajectoryCollector: beta = self.beta_schedule(self.round_num) collector = InteractiveTrajectoryCollector( venv=self.venv, - get_robot_acts=lambda obs: self.bc_trainer.policy.predict(obs)[0], + get_robot_acts=lambda acts: self.bc_trainer.policy.predict(acts)[0], beta=beta, save_dir=save_dir, rng=self.rng, From 7a2b7ce70c94cc4c237d3b123df7ab2ae928ca61 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Tue, 3 Oct 2023 23:24:34 -0700 Subject: [PATCH 71/85] address comments --- src/imitation/algorithms/density.py | 9 ++++--- src/imitation/data/rollout.py | 7 +++-- tests/algorithms/conftest.py | 23 ++++++++++++++++- tests/algorithms/test_bc.py | 30 +++++----------------- tests/algorithms/test_density_baselines.py | 29 ++++----------------- 5 files changed, 41 insertions(+), 57 deletions(-) diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index a8bfceb7b..377467ce9 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -145,14 +145,15 @@ def _get_demo_from_batch( ) assert act_b.shape[1:] == self.venv.action_space.shape - + ob_space = self.venv.observation_space if isinstance(obs_b, types.DictObs): - exp_shape = {k: v.shape for k, v in self.venv.observation_space.items()} # type: ignore[attr-defined] # noqa: E501 - + exp_shape = { + k: v.shape for k, v in ob_space.items() # type: ignore[attr-defined] + } obs_shape = {k: v.shape[1:] for k, v in obs_b.items()} assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}" else: - assert obs_b.shape[1:] == self.venv.observation_space.shape + assert obs_b.shape[1:] == ob_space.shape assert len(act_b) == len(obs_b) if next_obs_b is not None: assert next_obs_b.shape == obs_b.shape diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 9d989a830..78007630e 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -490,10 +490,9 @@ def generate_trajectories( assert v.shape is not None exp_obs[k] = (n_steps + 1,) + v.shape else: - assert venv.observation_space.shape is not None - exp_obs = ( - n_steps + 1, - ) + venv.observation_space.shape # type: ignore[assignment] + obs_space_shape = venv.observation_space.shape + assert obs_space_shape is not None + exp_obs = (n_steps + 1,) + obs_space_shape # type: ignore[assignment] real_obs = trajectory.obs.shape assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}" assert venv.action_space.shape is not None diff --git a/tests/algorithms/conftest.py b/tests/algorithms/conftest.py index b3d4589fc..edb5a1b36 100644 --- a/tests/algorithms/conftest.py +++ b/tests/algorithms/conftest.py @@ -1,9 +1,11 @@ """Fixtures common across algorithm tests.""" from typing import Sequence +import gymnasium as gym import pytest from stable_baselines3.common.policies import BasePolicy -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common import envs +from stable_baselines3.common.vec_env import VecEnv, DummyVecEnv from imitation.algorithms import bc from imitation.data.types import TrajectoryWithRew @@ -109,3 +111,22 @@ def pendulum_single_venv(rng) -> VecEnv: post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], rng=rng, ) + + +# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 +# merged and released. +class FloatReward(gym.RewardWrapper): + """Typecasts reward to a float.""" + + def reward(self, reward): + return float(reward) + + +@pytest.fixture +def multi_obs_venv(): + def make_env(): + env = envs.SimpleMultiObsEnv(channel_last=False) + env = FloatReward(env) + return RolloutInfoWrapper(env) + + return DummyVecEnv([make_env, make_env]) diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index ed9c5a347..8be92368d 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -4,13 +4,11 @@ import os from typing import Any, Callable, Optional, Sequence -import gymnasium as gym import hypothesis import hypothesis.strategies as st import numpy as np import pytest import torch as th -from stable_baselines3.common import envs as sb_envs from stable_baselines3.common import evaluation from stable_baselines3.common import policies as sb_policies from stable_baselines3.common import vec_env @@ -291,27 +289,11 @@ def test_that_policy_reconstruction_preserves_parameters( th.testing.assert_close(original, reconstructed) -# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 -# merged and released. -class FloatReward(gym.RewardWrapper): - """Typecasts reward to a float.""" - - def reward(self, reward): - return float(reward) - - -def test_dict_space(): - def make_env(): - env = sb_envs.SimpleMultiObsEnv(channel_last=False) - env = FloatReward(env) - return RolloutInfoWrapper(env) - - env = vec_env.DummyVecEnv([make_env, make_env]) - +def test_dict_space(multi_obs_venv: vec_env.VecEnv): # multi-input policy to accept dict observations policy = sb_policies.MultiInputActorCriticPolicy( - env.observation_space, - env.action_space, + multi_obs_venv.observation_space, + multi_obs_venv.action_space, lambda _: 0.001, ) rng = np.random.default_rng() @@ -319,16 +301,16 @@ def make_env(): # sample random transitions rollouts = rollout.rollout( policy=None, - venv=env, + venv=multi_obs_venv, sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=50), rng=rng, unwrap=True, ) transitions = rollout.flatten_trajectories(rollouts) bc_trainer = bc.BC( - observation_space=env.observation_space, + observation_space=multi_obs_venv.observation_space, policy=policy, - action_space=env.action_space, + action_space=multi_obs_venv.action_space, rng=rng, demonstrations=transitions, ) diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index b72970a79..e05a8f728 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -3,17 +3,14 @@ from dataclasses import asdict from typing import Sequence, cast -import gymnasium as gym import numpy as np import pytest import stable_baselines3 -from stable_baselines3.common import envs as sb_envs -from stable_baselines3.common import policies, vec_env +from stable_baselines3.common import policies from imitation.algorithms.density import DensityAlgorithm, DensityType from imitation.data import rollout, types from imitation.data.types import TrajectoryWithRew -from imitation.data.wrappers import RolloutInfoWrapper from imitation.policies.base import RandomPolicy from imitation.testing import reward_improvement @@ -172,27 +169,11 @@ def test_density_trainer_raises( density_trainer.set_demonstrations("foo") # type: ignore[arg-type] -# TODO(GH#794): Remove after https://github.com/DLR-RM/stable-baselines3/pull/1676 -# merged and released. -class FloatReward(gym.RewardWrapper): - """Typecasts reward to a float.""" - - def reward(self, reward): - return float(reward) - - -def test_dict_space(): - def make_env(): - env = sb_envs.SimpleMultiObsEnv(channel_last=False) - env = FloatReward(env) - return RolloutInfoWrapper(env) - - venv = vec_env.DummyVecEnv([make_env, make_env]) - +def test_dict_space(multi_obs_venv): # multi-input policy to accept dict observations rl_algo = stable_baselines3.PPO( policies.MultiInputActorCriticPolicy, - venv, + multi_obs_venv, n_steps=10, # small value to make test faster n_epochs=2, # small value to make test faster ) @@ -202,14 +183,14 @@ def make_env(): sample_until = rollout.make_min_episodes(15) rollouts = rollout.rollout( policy=None, - venv=venv, + venv=multi_obs_venv, sample_until=sample_until, rng=rng, ) density_trainer = DensityAlgorithm( demonstrations=rollouts, kernel="gaussian", - venv=venv, + venv=multi_obs_venv, rl_algo=rl_algo, kernel_bandwidth=0.2, standardise_inputs=True, From 15541cd56e3aac8c01a026236a5bd9e72f750a72 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 00:06:49 -0700 Subject: [PATCH 72/85] new pytype need input directory or file --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f07fd58d9..1f2d8d3a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -78,7 +78,7 @@ repos: name: pytype language: system types: [python] - entry: "bash -c 'pytype -j ${NUM_CPUS:-auto}'" + entry: "bash -c 'pytype ./ -j ${NUM_CPUS:-auto}'" require_serial: true verbose: true - id: docs From 6884538ddbc3f2083b988c4541c2890f995f4f2d Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 00:38:34 -0700 Subject: [PATCH 73/85] fix np.dtype --- tests/data/test_buffer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index e7615461e..64f607df2 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -176,7 +176,7 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype): @pytest.mark.parametrize("sample_shape", [(), (1,), (5, 2)]) def test_buffer_store_errors(sample_shape): capacity = 11 - dtype = "float32" + dtype = np.float32 def buf(): return Buffer(capacity, {"k": sample_shape}, {"k": dtype}) @@ -208,14 +208,14 @@ def buf(): def test_buffer_sample_errors(): - b = Buffer(10, {"k": (2, 1)}, dtypes={"k": bool}) + b = Buffer(10, {"k": (2, 1)}, dtypes={"k": np.bool_}) with pytest.raises(ValueError): b.sample(5) def test_buffer_init_errors(): with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"): - Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a="float32", c=bool)) + Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a=np.float32, c=np.bool_)) def test_replay_buffer_init_errors(): @@ -225,13 +225,13 @@ def test_replay_buffer_init_errors(): ): ReplayBuffer(15, venv=gym.make("CartPole-v1"), obs_shape=(10, 10)) with pytest.raises(ValueError, match=r"Shape or dtype missing.*"): - ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=bool) + ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=np.bool_) with pytest.raises(ValueError, match=r"Shape or dtype missing.*"): - ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=bool, act_dtype=bool) + ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=np.bool_, act_dtype=np.bool_) def test_buffer_from_data(): - data = np.ndarray([50, 30], dtype=bool) + data = np.ndarray([50, 30], dtype=np.bool_) buf = Buffer.from_data({"k": data}) assert buf._arrays["k"] is not data assert data.dtype == buf._arrays["k"].dtype From 5c6e5b883339323f61c678cd9281d726df99f202 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 00:39:24 -0700 Subject: [PATCH 74/85] ignore typed-dict-error --- tests/algorithms/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 802ac0d7f..ad7a930e6 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -144,7 +144,7 @@ def test_make_data_loader(): for batch, expected_batch in zip(data_loader, trans_mapping): assert batch.keys() == expected_batch.keys() for k in batch.keys(): - v = batch[k] + v = batch[k] # type: ignore[typed-dict-error] if isinstance(v, th.Tensor): v = v.numpy() assert np.all(v == expected_batch[k]) From 5c1d7515e99a1e6b7d9eb33f65cac37f107fd402 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 00:39:58 -0700 Subject: [PATCH 75/85] context manager related fix --- src/imitation/scripts/eval_policy.py | 4 +++- src/imitation/scripts/ingredients/demonstrations.py | 2 +- src/imitation/scripts/train_adversarial.py | 2 +- src/imitation/scripts/train_imitation.py | 6 +++--- src/imitation/scripts/train_preference_comparisons.py | 2 +- src/imitation/scripts/train_rl.py | 4 +++- 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 938d7194b..cc11c5d82 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -95,7 +95,9 @@ def eval_policy( log_dir = logging_ingredient.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None - with environment.make_venv(post_wrappers=post_wrappers) as venv: + with environment.make_venv( # type: ignore[wrong-keyword-args] + post_wrappers=post_wrappers + ) as venv: if render: venv = InteractiveRender(venv, render_fps) diff --git a/src/imitation/scripts/ingredients/demonstrations.py b/src/imitation/scripts/ingredients/demonstrations.py index 1367c0722..57fe6d5c7 100644 --- a/src/imitation/scripts/ingredients/demonstrations.py +++ b/src/imitation/scripts/ingredients/demonstrations.py @@ -143,7 +143,7 @@ def _generate_expert_trajs( raise ValueError("n_expert_demos must be specified when generating demos.") logger.info(f"Generating {n_expert_demos} expert trajectories") - with environment.make_rollout_venv() as rollout_env: + with environment.make_rollout_venv() as rollout_env: # type: ignore[wrong-arg-count] return rollout.rollout( expert.get_expert_policy(rollout_env), rollout_env, diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 26c8d7bcf..9afc51135 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -119,7 +119,7 @@ def train_adversarial( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 58dae3484..292597561 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -73,7 +73,7 @@ def bc( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger) bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"]) @@ -115,7 +115,7 @@ def dagger( if dagger["use_offline_rollouts"]: expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger) bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"]) @@ -161,7 +161,7 @@ def sqil( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] sqil_trainer = sqil_algorithm.SQIL( venv=venv, demonstrations=expert_trajs, diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 79ee4c136..8fb13f4c4 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -166,7 +166,7 @@ def train_preference_comparisons( custom_logger, log_dir = logging_ingredient.setup_logging() - with environment.make_venv() as venv: + with environment.make_venv() as venv: # type: ignore[wrong-arg-count] reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 96d35122c..4bb661fa6 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -99,7 +99,9 @@ def train_rl( policy_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] - with environment.make_venv(post_wrappers=post_wrappers) as venv: + with environment.make_venv( # type: ignore[wrong-keyword-args] + post_wrappers=post_wrappers + ) as venv: callback_objs = [] if reward_type is not None: reward_fn = load_reward( From f5288c692f346b76a262ec0497b2c4f540ef6a0f Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 00:41:19 -0700 Subject: [PATCH 76/85] keep pytype checking more failures --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f2d8d3a5..f6607a0c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -78,7 +78,7 @@ repos: name: pytype language: system types: [python] - entry: "bash -c 'pytype ./ -j ${NUM_CPUS:-auto}'" + entry: "bash -c 'pytype ./ --keep-going -j ${NUM_CPUS:-auto}'" require_serial: true verbose: true - id: docs From 6e94dea58d3e5831a102742c660d87e19f1eafac Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 07:28:28 -0700 Subject: [PATCH 77/85] Revert "keep pytype checking more failures" This reverts commit f5288c692f346b76a262ec0497b2c4f540ef6a0f. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f6607a0c2..1f2d8d3a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -78,7 +78,7 @@ repos: name: pytype language: system types: [python] - entry: "bash -c 'pytype ./ --keep-going -j ${NUM_CPUS:-auto}'" + entry: "bash -c 'pytype ./ -j ${NUM_CPUS:-auto}'" require_serial: true verbose: true - id: docs From bb1f9cded886cf6037069840404d3e9c8e8dc509 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 07:28:40 -0700 Subject: [PATCH 78/85] Revert "context manager related fix" This reverts commit 5c1d7515e99a1e6b7d9eb33f65cac37f107fd402. --- src/imitation/scripts/eval_policy.py | 4 +--- src/imitation/scripts/ingredients/demonstrations.py | 2 +- src/imitation/scripts/train_adversarial.py | 2 +- src/imitation/scripts/train_imitation.py | 6 +++--- src/imitation/scripts/train_preference_comparisons.py | 2 +- src/imitation/scripts/train_rl.py | 4 +--- 6 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index cc11c5d82..938d7194b 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -95,9 +95,7 @@ def eval_policy( log_dir = logging_ingredient.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None - with environment.make_venv( # type: ignore[wrong-keyword-args] - post_wrappers=post_wrappers - ) as venv: + with environment.make_venv(post_wrappers=post_wrappers) as venv: if render: venv = InteractiveRender(venv, render_fps) diff --git a/src/imitation/scripts/ingredients/demonstrations.py b/src/imitation/scripts/ingredients/demonstrations.py index 57fe6d5c7..1367c0722 100644 --- a/src/imitation/scripts/ingredients/demonstrations.py +++ b/src/imitation/scripts/ingredients/demonstrations.py @@ -143,7 +143,7 @@ def _generate_expert_trajs( raise ValueError("n_expert_demos must be specified when generating demos.") logger.info(f"Generating {n_expert_demos} expert trajectories") - with environment.make_rollout_venv() as rollout_env: # type: ignore[wrong-arg-count] + with environment.make_rollout_venv() as rollout_env: return rollout.rollout( expert.get_expert_policy(rollout_env), rollout_env, diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 9afc51135..26c8d7bcf 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -119,7 +119,7 @@ def train_adversarial( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: # type: ignore[wrong-arg-count] + with environment.make_venv() as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 292597561..58dae3484 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -73,7 +73,7 @@ def bc( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: # type: ignore[wrong-arg-count] + with environment.make_venv() as venv: bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger) bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"]) @@ -115,7 +115,7 @@ def dagger( if dagger["use_offline_rollouts"]: expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: # type: ignore[wrong-arg-count] + with environment.make_venv() as venv: bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger) bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"]) @@ -161,7 +161,7 @@ def sqil( custom_logger, log_dir = logging_ingredient.setup_logging() expert_trajs = demonstrations.get_expert_trajectories() - with environment.make_venv() as venv: # type: ignore[wrong-arg-count] + with environment.make_venv() as venv: sqil_trainer = sqil_algorithm.SQIL( venv=venv, demonstrations=expert_trajs, diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 8fb13f4c4..79ee4c136 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -166,7 +166,7 @@ def train_preference_comparisons( custom_logger, log_dir = logging_ingredient.setup_logging() - with environment.make_venv() as venv: # type: ignore[wrong-arg-count] + with environment.make_venv() as venv: reward_net = reward.make_reward_net(venv) relabel_reward_fn = functools.partial( reward_net.predict_processed, diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 4bb661fa6..96d35122c 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -99,9 +99,7 @@ def train_rl( policy_dir.mkdir(parents=True, exist_ok=True) post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] - with environment.make_venv( # type: ignore[wrong-keyword-args] - post_wrappers=post_wrappers - ) as venv: + with environment.make_venv(post_wrappers=post_wrappers) as venv: callback_objs = [] if reward_type is not None: reward_fn = load_reward( From a07ea269e98e007443161e39de43c201b9de820a Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 07:29:04 -0700 Subject: [PATCH 79/85] Revert "ignore typed-dict-error" This reverts commit 5c6e5b883339323f61c678cd9281d726df99f202. --- tests/algorithms/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index ad7a930e6..802ac0d7f 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -144,7 +144,7 @@ def test_make_data_loader(): for batch, expected_batch in zip(data_loader, trans_mapping): assert batch.keys() == expected_batch.keys() for k in batch.keys(): - v = batch[k] # type: ignore[typed-dict-error] + v = batch[k] if isinstance(v, th.Tensor): v = v.numpy() assert np.all(v == expected_batch[k]) From b2cca2e84e57e0b34d9f5aca3dd466768347973c Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 07:29:45 -0700 Subject: [PATCH 80/85] Revert "fix np.dtype" This reverts commit 6884538ddbc3f2083b988c4541c2890f995f4f2d. --- tests/data/test_buffer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index 64f607df2..e7615461e 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -176,7 +176,7 @@ def test_replay_buffer(capacity, chunk_len, obs_shape, act_shape, dtype): @pytest.mark.parametrize("sample_shape", [(), (1,), (5, 2)]) def test_buffer_store_errors(sample_shape): capacity = 11 - dtype = np.float32 + dtype = "float32" def buf(): return Buffer(capacity, {"k": sample_shape}, {"k": dtype}) @@ -208,14 +208,14 @@ def buf(): def test_buffer_sample_errors(): - b = Buffer(10, {"k": (2, 1)}, dtypes={"k": np.bool_}) + b = Buffer(10, {"k": (2, 1)}, dtypes={"k": bool}) with pytest.raises(ValueError): b.sample(5) def test_buffer_init_errors(): with pytest.raises(KeyError, match=r"sample_shape and dtypes.*"): - Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a=np.float32, c=np.bool_)) + Buffer(10, dict(a=(2, 1), b=(3,)), dtypes=dict(a="float32", c=bool)) def test_replay_buffer_init_errors(): @@ -225,13 +225,13 @@ def test_replay_buffer_init_errors(): ): ReplayBuffer(15, venv=gym.make("CartPole-v1"), obs_shape=(10, 10)) with pytest.raises(ValueError, match=r"Shape or dtype missing.*"): - ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=np.bool_) + ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=bool) with pytest.raises(ValueError, match=r"Shape or dtype missing.*"): - ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=np.bool_, act_dtype=np.bool_) + ReplayBuffer(15, obs_shape=(10, 10), obs_dtype=bool, act_dtype=bool) def test_buffer_from_data(): - data = np.ndarray([50, 30], dtype=np.bool_) + data = np.ndarray([50, 30], dtype=bool) buf = Buffer.from_data({"k": data}) assert buf._arrays["k"] is not data assert data.dtype == buf._arrays["k"].dtype From 1a24ae530fa77657916b199c075ccbcfa6332a78 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 07:29:56 -0700 Subject: [PATCH 81/85] Revert "new pytype need input directory or file" This reverts commit 15541cd56e3aac8c01a026236a5bd9e72f750a72. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f2d8d3a5..f07fd58d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -78,7 +78,7 @@ repos: name: pytype language: system types: [python] - entry: "bash -c 'pytype ./ -j ${NUM_CPUS:-auto}'" + entry: "bash -c 'pytype -j ${NUM_CPUS:-auto}'" require_serial: true verbose: true - id: docs From b989af8f9176349d69058bf2de6865b18a054554 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 07:30:44 -0700 Subject: [PATCH 82/85] Revert "Upgrade pytype and remove workaround for old versions" This reverts commit 194ec1ac336d5ecdb36a5df92d53e056458291be. --- setup.py | 2 +- tests/algorithms/test_base.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a461ced36..5fc3354ad 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ ATARI_REQUIRE = [ "seals[atari]~=0.2.1", ] -PYTYPE = ["pytype==2023.9.27"] if IS_NOT_WINDOWS else [] +PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] # Note: the versions of the test and doc requirements should be tightly pinned to known # working versions to make our CI/CD pipeline as stable as possible. diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 802ac0d7f..1654a1482 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -41,7 +41,10 @@ def test_check_fixed_horizon_flag(custom_logger): def _make_and_iterate_loader(*args, **kwargs): - loader = base.make_data_loader(*args, **kwargs) + # our pytype version doesn't understand optional arguments in TypedDict + # this is fixed in 2023.04.11, but we require 2022.7.26 + # See https://github.com/google/pytype/issues/1195 + loader = base.make_data_loader(*args, **kwargs) # pytype: disable=wrong-arg-types for batch in loader: pass From 4817c2f3d34594e299fa30b44f4d475809b8ed9b Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 07:41:48 -0700 Subject: [PATCH 83/85] lint fix --- tests/algorithms/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/algorithms/conftest.py b/tests/algorithms/conftest.py index edb5a1b36..687395f17 100644 --- a/tests/algorithms/conftest.py +++ b/tests/algorithms/conftest.py @@ -3,9 +3,9 @@ import gymnasium as gym import pytest -from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common import envs -from stable_baselines3.common.vec_env import VecEnv, DummyVecEnv +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv from imitation.algorithms import bc from imitation.data.types import TrajectoryWithRew From 94c3ecfe4f2c6a8d8525172ebb56dc9ff4a81354 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 08:04:54 -0700 Subject: [PATCH 84/85] fix type check --- tests/algorithms/conftest.py | 2 +- tests/algorithms/test_bc.py | 2 ++ tests/algorithms/test_density_baselines.py | 5 ++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/algorithms/conftest.py b/tests/algorithms/conftest.py index 687395f17..a453f047d 100644 --- a/tests/algorithms/conftest.py +++ b/tests/algorithms/conftest.py @@ -123,7 +123,7 @@ def reward(self, reward): @pytest.fixture -def multi_obs_venv(): +def multi_obs_venv() -> VecEnv: def make_env(): env = envs.SimpleMultiObsEnv(channel_last=False) env = FloatReward(env) diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 8be92368d..8de49c66e 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -4,6 +4,7 @@ import os from typing import Any, Callable, Optional, Sequence +import gymnasium as gym import hypothesis import hypothesis.strategies as st import numpy as np @@ -291,6 +292,7 @@ def test_that_policy_reconstruction_preserves_parameters( def test_dict_space(multi_obs_venv: vec_env.VecEnv): # multi-input policy to accept dict observations + assert isinstance(multi_obs_venv.observation_space, gym.spaces.Dict) policy = sb_policies.MultiInputActorCriticPolicy( multi_obs_venv.observation_space, multi_obs_venv.action_space, diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index e05a8f728..6bf8c598e 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -3,10 +3,12 @@ from dataclasses import asdict from typing import Sequence, cast +import gymnasium as gym import numpy as np import pytest import stable_baselines3 from stable_baselines3.common import policies +from stable_baselines3.common import vec_env from imitation.algorithms.density import DensityAlgorithm, DensityType from imitation.data import rollout, types @@ -169,8 +171,9 @@ def test_density_trainer_raises( density_trainer.set_demonstrations("foo") # type: ignore[arg-type] -def test_dict_space(multi_obs_venv): +def test_dict_space(multi_obs_venv: vec_env.VecEnv): # multi-input policy to accept dict observations + assert isinstance(multi_obs_venv.observation_space, gym.spaces.Dict) rl_algo = stable_baselines3.PPO( policies.MultiInputActorCriticPolicy, multi_obs_venv, From d5d191880bd09fcf88e5aac8026df76f5652e336 Mon Sep 17 00:00:00 2001 From: ZiyueWang25 Date: Wed, 4 Oct 2023 08:25:29 -0700 Subject: [PATCH 85/85] fix lint --- tests/algorithms/test_density_baselines.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index 6bf8c598e..5c92feb58 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -7,8 +7,7 @@ import numpy as np import pytest import stable_baselines3 -from stable_baselines3.common import policies -from stable_baselines3.common import vec_env +from stable_baselines3.common import policies, vec_env from imitation.algorithms.density import DensityAlgorithm, DensityType from imitation.data import rollout, types