diff --git a/setup.py b/setup.py index fa1d03f31..6d1f2489c 100644 --- a/setup.py +++ b/setup.py @@ -207,7 +207,10 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: STABLE_BASELINES3, "sacred>=0.8.4", "tensorboard>=1.14", - "huggingface_sb3>=2.2.1", + # TODO: remove once https://github.com/huggingface/huggingface_sb3/issues/37 is + # fixed + "huggingface_sb3==2.2.5", + "optuna>=3.0.1", "datasets>=2.8.0", ], tests_require=TESTS_REQUIRE,