From 284a80f9695571ca09fcde3c734704d63c48c131 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 6 Oct 2023 10:48:15 +0200 Subject: [PATCH] Make it possible to import the package without scikit-optimize Only needed for SkoptOptimizer --- brian2modelfitting/optimizer.py | 31 ++++++++++++---------- brian2modelfitting/tests/test_optimizer.py | 20 ++++++++------ pyproject.toml | 1 + 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/brian2modelfitting/optimizer.py b/brian2modelfitting/optimizer.py index 0a0d576..e163a02 100644 --- a/brian2modelfitting/optimizer.py +++ b/brian2modelfitting/optimizer.py @@ -1,24 +1,26 @@ import abc import numpy as np -import warnings +import sklearn -from brian2.utils.logger import get_logger +try: + import skopt +except ImportError: + skopt = None -# Prevent sklearn from adding a filter by monkey-patching the warnings module -# TODO: Remove when we depend on a newer version of scikit-learn (with -# https://github.com/scikit-learn/scikit-learn/pull/15080 merged) -_filterwarnings = warnings.filterwarnings -warnings.filterwarnings = lambda *args, **kwds: None -from skopt.space import Real -from skopt import Optimizer as skoptOptimizer -from sklearn.base import RegressorMixin -warnings.filterwarnings = _filterwarnings +from brian2.utils.logger import get_logger import nevergrad from nevergrad.optimization import optimizerlib, registry logger = get_logger(__name__) + +def _check_skopt(): + if skopt is None: + raise ImportError("The SkoptOptimizer requires the `scikit-optimize` " + "library to be installed.") + + def calc_bounds(parameter_names, **params): """ Verify and get the provided for parameters bounds @@ -254,9 +256,10 @@ class SkoptOptimizer(Optimizer): Number of calls to ``func``. Defaults to 100. """ def __init__(self, method='GP', **kwds): + _check_skopt() super(Optimizer, self).__init__() if not(method.upper() in ["GP", "RF", "ET", "GBRT"] or - isinstance(method, RegressorMixin)): + isinstance(method, sklearn.base.RegressorMixin)): raise AssertionError("Provided method: {} is not an skopt " "optimization or a regressor".format(method)) @@ -277,10 +280,10 @@ def initialize(self, parameter_names, popsize, rounds, **params): instruments = [] for i, name in enumerate(parameter_names): - instrumentation = Real(*np.asarray(bounds[i]), transform='normalize') + instrumentation = skopt.space.Real(*np.asarray(bounds[i]), transform='normalize') instruments.append(instrumentation) - self.optim = skoptOptimizer( + self.optim = skopt.Optimizer( dimensions=instruments, base_estimator=self.method, **self.kwds) diff --git a/brian2modelfitting/tests/test_optimizer.py b/brian2modelfitting/tests/test_optimizer.py index 6f1bd28..1063512 100644 --- a/brian2modelfitting/tests/test_optimizer.py +++ b/brian2modelfitting/tests/test_optimizer.py @@ -3,17 +3,20 @@ ''' import numpy as np from numpy.testing import assert_equal, assert_raises -from brian2modelfitting import Optimizer, NevergradOptimizer, SkoptOptimizer, calc_bounds +from brian2modelfitting import NevergradOptimizer, SkoptOptimizer, calc_bounds -from skopt import Optimizer as SOptimizer from nevergrad.optimization.base import Optimizer as NOptimizer +try: + import skopt +except ImportError: + skopt = None -def test_init(): - # Optimizer() +def test_init_nevergrad(): NevergradOptimizer() - SkoptOptimizer() - NevergradOptimizer(method='DE') + +def test_init_skopt(): + SkoptOptimizer() SkoptOptimizer(method='GP') @@ -51,13 +54,14 @@ def test_initialize_nevergrad(): def test_initialize_skopt(): + assert skopt is not None s_opt = SkoptOptimizer() s_opt.initialize({'g'}, g=[1, 30], popsize=30, rounds=2) - assert isinstance(s_opt.optim, SOptimizer) + assert isinstance(s_opt.optim, skopt.Optimizer) assert_equal(len(s_opt.optim.space.dimensions), 1) s_opt.initialize({'g', 'E'}, g=[1, 30], E=[2, 20], popsize=30, rounds=2) - assert isinstance(s_opt.optim, SOptimizer) + assert isinstance(s_opt.optim, skopt.Optimizer) assert_equal(len(s_opt.optim.space.dimensions), 2) assert_raises(TypeError, s_opt.initialize, ['g'], g=[1], popsize=30) diff --git a/pyproject.toml b/pyproject.toml index 2fc20f4..b6f4b1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ 'numpy>=1.21', 'brian2>=2.2', 'nevergrad>=0.4', + 'scikit-learn>=0.22', 'tqdm', 'pandas', ]