Skip to content

Commit

Permalink
fix: restore the pruning test
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Oct 17, 2024
1 parent 038a735 commit fc544a6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def optimize_acqf_homotopy(
("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs),
("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs),
]:
if "return_best_only" in kwarg_dict:
if kwarg_dict.get("return_best_only", None) is not False:
warnings.warn(
f"`return_best_only` is set to True in `{kwarg_dict_name}`, override to False."
)
Expand Down
28 changes: 20 additions & 8 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def test_optimize_acqf_homotopy(self):
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs,
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}),
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"post_processing_func": lambda x: x.round(),
},
)
self.assertEqual(candidate, torch.zeros(1, **tkwargs))
self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs))
Expand All @@ -141,8 +144,11 @@ def test_optimize_acqf_homotopy(self):
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}),
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs
optimize_acqf_loop_kwargs={
**optimize_acqf_core_kwargs,
"fixed_features": fixed_features,
},
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs,
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

Expand All @@ -153,8 +159,11 @@ def test_optimize_acqf_homotopy(self):
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}),
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs
optimize_acqf_loop_kwargs={
**optimize_acqf_core_kwargs,
"fixed_features": fixed_features,
},
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs,
)
self.assertEqual(candidate.shape, torch.Size([3, 2]))
self.assertEqual(acqf_val.shape, torch.Size([3]))
Expand Down Expand Up @@ -207,8 +216,8 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock):
model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2)
acqf = PosteriorMean(model=model)
optimize_acqf_core_kwargs = {
"num_restarts":2,
"raw_samples":16,
"num_restarts": 4,
"raw_samples": 16,
}

candidate, acqf_val = optimize_acqf_homotopy(
Expand All @@ -217,7 +226,10 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock):
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs,
optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}),
optimize_acqf_final_kwargs={
**optimize_acqf_core_kwargs,
"post_processing_func": lambda x: x.round(),
},
)
# First time we expect to call `prune_candidates` with 4 candidates
self.assertEqual(
Expand Down

0 comments on commit fc544a6

Please sign in to comment.