diff --git a/setup.py b/setup.py index 620f7edca..5a1846697 100644 --- a/setup.py +++ b/setup.py @@ -207,10 +207,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: STABLE_BASELINES3, "sacred>=0.8.4", "tensorboard>=1.14", - # TODO: remove once https://github.com/huggingface/huggingface_sb3/issues/37 is - # fixed - "huggingface_sb3==2.2.5", - "optuna>=3.0.1", + "huggingface_sb3~=2.3", "datasets>=2.8.0", ], tests_require=TESTS_REQUIRE,