Skip to content

Commit

Permalink
test: add a test using linear constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Oct 18, 2024
1 parent fc544a6 commit d088be6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
5 changes: 3 additions & 2 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def optimize_acqf_homotopy(
]:
if kwarg_dict.get("return_best_only", None) is not False:
warnings.warn(
f"`return_best_only` is set to True in `{kwarg_dict_name}`, override to False."
f"`return_best_only` is not False in `{kwarg_dict_name}`, override to False."
)
kwarg_dict["return_best_only"] = False

Expand All @@ -101,6 +101,7 @@ def optimize_acqf_homotopy(
"removing it in favour of `batch_initial_conditions` given to "
"`optimize_acqf_homotopy`."
)
# are pops dangerious here given no copy? if repeatedly reusing kwarg_dict it could create issues
kwarg_dict.pop("batch_initial_conditions")

Check warning on line 105 in botorch/optim/optimize_homotopy.py

View check run for this annotation

Codecov / codecov/patch

botorch/optim/optimize_homotopy.py#L105

Added line #L105 was not covered by tests

for arg_name, arg_value in [("acq_function", acq_function), ("bounds", bounds)]:
Expand Down Expand Up @@ -133,7 +134,7 @@ def optimize_acqf_homotopy(
"removing it as we set `batch_initial_conditions` to the candidates "
"returned by homotopy loop for the final optimization."
)
optimize_acqf_final_kwargs.pop("raw_samples")
optimize_acqf_final_kwargs.pop("raw_samples") # are pops dangerious here given no copy? see above

candidate_list, acq_value_list = [], []
if q > 1:
Expand Down
33 changes: 28 additions & 5 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_optimize_acqf_homotopy(self):
acq_function=acqf,
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs,
optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs},
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"post_processing_func": lambda x: x.round(),
Expand All @@ -146,9 +146,9 @@ def test_optimize_acqf_homotopy(self):
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs={
**optimize_acqf_core_kwargs,
"fixed_features": fixed_features,
"fixed_features": fixed_features, # this is done to mimic old behaviour which was perhaps a bug?
},
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs,
optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs},
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

Expand All @@ -163,12 +163,35 @@ def test_optimize_acqf_homotopy(self):
**optimize_acqf_core_kwargs,
"fixed_features": fixed_features,
},
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs,
optimize_acqf_final_kwargs={**optimize_acqf_core_kwargs},
)
self.assertEqual(candidate.shape, torch.Size([3, 2]))
self.assertEqual(acqf_val.shape, torch.Size([3]))

# with linear constraints
constraints = [( # X[..., 0] + X[..., 1] >= 2.
torch.tensor([0, 1], device=self.device),
torch.ones(2, device=self.device, dtype=torch.double),
2.0,
)]

acqf = PosteriorMean(model=model)
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs={
**optimize_acqf_core_kwargs,
"inequality_constraints": constraints,
},
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"inequality_constraints": constraints,
},
)
self.assertEqual(candidate.shape, torch.Size([1, 2]))
self.assertGreaterEqual(candidate.sum(), 2 * torch.ones(1, **tkwargs))

def test_prune_candidates(self):
tkwargs = {"device": self.device, "dtype": torch.double}
Expand Down Expand Up @@ -225,7 +248,7 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock):
acq_function=acqf,
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs,
optimize_acqf_loop_kwargs={**optimize_acqf_core_kwargs},
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"post_processing_func": lambda x: x.round(),
Expand Down

0 comments on commit d088be6

Please sign in to comment.