Skip to content

Commit

Permalink
Prepare for 0.1.7 release by cleaning up multiobjective.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651474895
  • Loading branch information
xingyousong authored and copybara-github committed Jul 11, 2024
1 parent ab8bc4b commit 41c00cc
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 23 deletions.
2 changes: 1 addition & 1 deletion vizier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@

sys.path.append(PROTO_ROOT)

__version__ = "0.1.16"
__version__ = "0.1.17"
18 changes: 4 additions & 14 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,10 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
_output_warper: output_warpers.OutputWarper = attr.field(
factory=output_warpers.create_default_warper, kw_only=True
)

# Multi-objective parameters.
_num_scalarizations: int = attr.field(default=1000, kw_only=True)
_ref_scaling: float = attr.field(default=0.01, kw_only=True)
_num_ehvi_samples: Optional[int] = attr.field(default=None, kw_only=True)

# ------------------------------------------------------------------
# Internal attributes which should not be set by callers.
# ------------------------------------------------------------------
Expand Down Expand Up @@ -212,16 +211,7 @@ def __attrs_post_init__(self):
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

if self._num_ehvi_samples: # Sampled EHVI.
reduction_fn = lambda x: jnp.mean(jax.nn.relu(x), axis=[0, 1])
acquisition_fn = acq_lib.Sample(self._num_ehvi_samples)
else: # Scalarized UCB.
reduction_fn = lambda x: jnp.mean(x, axis=0)
acquisition_fn = acq_lib.UCB()

def acq_fn_factory(
data: types.ModelData,
) -> acq_lib.AcquisitionFunction:
def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
Expand All @@ -236,9 +226,9 @@ def acq_fn_factory(
jnp.max(scalarizer(labels_array), axis=-1) if has_labels else None
)
return acq_lib.ScalarizedAcquisition(
acquisition_fn,
acq_lib.UCB(),
scalarizer,
reduction_fn=reduction_fn,
reduction_fn=lambda x: jnp.mean(x, axis=0),
max_scalarized=max_scalarized,
)

Expand Down
10 changes: 2 additions & 8 deletions vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,7 @@ def _qei_factory(data: types.ModelData) -> acquisitions.AcquisitionFunction:
iters * n_parallel,
)

@parameterized.parameters(
dict(num_ehvi_samples=11),
dict(num_ehvi_samples=None),
)
def test_multi_metrics(self, num_ehvi_samples: int | None):
def test_multi_metrics(self):
search_space = vz.SearchSpace()
search_space.root.add_float_param('x0', -5.0, 5.0)
problem = vz.ProblemStatement(
Expand All @@ -487,9 +483,7 @@ def test_multi_metrics(self, num_ehvi_samples: int | None):
)

iters = 2
designer = gp_bandit.VizierGPBandit(
problem, num_ehvi_samples=num_ehvi_samples
)
designer = gp_bandit.VizierGPBandit(problem)
self.assertLen(
test_runners.RandomMetricsRunner(
problem,
Expand Down

0 comments on commit 41c00cc

Please sign in to comment.