diff --git a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py index f071d6c12..4f1bc0ced 100644 --- a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py +++ b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py @@ -73,7 +73,6 @@ import json import time from typing import Optional, Sequence - from absl import logging import attr import numpy as np @@ -136,8 +135,10 @@ def __init__( # ensure non-repeated behavior. seed = int(time.time()) logging.info( - ('A seed was not provided to Eagle Strategy designer constructor. ' - 'Setting the seed to %s'), + ( + 'A seed was not provided to Eagle Strategy designer constructor. ' + 'Setting the seed to %s' + ), str(seed), ) self._scaler = converters.ProblemAndTrialsScaler(problem_statement) @@ -172,7 +173,8 @@ def dump(self) -> vz.Metadata: metadata = vz.Metadata() metadata.ns('eagle')['rng'] = serialization.serialize_rng(self._rng) metadata.ns('eagle')['firefly_pool'] = ( - serialization.partially_serialize_firefly_pool(self._firefly_pool)) + serialization.partially_serialize_firefly_pool(self._firefly_pool) + ) metadata.ns('eagle')['serialization_version'] = 'v1' metadata.ns('eagle')['dump_timestamp'] = str(time.time()) metadata.ns('eagle').ns('random_designer').attach( @@ -196,23 +198,28 @@ def load(self, metadata: vz.Metadata) -> None: """ if metadata.ns('eagle').get('serialization_version', default=None) is None: # First time the designer is called, so the namespace doesn't exist yet. - logging.info('Eagle designer was called for the first time. No state was' - ' recovered.') + logging.info( + 'Eagle designer was called for the first time. No state was' + ' recovered.' + ) else: try: self._rng = serialization.restore_rng(metadata.ns('eagle')['rng']) except Exception as e: raise serializable.FatalDecodeError( - "Couldn't load random generator from metadata.") from e + "Couldn't load random generator from metadata." + ) from e self._utils.rng = self._rng try: firefly_pool = metadata.ns('eagle')['firefly_pool'] self._firefly_pool = serialization.restore_firefly_pool( - self._utils, firefly_pool) + self._utils, firefly_pool + ) except Exception as e: raise serializable.HarmlessDecodeError( - "Couldn't load firefly pool from metadata.") from e + "Couldn't load firefly pool from metadata." + ) from e try: self._initial_designer = quasi_random.QuasiRandomDesigner( @@ -225,8 +232,10 @@ def load(self, metadata: vz.Metadata) -> None: ) from e logging.info( - ('Eagle designer restored state from timestamp %s. Firefly pool' - ' now contains %s fireflies.'), + ( + 'Eagle designer restored state from timestamp %s. Firefly pool' + ' now contains %s fireflies.' + ), metadata.ns('eagle')['dump_timestamp'], self._firefly_pool.size, ) @@ -386,7 +395,8 @@ def _update_one(self, trial: vz.Trial) -> None: return elif not trial.infeasible and self._utils.is_better_than( - trial, parent_fly.trial): + trial, parent_fly.trial + ): # There's improvement. Update the parent with the new trial. parent_fly.trial = trial parent_fly.generation += 1 diff --git a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_test.py b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_test.py index ea93e3717..e63526448 100644 --- a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_test.py +++ b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_test.py @@ -14,7 +14,7 @@ from __future__ import annotations -"""Tests for eagle_strategy.""" +"""Tests for eagle_strategy module.""" import numpy as np from vizier import algorithms as vza @@ -22,6 +22,7 @@ from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy_utils from vizier._src.algorithms.designers.eagle_strategy import testing +from vizier.testing import test_studies from absl.testing import absltest from absl.testing import parameterized @@ -328,6 +329,25 @@ def _suggest_and_update( _suggest_and_update(eagle_designer, tid, infeasible=False) tid += 1 + def test_on_singleton_search_space(self): + problem = vz.ProblemStatement( + test_studies.flat_space_with_all_types_with_singletons() + ) + problem.metric_information.append( + vz.MetricInformation( + name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE + ) + ) + designer = eagle_strategy.EagleStrategyDesigner(problem) + initial_suggestions = designer.suggest(25) + trials = [] + for idx, suggestion in enumerate(initial_suggestions): + trial = suggestion.to_trial(idx) + trial.complete(vz.Measurement({'metric': 1.0})) + trials.append(trial) + designer.update(vza.CompletedTrials(trials), vza.ActiveTrials([])) + self.assertLen(designer.suggest(25), 25) + if __name__ == '__main__': absltest.main() diff --git a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py index fc1543c9e..c1bd0b361 100644 --- a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py +++ b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py @@ -198,7 +198,10 @@ def _compute_canonical_distance_squared_by_type( dist_squared_by_type[param_config.type] += int(p1_value == p2_value) else: min_value, max_value = param_config.bounds - dist = (p1_value - p2_value) / (max_value - min_value) + if max_value == min_value: + dist = 0 + else: + dist = (p1_value - p2_value) / (max_value - min_value) dist_squared_by_type[param_config.type] += dist * dist return dist_squared_by_type diff --git a/vizier/testing/test_studies.py b/vizier/testing/test_studies.py index cdbac44fa..0a2e9537b 100644 --- a/vizier/testing/test_studies.py +++ b/vizier/testing/test_studies.py @@ -74,6 +74,24 @@ def flat_space_with_all_types() -> vz.SearchSpace: return space +def flat_space_with_all_types_with_singletons() -> vz.SearchSpace: + """Search space with all parameter types.""" + + space = vz.SearchSpace() + root = space.root + root.add_float_param('double_singleton', min_value=1.0, max_value=1.0) + root.add_int_param('integer_singleton', min_value=5, max_value=5) + root.add_categorical_param('categorical_singleton', feasible_values=['a']) + root.add_discrete_param('discrete_singleton', feasible_values=[3]) + root.add_float_param('double', min_value=0.0, max_value=5.0) + root.add_int_param('integer', min_value=0, max_value=10) + root.add_categorical_param( + 'categorical', feasible_values=['a', '1', 'b', '2'] + ) + root.add_discrete_param('discrete', feasible_values=[0.0, 0.6]) + return space + + def conditional_automl_space() -> vz.SearchSpace: """Conditional space for a simple AutoML task.""" space = vz.SearchSpace()