Skip to content

Commit

Permalink
AskTell support for indices to differently sized datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Nov 4, 2024
1 parent d6f70ca commit dc9e175
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 55 deletions.
18 changes: 8 additions & 10 deletions tests/unit/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_ask_tell_optimizer_returns_complete_state(
assert_datasets_allclose(state.record.dataset, init_dataset)
assert isinstance(state.record.model, type(model))
assert state.record.acquisition_state is None
assert state.local_data_ixs is not None
assert isinstance(state.local_data_ixs, Sequence)
assert state.local_data_len == 2
npt.assert_array_equal(
state.local_data_ixs,
Expand Down Expand Up @@ -229,8 +229,8 @@ def test_ask_tell_optimizer_loads_from_state(

assert_datasets_allclose(new_state.record.dataset, old_state.record.dataset)
assert old_state.record.model is new_state.record.model
assert new_state.local_data_ixs is not None
assert old_state.local_data_ixs is not None
assert isinstance(new_state.local_data_ixs, Sequence)
assert isinstance(old_state.local_data_ixs, Sequence)
npt.assert_array_equal(new_state.local_data_ixs, old_state.local_data_ixs)
assert old_state.local_data_len == new_state.local_data_len == len(init_dataset.query_points)

Expand Down Expand Up @@ -948,15 +948,13 @@ def test_ask_tell_optimizer_dataset_len_variables(
assert AskTellOptimizer.dataset_len({"tag1": dataset, "tag2": dataset}) == 2


def test_ask_tell_optimizer_dataset_len_raises_on_inconsistently_sized_datasets(
def test_ask_tell_optimizer_dataset_len_returns_dict_on_inconsistently_sized_datasets(
init_dataset: Dataset,
) -> None:
with pytest.raises(ValueError):
AskTellOptimizer.dataset_len(
{"tag": init_dataset, "empty": Dataset(tf.zeros([0, 2]), tf.zeros([0, 2]))}
)
with pytest.raises(ValueError):
AskTellOptimizer.dataset_len({})
assert AskTellOptimizer.dataset_len(
{"tag": init_dataset, "empty": Dataset(tf.zeros([0, 2]), tf.zeros([0, 2]))}
) == {"tag": 2, "empty": 0}
assert AskTellOptimizer.dataset_len({}) == {}


@pytest.mark.parametrize("optimizer", OPTIMIZERS)
Expand Down
23 changes: 15 additions & 8 deletions trieste/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def copy_to_local_models(
def with_local_datasets(
datasets: Mapping[Tag, Dataset],
num_local_datasets: int,
local_dataset_indices: Optional[Sequence[TensorType]] = None,
local_dataset_indices: Optional[
Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]
] = None,
) -> Dict[Tag, Dataset]:
"""
Helper method to add local datasets if they do not already exist, by copying global datasets
Expand All @@ -174,17 +176,22 @@ def with_local_datasets(
the global datasets should be copied. If None then the entire datasets are copied.
:return: The updated mapping of datasets.
"""
if local_dataset_indices is not None and len(local_dataset_indices) != num_local_datasets:
raise ValueError(
f"local_dataset_indices should have {num_local_datasets} entries, "
f"has {len(local_dataset_indices)}"
)
if isinstance(local_dataset_indices, Sequence):
local_dataset_indices = {tag: local_dataset_indices for tag in datasets}

updated_datasets = {}
for tag in datasets:
updated_datasets[tag] = datasets[tag]
ltag = LocalizedTag.from_tag(tag)
if not ltag.is_local:
if local_dataset_indices is not None:
if tag not in local_dataset_indices:
raise ValueError(f"local_dataset_indices missing tag {tag}")
elif len(local_dataset_indices[tag]) != num_local_datasets:
raise ValueError(
f"local_dataset_indices for tag {tag} should have {num_local_datasets} "
f"entries, but has {len(local_dataset_indices[tag])}"
)
for i in range(num_local_datasets):
target_ltag = LocalizedTag(ltag.global_tag, i)
if target_ltag not in datasets:
Expand All @@ -194,10 +201,10 @@ def with_local_datasets(
# TODO: use sparse tensors instead
updated_datasets[target_ltag] = Dataset(
query_points=tf.gather(
datasets[tag].query_points, local_dataset_indices[i]
datasets[tag].query_points, local_dataset_indices[tag][i]
),
observations=tf.gather(
datasets[tag].observations, local_dataset_indices[i]
datasets[tag].observations, local_dataset_indices[tag][i]
),
)

Expand Down
129 changes: 92 additions & 37 deletions trieste/ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class AskTellOptimizerState(Generic[StateType, ProbabilisticModelType]):
record: Record[StateType, ProbabilisticModelType]
""" A record of the current state of the optimization. """

local_data_ixs: Optional[Sequence[TensorType]]
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]]
""" Indices to the local data, for LocalDatasetsAcquisitionRule rules
when `track_data` is `False`. """

Expand All @@ -108,7 +108,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -122,7 +122,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -139,7 +139,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -152,7 +152,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -166,7 +166,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -183,7 +183,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -204,7 +204,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
):
"""
Expand All @@ -225,9 +225,12 @@ def __init__(
updates to the global datasets (optionally using `local_data_ixs` and indices passed
in to `tell`).
:param local_data_ixs: Indices to the local data in the initial datasets. If unspecified,
assumes that the initial datasets are global.
assumes that the initial datasets are global. Can a be a single sequence for all
datasets, or a mapping with separate values for each dataset.
:param local_data_len: Optional length of the data when the passed in `local_data_ixs`
were measured. If the data has increased since then, the indices are extended.
(Note that this is only supported when all datasets have the same length. If not,
then it is up to the caller to update the indices before initialization.)
:raise ValueError: If any of the following are true:
- the keys in ``datasets`` and ``models`` do not match
- ``datasets`` or ``models`` are empty
Expand Down Expand Up @@ -287,12 +290,41 @@ def __init__(
if self.track_data:
datasets = self._datasets = with_local_datasets(self._datasets, num_local_datasets)
else:
self._dataset_len = self.dataset_len(self._datasets)
if local_data_ixs is not None:
dataset_len = self.dataset_len(self._datasets)
self._dataset_len = dataset_len if isinstance(dataset_len, int) else None
self._dataset_ixs: list[TensorType] | Mapping[Tag, list[TensorType]]

if local_data_ixs is None:
# assume that the initial datasets are global
if isinstance(dataset_len, int):
self._dataset_ixs = [
tf.range(dataset_len) for _ in range(num_local_datasets)
]
else:
self._dataset_ixs = {
t: [tf.range(l) for _ in range(num_local_datasets)]
for t, l in dataset_len.items()
}

elif isinstance(local_data_ixs, Mapping):
self._dataset_ixs = {t: list(ixs) for t, ixs in local_data_ixs.items()}
if local_data_len is not None:
raise ValueError(
"Cannot infer new data points for datasets with different "
"local data indices. Pass in full indices instead."
)

else:
self._dataset_ixs = list(local_data_ixs)

if local_data_len is not None:
# infer new dataset indices from change in dataset sizes
num_new_points = self._dataset_len - local_data_len
if isinstance(dataset_len, Mapping):
raise ValueError(
"Cannot infer new data points for datasets with different "
"lengths. Pass in full indices instead."
)
num_new_points = dataset_len - local_data_len
if num_new_points < 0 or (
num_local_datasets > 0 and num_new_points % num_local_datasets != 0
):
Expand All @@ -310,10 +342,6 @@ def __init__(
],
-1,
)
else:
self._dataset_ixs = [
tf.range(self._dataset_len) for _ in range(num_local_datasets)
]

datasets = with_local_datasets(
self._datasets, num_local_datasets, self._dataset_ixs
Expand Down Expand Up @@ -375,7 +403,7 @@ def dataset(self) -> Dataset:
raise ValueError(f"Expected a single dataset, found {len(datasets)}")

@property
def local_data_ixs(self) -> Optional[Sequence[TensorType]]:
def local_data_ixs(self) -> Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]]:
"""Indices to the local data. Only stored for LocalDatasetsAcquisitionRule rules
when `track_data` is `False`."""
if isinstance(self._acquisition_rule, LocalDatasetsAcquisitionRule) and not self.track_data:
Expand Down Expand Up @@ -433,8 +461,8 @@ def acquisition_state(self) -> StateType | None:
return self._acquisition_state

@classmethod
def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int:
"""Helper method for inferring the global dataset size."""
def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int | Mapping[Tag, int]:
"""Helper method for inferring the global dataset size(s)."""
dataset_lens = {
tag: int(tf.shape(dataset.query_points)[0])
for tag, dataset in datasets.items()
Expand All @@ -444,9 +472,7 @@ def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int:
if len(unique_lens) == 1:
return int(unique_lens[0])
else:
raise ValueError(
f"Expected unique global dataset size, got {unique_lens}: {dataset_lens}"
)
return dataset_lens

@classmethod
def from_record(
Expand All @@ -465,7 +491,7 @@ def from_record(
| None
) = None,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
) -> AskTellOptimizerType:
"""Creates new :class:`~AskTellOptimizer` instance from provided optimization state.
Expand Down Expand Up @@ -634,14 +660,15 @@ def ask(self) -> TensorType:
def tell(
self,
new_data: Mapping[Tag, Dataset] | Dataset,
new_data_ixs: Optional[Sequence[TensorType]] = None,
new_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
) -> None:
"""Updates optimizer state with new data.
:param new_data: New observed data. If `track_data` is `False`, this refers to all
the data.
:param new_data_ixs: Indices to the new observed local data, if `track_data` is `False`.
If unspecified, inferred from the change in dataset sizes.
If unspecified, inferred from the change in dataset sizes (as long as all the
datasets have the same size).
:raise ValueError: If keys in ``new_data`` do not match those in already built dataset.
"""
if isinstance(new_data, Dataset):
Expand Down Expand Up @@ -670,10 +697,45 @@ def tell(
elif not isinstance(self._acquisition_rule, LocalDatasetsAcquisitionRule):
datasets = new_data
else:
num_local_datasets = len(self._dataset_ixs)
if new_data_ixs is None:
num_local_datasets = (
len(self._dataset_ixs)
if isinstance(self._dataset_ixs, Sequence)
else len(next(iter(self._dataset_ixs.values())))
)

if new_data_ixs is not None:
# use explicit indices
def update_ixs(ixs: list[TensorType], new_ixs: Sequence[TensorType]) -> None:
if len(ixs) != len(new_ixs):
raise ValueError(
f"new_data_ixs has {len(new_ixs)} entries, expected {len(ixs)}"
)
for i in range(len(ixs)):
ixs[i] = tf.concat([ixs[i], new_ixs[i]], -1)

if isinstance(new_data_ixs, Sequence) and isinstance(self._dataset_ixs, Mapping):
raise ValueError("separate new_data_ixs required for each dataset")
if isinstance(new_data_ixs, Mapping) and isinstance(self._dataset_ixs, Sequence):
self._dataset_ixs = {tag: list(self._dataset_ixs) for tag in self._datasets}
if isinstance(new_data_ixs, Mapping):
assert isinstance(self._dataset_ixs, Mapping)
for tag in self._datasets:
update_ixs(self._dataset_ixs[tag], new_data_ixs[tag])
else:
assert isinstance(self._dataset_ixs, list)
update_ixs(self._dataset_ixs, new_data_ixs)

else:
# infer dataset indices from change in dataset sizes
if isinstance(self._dataset_ixs, Mapping) or not isinstance(self._dataset_len, int):
raise NotImplementedError(
"new data indices cannot be inferred for datasets with different sizes"
)
new_dataset_len = self.dataset_len(new_data)
if not isinstance(new_dataset_len, int):
raise NotImplementedError(
"new data indices cannot be inferred for new data with different sizes"
)
num_new_points = new_dataset_len - self._dataset_len
if num_new_points < 0 or (
num_local_datasets > 0 and num_new_points % num_local_datasets != 0
Expand All @@ -690,17 +752,10 @@ def tell(
],
-1,
)
else:
# use explicit indices
if len(new_data_ixs) != num_local_datasets:
raise ValueError(
f"new_data_ixs has {len(new_data_ixs)} entries, "
f"expected {num_local_datasets}"
)
for i in range(num_local_datasets):
self._dataset_ixs[i] = tf.concat([self._dataset_ixs[i], new_data_ixs[i]], -1)

datasets = with_local_datasets(new_data, num_local_datasets, self._dataset_ixs)
self._dataset_len = self.dataset_len(datasets)
dataset_len = self.dataset_len(datasets)
self._dataset_len = dataset_len if isinstance(dataset_len, int) else None

filtered_datasets = self._acquisition_rule.filter_datasets(self._models, datasets)
if callable(filtered_datasets):
Expand Down

0 comments on commit dc9e175

Please sign in to comment.