-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: HuggingFace save_to_disk takes
PathLike
type which is defined …
…as str, bytes or os.PathLike. imitation.util.parse_path always returned pathlib.Path which is not one of these types. This commit converts pathlib.Path to str before calling the HF fn.
- Loading branch information
1 parent
a8b079c
commit f5cb8a4
Showing
4 changed files
with
126 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import gymnasium as gym | ||
import numpy as np | ||
import pytest | ||
|
||
from imitation.data import types | ||
|
||
SPACES = [ | ||
gym.spaces.Discrete(3), | ||
gym.spaces.MultiDiscrete([3, 4]), | ||
gym.spaces.Box(-1, 1, shape=(1,)), | ||
gym.spaces.Box(-1, 1, shape=(2,)), | ||
gym.spaces.Box(-np.inf, np.inf, shape=(2,)), | ||
] | ||
DICT_SPACE = gym.spaces.Dict( | ||
{"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))}, | ||
) | ||
LENGTHS = [0, 1, 2, 10] | ||
|
||
|
||
@pytest.fixture(params=SPACES) | ||
def act_space(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture(params=SPACES + [DICT_SPACE]) | ||
def obs_space(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture(params=LENGTHS) | ||
def length(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture | ||
def trajectory( | ||
obs_space: gym.Space, | ||
act_space: gym.Space, | ||
length: int, | ||
) -> types.Trajectory: | ||
"""Fixture to generate trajectory of length `length` iid sampled from spaces.""" | ||
if length == 0: | ||
pytest.skip() | ||
|
||
raw_obs = [obs_space.sample() for _ in range(length + 1)] | ||
if isinstance(obs_space, gym.spaces.Dict): | ||
obs: types.Observation = types.DictObs.from_obs_list(raw_obs) | ||
else: | ||
obs = np.array(raw_obs) | ||
acts = np.array([act_space.sample() for _ in range(length)]) | ||
infos = np.array([{f"key{i}": i} for i in range(length)]) | ||
return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True) | ||
|
||
|
||
@pytest.fixture | ||
def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew: | ||
"""Like `trajectory` but with reward randomly sampled from a Gaussian.""" | ||
rews = np.random.randn(len(trajectory)) | ||
return types.TrajectoryWithRew( | ||
**types.dataclass_quick_asdict(trajectory), | ||
rews=rews, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
"""Tests for `imitation.data.serialize`.""" | ||
|
||
import pathlib | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import pytest | ||
|
||
from imitation.data import serialize, types | ||
from imitation.data.types import DictObs | ||
|
||
|
||
@pytest.fixture | ||
def data_path(tmp_path): | ||
return tmp_path / "data" | ||
|
||
|
||
@pytest.mark.parametrize("path_type", [str, pathlib.Path]) | ||
def test_save_trajectory(data_path, trajectory, path_type): | ||
if isinstance(trajectory.obs, DictObs): | ||
pytest.skip("serialize.save does not yet support DictObs") | ||
|
||
serialize.save(path_type(data_path), [trajectory]) | ||
assert data_path.exists() | ||
|
||
|
||
@pytest.mark.parametrize("path_type", [str, pathlib.Path]) | ||
def test_save_trajectory_rew(data_path, trajectory_rew, path_type): | ||
if isinstance(trajectory_rew.obs, DictObs): | ||
pytest.skip("serialize.save does not yet support DictObs") | ||
serialize.save(path_type(data_path), [trajectory_rew]) | ||
assert data_path.exists() | ||
|
||
|
||
def test_save_load_trajectory(data_path, trajectory): | ||
if isinstance(trajectory.obs, DictObs): | ||
pytest.skip("serialize.save does not yet support DictObs") | ||
serialize.save(data_path, [trajectory]) | ||
|
||
reconstructed = list(serialize.load(data_path)) | ||
reconstructedi = reconstructed[0] | ||
|
||
assert len(reconstructed) == 1 | ||
assert np.allclose(reconstructedi.obs, trajectory.obs) | ||
assert np.allclose(reconstructedi.acts, trajectory.acts) | ||
assert np.allclose(reconstructedi.terminal, trajectory.terminal) | ||
assert not hasattr(reconstructedi, "rews") | ||
|
||
|
||
@pytest.mark.parametrize("load_fn", [serialize.load, serialize.load_with_rewards]) | ||
def test_save_load_trajectory_rew(data_path, trajectory_rew, load_fn): | ||
if isinstance(trajectory_rew.obs, DictObs): | ||
pytest.skip("serialize.save does not yet support DictObs") | ||
serialize.save(data_path, [trajectory_rew]) | ||
|
||
reconstructed = list(load_fn(data_path)) | ||
reconstructedi = reconstructed[0] | ||
|
||
assert len(reconstructed) == 1 | ||
assert np.allclose(reconstructedi.obs, trajectory_rew.obs) | ||
assert np.allclose(reconstructedi.acts, trajectory_rew.acts) | ||
assert np.allclose(reconstructedi.terminal, trajectory_rew.terminal) | ||
assert np.allclose(reconstructedi.rews, trajectory_rew.rews) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters