Skip to content

Commit

Permalink
Make test for exceptions raised by SQIL constructor more specific
Browse files Browse the repository at this point in the history
- also: adjust imports to conform with style guide
  • Loading branch information
jas-ho committed Aug 9, 2023
1 parent 5cbb6b2 commit d2124a2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/algorithms/test_sqil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Tests `imitation.algorithms.sqil`."""
from unittest.mock import MagicMock
from unittest import mock

import numpy as np
import pytest
Expand Down Expand Up @@ -139,9 +139,9 @@ def test_sqil_performance(

@pytest.mark.parametrize("illegal_kw", ["replay_buffer_class", "replay_buffer_kwargs"])
def test_sqil_constructor_raises(illegal_kw: str):
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=".*SQIL uses a custom replay buffer.*"):
sqil.SQIL(
venv=MagicMock(spec=vec_env.VecEnv),
venv=mock.MagicMock(spec=vec_env.VecEnv),
demonstrations=None,
policy="MlpPolicy",
dqn_kwargs={illegal_kw: None},
Expand Down

0 comments on commit d2124a2

Please sign in to comment.