Skip to content

Commit

Permalink
Adapt hyperparameters in test_sqil_performance to reduce flakiness
Browse files Browse the repository at this point in the history
  • Loading branch information
jas-ho committed Aug 8, 2023
1 parent 4ccea30 commit 68cbce8
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/algorithms/test_sqil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Tests `imitation.algorithms.sqil`."""

import numpy as np
import pytest
from stable_baselines3.common import policies, vec_env
Expand Down Expand Up @@ -99,13 +98,15 @@ def test_sqil_performance(
cartpole_venv: vec_env.VecEnv,
):
demonstrations = get_demos(rng, pytestconfig, "transitions")
demonstrations = demonstrations[:20]

model = sqil.SQIL(
venv=cartpole_venv,
demonstrations=demonstrations,
policy="MlpPolicy",
dqn_kwargs=dict(learning_starts=1000),
dqn_kwargs=dict(
learning_starts=500,
learning_rate=0.002,
batch_size=220,
),
)

rewards_before, _ = evaluate_policy(
Expand Down

0 comments on commit 68cbce8

Please sign in to comment.