Skip to content

Commit

Permalink
Add PosteriorTransform to get_optimal_samples and optimize_posterior_…
Browse files Browse the repository at this point in the history
…samples (#2576)

Summary:

Added `posterior_transform` arg to get_optimal_samples to enable posterior sampling-based (xES, TestSet IG) acquisition functions with minimization problems. Intended use in one-shot settings.

Differential Revision: D64266499
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Oct 15, 2024
1 parent fb0c667 commit a3bee4a
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 57 deletions.
35 changes: 30 additions & 5 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
IdentityMCObjective,
MCAcquisitionObjective,
PosteriorTransform,
ScalarizedPosteriorTransform,
)
from botorch.exceptions.errors import (
BotorchTensorDimensionError,
Expand All @@ -28,10 +29,11 @@
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.pathwise import draw_matheron_paths
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.sampling import optimize_posterior_samples
from botorch.utils.transforms import is_ensemble, normalize_indices
from gpytorch.models import GP
from torch import Tensor


Expand Down Expand Up @@ -486,12 +488,13 @@ def project_to_sample_points(X: Tensor, sample_points: Tensor) -> Tensor:


def get_optimal_samples(
model: Model,
model: GP,
bounds: Tensor,
num_optima: int,
raw_samples: int = 1024,
num_restarts: int = 20,
maximize: bool = True,
posterior_transform: PosteriorTransform | None = None,
) -> tuple[Tensor, Tensor]:
"""Draws sample paths from the posterior and maximizes the samples using GD.
Expand All @@ -505,17 +508,39 @@ def get_optimal_samples(
num_restarts (int, optional): The number of candidates to do gradient-based
optimization on. Defaults to 20.
maximize: Whether to maximize or minimize the samples.
posterior_transform: A PosteriorTransform, used to negate the objective or
scalarize multi-output models.
Returns:
Tuple[Tensor, Tensor]: The optimal input locations and corresponding
outputs, x* and f*.
"""
paths = draw_matheron_paths(model, sample_shape=torch.Size([num_optima]))
if posterior_transform and not isinstance(
posterior_transform, ScalarizedPosteriorTransform
):
raise ValueError(
"Only the ScalarizedPosteriorTransform is supported for "
"get_optimal_samples."
)

# To avoid ambiguity, we disallow two types of negation.
# TODO remove maximize argument entirely - currently in use in input contstructors
# and acquisition functions
if not maximize and posterior_transform:
raise ValueError(
"Minimizing the samples are not supported with a `posterior_transform`."
)
elif not maximize:
posterior_transform = ScalarizedPosteriorTransform(
weights=-torch.ones(1).to(bounds)
)
paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
optimal_inputs, optimal_outputs = optimize_posterior_samples(
paths,
paths=paths,
bounds=bounds,
raw_samples=raw_samples,
num_restarts=num_restarts,
maximize=maximize,
posterior_transform=posterior_transform,
)
return optimal_inputs, optimal_outputs
54 changes: 30 additions & 24 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable
from contextlib import contextmanager
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING

import numpy as np
import scipy
Expand All @@ -36,7 +36,12 @@


if TYPE_CHECKING:
from botorch.sampling.pathwise.paths import SamplePath # pragma: no cover
from botorch.acquisition.objective import ( # pragma: no cover
ScalarizedPosteriorTransform,
)
from botorch.models.deterministic import ( # pragma: no cover
GenericDeterministicModel,
)


@contextmanager
Expand Down Expand Up @@ -988,13 +993,12 @@ def sparse_to_dense_constraints(


def optimize_posterior_samples(
paths: SamplePath,
paths: GenericDeterministicModel,
bounds: Tensor,
candidates: Tensor | None = None,
raw_samples: int | None = 1024,
raw_samples: int = 1024,
num_restarts: int = 20,
maximize: bool = True,
**kwargs: Any,
posterior_transform: ScalarizedPosteriorTransform | None = None,
) -> tuple[Tensor, Tensor]:
r"""Cheaply maximizes posterior samples by random querying followed by vanilla
gradient descent on the best num_restarts points.
Expand All @@ -1006,38 +1010,40 @@ def optimize_posterior_samples(
which acts as extra initial guesses for the optimization routine.
raw_samples: The number of samples with which to query the samples initially.
num_restarts: The number of points selected for gradient-based optimization.
maximize: Boolean indicating whether to maimize or minimize
posterior_transform: A ScalarizedPosteriorTransform used to negate the
objective or linearly combine multiple outputs.
Returns:
A two-element tuple containing:
- X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
- f_opt: A `num_optima x [batch_size] x 1`-dim tensor of optimal outputs f*.
X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
f_opt: A `num_optima x [batch_size] x m`-dim tensor of optimal outputs f*,
where m is the number of outputs.
"""
if maximize:

def path_func(x):
return paths(x)

else:
def path_func(x) -> Tensor:
res = paths(x)
if posterior_transform:
res = posterior_transform.evaluate(res)

def path_func(x):
return -paths(x)
return res.squeeze(-1)

candidate_set = unnormalize(
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(raw_samples), bounds
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples),
bounds=bounds,
)

# queries all samples on all candidates - output shape
# raw_samples * num_optima * num_models
candidate_queries = path_func(candidate_set)
argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices
X_top_k = candidate_set[argtop_k, :]

# to avoid circular import, the import occurs here
from botorch.generation.gen import gen_candidates_torch
from botorch.generation.gen import gen_candidates_scipy

X_top_k, f_top_k = gen_candidates_torch(
X_top_k, path_func, lower_bounds=bounds[0], upper_bounds=bounds[1], **kwargs
X_top_k, f_top_k = gen_candidates_scipy(
X_top_k,
path_func,
lower_bounds=bounds[0],
upper_bounds=bounds[1],
)
f_opt, arg_opt = f_top_k.max(dim=-1, keepdim=True)

Expand All @@ -1050,6 +1056,6 @@ def path_func(x):
X_opt = X_top_k.reshape(final_shape.numel(), num_restarts, -1)[
torch.arange(final_shape.numel()), arg_opt.flatten()
].reshape(*final_shape, -1)
if not maximize:
f_opt = -f_opt
f_opt = paths(X_opt.unsqueeze(-2)).squeeze(-2)

return X_opt, f_opt
82 changes: 63 additions & 19 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

import torch

from botorch.acquisition.objective import GenericMCObjective, LearnedObjective
from botorch.acquisition.objective import (
ExpectationPosteriorTransform,
GenericMCObjective,
LearnedObjective,
ScalarizedPosteriorTransform,
)
from botorch.acquisition.utils import (
compute_best_feasible_objective,
expand_trace_observations,
Expand Down Expand Up @@ -418,26 +423,65 @@ def test_get_optimal_samples(self):

bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
Y = torch.sin(X).sum(dim=-1, keepdim=True).to(dtype)
model = SingleTaskGP(X, Y)
X_opt, f_opt = get_optimal_samples(
model, bounds, num_optima=num_optima, **for_testing_speed_kwargs
)
X_opt, f_opt_min = get_optimal_samples(
model,
bounds,
num_optima=num_optima,
maximize=False,
**for_testing_speed_kwargs,
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
model = SingleTaskGP(train_X=X, train_Y=Y)
posterior_transform = ScalarizedPosteriorTransform(
weights=-torch.ones(1, dtype=dtype)
)

correct_X_shape = (num_optima,) + batch_shape + (dims,)
correct_f_shape = (num_optima,) + batch_shape + (1,)
self.assertEqual(X_opt.shape, correct_X_shape)
self.assertEqual(f_opt.shape, correct_f_shape)
# asserting that the solutions found by minimization the samples are smaller
# than those found by maximization
self.assertTrue(torch.all(f_opt_min < f_opt))
for ps in [None, posterior_transform]:
with torch.random.fork_rng():
torch.manual_seed(0)
X_opt, f_opt = get_optimal_samples(
model=model,
bounds=bounds,
num_optima=num_optima,
posterior_transform=ps,
**for_testing_speed_kwargs,
)
correct_X_shape = (num_optima,) + batch_shape + (dims,)
correct_f_shape = (num_optima,) + batch_shape + (1,)
self.assertEqual(X_opt.shape, correct_X_shape)
self.assertEqual(f_opt.shape, correct_f_shape)

with torch.random.fork_rng():
torch.manual_seed(0)
X_opt_min, f_opt_min = get_optimal_samples(
model=model,
bounds=bounds,
num_optima=num_optima,
maximize=False,
**for_testing_speed_kwargs,
)
# check that the minimum is the same for minimize and
# negative posterior transform
self.assertAllClose(X_opt_min, X_opt)

with self.assertRaisesRegex(
ValueError,
"Minimizing the samples are not supported with a `posterior_transform`.",
):
get_optimal_samples(
model=model,
bounds=bounds,
num_optima=num_optima,
maximize=False,
posterior_transform=posterior_transform,
**for_testing_speed_kwargs,
)
with self.assertRaisesRegex(
ValueError,
"Only the ScalarizedPosteriorTransform is supported for "
"get_optimal_samples.",
):
get_optimal_samples(
model=model,
bounds=bounds,
num_optima=num_optima,
maximize=False,
posterior_transform=ExpectationPosteriorTransform(n_w=5),
**for_testing_speed_kwargs,
)


class TestPreferenceUtils(BotorchTestCase):
Expand Down
78 changes: 69 additions & 9 deletions test/utils/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import numpy as np
import torch
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.exceptions.errors import BotorchError
from botorch.exceptions.warnings import UserInputWarning
from botorch.models.gp_regression import SingleTaskGP
from botorch.sampling.pathwise import draw_matheron_paths
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.sampling import (
_convert_bounds_to_inequality_constraints,
batched_multinomial,
Expand Down Expand Up @@ -552,23 +553,33 @@ def test_optimize_posterior_samples(self):
torch.manual_seed(1)
dims = 2
dtype = torch.float64
eps = 1e-6
for_testing_speed_kwargs = {"raw_samples": 512, "num_restarts": 10}
eps = 1e-4
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
nums_optima = (1, 7)
batch_shapes = ((), (3,), (5, 2))
for num_optima, batch_shape in itertools.product(nums_optima, batch_shapes):
batch_shapes = ((), (2,), (3, 2))
posterior_transforms = (
None,
ScalarizedPosteriorTransform(weights=-torch.ones(1, dtype=dtype)),
)
for num_optima, batch_shape, posterior_transform in itertools.product(
nums_optima, batch_shapes, posterior_transforms
):
bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
X = torch.rand(*batch_shape, 13, dims, dtype=dtype)
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
Y = torch.pow(X - 0.5, 2).sum(dim=-1, keepdim=True)

# having a noiseless model all but guarantees that the found optima
# will be better than the observations
model = SingleTaskGP(X, Y, torch.full_like(Y, eps))
paths = draw_matheron_paths(
model.covar_module.lengthscale = 0.5
paths = get_matheron_path_model(
model=model, sample_shape=torch.Size([num_optima])
)
X_opt, f_opt = optimize_posterior_samples(
paths, bounds, **for_testing_speed_kwargs
paths=paths,
bounds=bounds,
posterior_transform=posterior_transform,
**for_testing_speed_kwargs,
)

correct_X_shape = (num_optima,) + batch_shape + (dims,)
Expand All @@ -581,4 +592,53 @@ def test_optimize_posterior_samples(self):

# Check that the all found optima are larger than the observations
# This is not 100% deterministic, but just about.
self.assertTrue(torch.all(f_opt > Y.max(dim=-2).values))
Y_queries = paths(X)
# this is when we negate, so the values should be smaller
if posterior_transform:
self.assertTrue(torch.all(f_opt < Y_queries.min(dim=-2).values))

# otherwise, larger
else:
self.assertTrue(torch.all(f_opt > Y_queries.max(dim=-2).values))

def test_optimize_posterior_samples_multi_objective(self):
# Fix the random seed to prevent flaky failures.
torch.manual_seed(1)
dims = 2
dtype = torch.float64
eps = 1e-4
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
num_optima = 5
batch_shape = (3,)

# test that multi-output models are supported if there is an appropriate
# scalarization
bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
Y1 = torch.pow(X - 0.5, 2).sum(dim=-1, keepdim=True)
Y2 = torch.cos(X * 3).sum(dim=-1, keepdim=True)
Y = torch.cat([Y1, Y2], dim=-1)
# having a noiseless model all but guarantees that the found optima
# will be better than the observations
model = SingleTaskGP(X, Y, torch.full_like(Y, eps))
model.covar_module.lengthscale = 0.5
posterior_transform = ScalarizedPosteriorTransform(
weights=torch.ones(2, dtype=dtype)
)
paths = get_matheron_path_model(
model=model,
sample_shape=torch.Size([num_optima]),
)
X_opt, f_opt = optimize_posterior_samples(
paths=paths,
bounds=bounds,
posterior_transform=posterior_transform,
**for_testing_speed_kwargs,
)

correct_X_shape = (num_optima,) + batch_shape + (dims,)
correct_f_shape = (num_optima,) + batch_shape + (2,)
self.assertEqual(X_opt.shape, correct_X_shape)
self.assertEqual(f_opt.shape, correct_f_shape)
self.assertTrue(torch.all(X_opt >= bounds[0]))
self.assertTrue(torch.all(X_opt <= bounds[1]))

0 comments on commit a3bee4a

Please sign in to comment.