diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index 87ba7b41cc..5f218eeca6 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -85,7 +85,7 @@ def optimize_acqf_homotopy( ("optimize_acqf_loop_kwargs", optimize_acqf_loop_kwargs), ("optimize_acqf_final_kwargs", optimize_acqf_final_kwargs), ]: - if "return_best_only" in kwarg_dict: + if kwarg_dict.get("return_best_only", None) is not False: warnings.warn( f"`return_best_only` is set to True in `{kwarg_dict_name}`, override to False." ) diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py index e2f2d1c611..a531e525fb 100644 --- a/test/optim/test_homotopy.py +++ b/test/optim/test_homotopy.py @@ -125,7 +125,10 @@ def test_optimize_acqf_homotopy(self): bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}), + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "post_processing_func": lambda x: x.round(), + }, ) self.assertEqual(candidate, torch.zeros(1, **tkwargs)) self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs)) @@ -141,8 +144,11 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}), - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "fixed_features": fixed_features, + }, + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs, ) self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) @@ -153,8 +159,11 @@ def test_optimize_acqf_homotopy(self): acq_function=acqf, bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), - optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs.update({"fixed_features":fixed_features}), - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs + optimize_acqf_loop_kwargs={ + **optimize_acqf_core_kwargs, + "fixed_features": fixed_features, + }, + optimize_acqf_final_kwargs=optimize_acqf_core_kwargs, ) self.assertEqual(candidate.shape, torch.Size([3, 2])) self.assertEqual(acqf_val.shape, torch.Size([3])) @@ -207,8 +216,8 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) acqf = PosteriorMean(model=model) optimize_acqf_core_kwargs = { - "num_restarts":2, - "raw_samples":16, + "num_restarts": 4, + "raw_samples": 16, } candidate, acqf_val = optimize_acqf_homotopy( @@ -217,7 +226,10 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): bounds=torch.tensor([[-10], [5]]).to(**tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), optimize_acqf_loop_kwargs=optimize_acqf_core_kwargs, - optimize_acqf_final_kwargs=optimize_acqf_core_kwargs.update({"post_processing_func":lambda x: x.round()}), + optimize_acqf_final_kwargs={ + **optimize_acqf_core_kwargs, + "post_processing_func": lambda x: x.round(), + }, ) # First time we expect to call `prune_candidates` with 4 candidates self.assertEqual(