Skip to content

Commit

Permalink
Patch Eagle designer to handle singleton parameters.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688167486
  • Loading branch information
belenkil authored and copybara-github committed Oct 21, 2024
1 parent fa018db commit dbaf04f
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
34 changes: 22 additions & 12 deletions vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
import json
import time
from typing import Optional, Sequence

from absl import logging
import attr
import numpy as np
Expand Down Expand Up @@ -136,8 +135,10 @@ def __init__(
# ensure non-repeated behavior.
seed = int(time.time())
logging.info(
('A seed was not provided to Eagle Strategy designer constructor. '
'Setting the seed to %s'),
(
'A seed was not provided to Eagle Strategy designer constructor. '
'Setting the seed to %s'
),
str(seed),
)
self._scaler = converters.ProblemAndTrialsScaler(problem_statement)
Expand Down Expand Up @@ -172,7 +173,8 @@ def dump(self) -> vz.Metadata:
metadata = vz.Metadata()
metadata.ns('eagle')['rng'] = serialization.serialize_rng(self._rng)
metadata.ns('eagle')['firefly_pool'] = (
serialization.partially_serialize_firefly_pool(self._firefly_pool))
serialization.partially_serialize_firefly_pool(self._firefly_pool)
)
metadata.ns('eagle')['serialization_version'] = 'v1'
metadata.ns('eagle')['dump_timestamp'] = str(time.time())
metadata.ns('eagle').ns('random_designer').attach(
Expand All @@ -196,23 +198,28 @@ def load(self, metadata: vz.Metadata) -> None:
"""
if metadata.ns('eagle').get('serialization_version', default=None) is None:
# First time the designer is called, so the namespace doesn't exist yet.
logging.info('Eagle designer was called for the first time. No state was'
' recovered.')
logging.info(
'Eagle designer was called for the first time. No state was'
' recovered.'
)
else:
try:
self._rng = serialization.restore_rng(metadata.ns('eagle')['rng'])
except Exception as e:
raise serializable.FatalDecodeError(
"Couldn't load random generator from metadata.") from e
"Couldn't load random generator from metadata."
) from e
self._utils.rng = self._rng

try:
firefly_pool = metadata.ns('eagle')['firefly_pool']
self._firefly_pool = serialization.restore_firefly_pool(
self._utils, firefly_pool)
self._utils, firefly_pool
)
except Exception as e:
raise serializable.HarmlessDecodeError(
"Couldn't load firefly pool from metadata.") from e
"Couldn't load firefly pool from metadata."
) from e

try:
self._initial_designer = quasi_random.QuasiRandomDesigner(
Expand All @@ -225,8 +232,10 @@ def load(self, metadata: vz.Metadata) -> None:
) from e

logging.info(
('Eagle designer restored state from timestamp %s. Firefly pool'
' now contains %s fireflies.'),
(
'Eagle designer restored state from timestamp %s. Firefly pool'
' now contains %s fireflies.'
),
metadata.ns('eagle')['dump_timestamp'],
self._firefly_pool.size,
)
Expand Down Expand Up @@ -386,7 +395,8 @@ def _update_one(self, trial: vz.Trial) -> None:
return

elif not trial.infeasible and self._utils.is_better_than(
trial, parent_fly.trial):
trial, parent_fly.trial
):
# There's improvement. Update the parent with the new trial.
parent_fly.trial = trial
parent_fly.generation += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

from __future__ import annotations

"""Tests for eagle_strategy."""
"""Tests for eagle_strategy module."""

import numpy as np
from vizier import algorithms as vza
from vizier import pyvizier as vz
from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy
from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy_utils
from vizier._src.algorithms.designers.eagle_strategy import testing
from vizier.testing import test_studies

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -328,6 +329,25 @@ def _suggest_and_update(
_suggest_and_update(eagle_designer, tid, infeasible=False)
tid += 1

def test_on_singleton_search_space(self):
problem = vz.ProblemStatement(
test_studies.flat_space_with_all_types_with_singletons()
)
problem.metric_information.append(
vz.MetricInformation(
name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE
)
)
designer = eagle_strategy.EagleStrategyDesigner(problem)
initial_suggestions = designer.suggest(25)
trials = []
for idx, suggestion in enumerate(initial_suggestions):
trial = suggestion.to_trial(idx)
trial.complete(vz.Measurement({'metric': 1.0}))
trials.append(trial)
designer.update(vza.CompletedTrials(trials), vza.ActiveTrials([]))
self.assertLen(designer.suggest(25), 25)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def _compute_canonical_distance_squared_by_type(
dist_squared_by_type[param_config.type] += int(p1_value == p2_value)
else:
min_value, max_value = param_config.bounds
dist = (p1_value - p2_value) / (max_value - min_value)
if max_value == min_value:
dist = 0
else:
dist = (p1_value - p2_value) / (max_value - min_value)
dist_squared_by_type[param_config.type] += dist * dist

return dist_squared_by_type
Expand Down
18 changes: 18 additions & 0 deletions vizier/testing/test_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,24 @@ def flat_space_with_all_types() -> vz.SearchSpace:
return space


def flat_space_with_all_types_with_singletons() -> vz.SearchSpace:
"""Search space with all parameter types."""

space = vz.SearchSpace()
root = space.root
root.add_float_param('double_singleton', min_value=1.0, max_value=1.0)
root.add_int_param('integer_singleton', min_value=5, max_value=5)
root.add_categorical_param('categorical_singleton', feasible_values=['a'])
root.add_discrete_param('discrete_singleton', feasible_values=[3])
root.add_float_param('double', min_value=0.0, max_value=5.0)
root.add_int_param('integer', min_value=0, max_value=10)
root.add_categorical_param(
'categorical', feasible_values=['a', '1', 'b', '2']
)
root.add_discrete_param('discrete', feasible_values=[0.0, 0.6])
return space


def conditional_automl_space() -> vz.SearchSpace:
"""Conditional space for a simple AutoML task."""
space = vz.SearchSpace()
Expand Down

0 comments on commit dbaf04f

Please sign in to comment.