From 926abab815b6ab05ad95cec1044afb9a20dee1d6 Mon Sep 17 00:00:00 2001 From: uri-granta <50578464+uri-granta@users.noreply.github.com> Date: Mon, 15 Jul 2024 11:22:30 +0100 Subject: [PATCH] Allow TR num_local_datasets to be set to 0 (#859) Co-authored-by: Uri Granta --- tests/unit/test_ask_tell_optimization.py | 23 +++++++++++++++++++++++ trieste/ask_tell_optimization.py | 8 ++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_ask_tell_optimization.py b/tests/unit/test_ask_tell_optimization.py index e75c777c38..c3aeab8997 100644 --- a/tests/unit/test_ask_tell_optimization.py +++ b/tests/unit/test_ask_tell_optimization.py @@ -957,3 +957,26 @@ def test_ask_tell_optimizer_dataset_len_raises_on_inconsistently_sized_datasets( ) with pytest.raises(ValueError): AskTellOptimizer.dataset_len({}) + + +@pytest.mark.parametrize("optimizer", OPTIMIZERS) +def test_ask_tell_optimizer_doesnt_blow_up_with_no_local_datasets( + search_space: Box, + init_dataset: Dataset, + model: TrainableProbabilisticModel, + optimizer: OptimizerType, +) -> None: + pseudo_local_acquisition_rule = FixedLocalAcquisitionRule([[0.0]], 0) + ask_tell = optimizer( + search_space, init_dataset, model, pseudo_local_acquisition_rule, track_data=False + ) + ask_tell._datasets[OBJECTIVE] += mk_dataset( + [[x / 100] for x in range(75, 75 + 5)], [[x / 100] for x in range(75, 75 + 5)] + ) + ask_tell.tell(ask_tell.dataset) + optimizer.from_state( + ask_tell.to_state(), + search_space, + pseudo_local_acquisition_rule, + track_data=False, + ) diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index e5e4579c0a..eaeccd153d 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -297,7 +297,9 @@ def __init__( if local_data_len is not None: # infer new dataset indices from change in dataset sizes num_new_points = self._dataset_len - local_data_len - if num_new_points < 0 or num_new_points % num_local_datasets != 0: + if num_new_points < 0 or ( + num_local_datasets > 0 and num_new_points % num_local_datasets != 0 + ): raise ValueError( "Cannot infer new data points as datasets haven't increased by " f"a multiple of {num_local_datasets}" @@ -669,7 +671,9 @@ def tell( # infer dataset indices from change in dataset sizes new_dataset_len = self.dataset_len(new_data) num_new_points = new_dataset_len - self._dataset_len - if num_new_points < 0 or num_new_points % num_local_datasets != 0: + if num_new_points < 0 or ( + num_local_datasets > 0 and num_new_points % num_local_datasets != 0 + ): raise ValueError( "Cannot infer new data points as datasets haven't increased by " f"a multiple of {num_local_datasets}"