Skip to content

Commit

Permalink
Make it possible to import the package without scikit-optimize
Browse files Browse the repository at this point in the history
Only needed for SkoptOptimizer
  • Loading branch information
mstimberg committed Oct 6, 2023
1 parent 099d218 commit 284a80f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
31 changes: 17 additions & 14 deletions brian2modelfitting/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import abc
import numpy as np
import warnings
import sklearn

from brian2.utils.logger import get_logger
try:
import skopt
except ImportError:
skopt = None

# Prevent sklearn from adding a filter by monkey-patching the warnings module
# TODO: Remove when we depend on a newer version of scikit-learn (with
# https://github.com/scikit-learn/scikit-learn/pull/15080 merged)
_filterwarnings = warnings.filterwarnings
warnings.filterwarnings = lambda *args, **kwds: None
from skopt.space import Real
from skopt import Optimizer as skoptOptimizer
from sklearn.base import RegressorMixin
warnings.filterwarnings = _filterwarnings
from brian2.utils.logger import get_logger

import nevergrad
from nevergrad.optimization import optimizerlib, registry

logger = get_logger(__name__)


def _check_skopt():
if skopt is None:
raise ImportError("The SkoptOptimizer requires the `scikit-optimize` "
"library to be installed.")


def calc_bounds(parameter_names, **params):
"""
Verify and get the provided for parameters bounds
Expand Down Expand Up @@ -254,9 +256,10 @@ class SkoptOptimizer(Optimizer):
Number of calls to ``func``. Defaults to 100.
"""
def __init__(self, method='GP', **kwds):
_check_skopt()
super(Optimizer, self).__init__()
if not(method.upper() in ["GP", "RF", "ET", "GBRT"] or
isinstance(method, RegressorMixin)):
isinstance(method, sklearn.base.RegressorMixin)):
raise AssertionError("Provided method: {} is not an skopt "
"optimization or a regressor".format(method))

Expand All @@ -277,10 +280,10 @@ def initialize(self, parameter_names, popsize, rounds, **params):

instruments = []
for i, name in enumerate(parameter_names):
instrumentation = Real(*np.asarray(bounds[i]), transform='normalize')
instrumentation = skopt.space.Real(*np.asarray(bounds[i]), transform='normalize')
instruments.append(instrumentation)

self.optim = skoptOptimizer(
self.optim = skopt.Optimizer(
dimensions=instruments,
base_estimator=self.method,
**self.kwds)
Expand Down
20 changes: 12 additions & 8 deletions brian2modelfitting/tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
'''
import numpy as np
from numpy.testing import assert_equal, assert_raises
from brian2modelfitting import Optimizer, NevergradOptimizer, SkoptOptimizer, calc_bounds
from brian2modelfitting import NevergradOptimizer, SkoptOptimizer, calc_bounds

from skopt import Optimizer as SOptimizer
from nevergrad.optimization.base import Optimizer as NOptimizer
try:
import skopt
except ImportError:
skopt = None

def test_init():
# Optimizer()
def test_init_nevergrad():
NevergradOptimizer()
SkoptOptimizer()

NevergradOptimizer(method='DE')

def test_init_skopt():
SkoptOptimizer()
SkoptOptimizer(method='GP')


Expand Down Expand Up @@ -51,13 +54,14 @@ def test_initialize_nevergrad():


def test_initialize_skopt():
assert skopt is not None
s_opt = SkoptOptimizer()
s_opt.initialize({'g'}, g=[1, 30], popsize=30, rounds=2)
assert isinstance(s_opt.optim, SOptimizer)
assert isinstance(s_opt.optim, skopt.Optimizer)
assert_equal(len(s_opt.optim.space.dimensions), 1)

s_opt.initialize({'g', 'E'}, g=[1, 30], E=[2, 20], popsize=30, rounds=2)
assert isinstance(s_opt.optim, SOptimizer)
assert isinstance(s_opt.optim, skopt.Optimizer)
assert_equal(len(s_opt.optim.space.dimensions), 2)

assert_raises(TypeError, s_opt.initialize, ['g'], g=[1], popsize=30)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
'numpy>=1.21',
'brian2>=2.2',
'nevergrad>=0.4',
'scikit-learn>=0.22',
'tqdm',
'pandas',
]
Expand Down

0 comments on commit 284a80f

Please sign in to comment.