Skip to content

Commit

Permalink
Fix tags assert in filter_datasets (#855)
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani authored Jun 10, 2024
1 parent 744d901 commit 7a595dc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
50 changes: 42 additions & 8 deletions tests/unit/acquisition/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import copy
from collections.abc import Mapping
from typing import Any, Callable, List, Optional, Sequence, Union, cast
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from unittest.mock import ANY, MagicMock

import gpflow
Expand Down Expand Up @@ -1612,7 +1612,8 @@ def test_multi_trust_region_box_raises_on_mismatched_global_search_space() -> No
mtb.acquire(Box([0.0, 0.0], [2.0, 2.0]), {})


def test_multi_trust_region_box_raises_on_mismatched_tags() -> None:
@pytest.mark.parametrize("acquire", [True, False])
def test_multi_trust_region_box_raises_on_mismatched_tags(acquire: bool) -> None:
search_space = Box([0.0, 0.0], [1.0, 1.0])
dataset = Dataset(
tf.constant([[0.0, 0.0], [1.0, 1.0]], dtype=tf.float64),
Expand All @@ -1626,17 +1627,50 @@ def test_multi_trust_region_box_raises_on_mismatched_tags() -> None:

state = BatchTrustRegionState[UpdatableTrustRegionBox](subspaces, ["a", "b"])
models = {OBJECTIVE: QuadraticMeanAndRBFKernelWithSamplers(dataset)}
state_func = mtb.acquire(
search_space,
models,
{OBJECTIVE: dataset},
)
mtb.filter_datasets(models, {OBJECTIVE: dataset})
if acquire:
state_func = mtb.acquire(
search_space,
models,
{OBJECTIVE: dataset},
)
else:
state_func = mtb.filter_datasets(models, {OBJECTIVE: dataset})

with pytest.raises(AssertionError, match="The tags of the state acquisition space"):
_, _ = state_func(state)


@pytest.mark.parametrize("acquire", [True, False])
@pytest.mark.parametrize("as_list", [True, False])
def test_multi_trust_region_box_state_supports_different_tags(acquire: bool, as_list: bool) -> None:
search_space = Box([0.0], [1.0])
datasets = {OBJECTIVE: mk_dataset([[0.0], [1.0]], [[0.0], [1.0]])}
model = QuadraticMeanAndRBFKernelWithSamplers(
dataset=datasets[OBJECTIVE], noise_variance=tf.constant(1.0, dtype=tf.float64)
)
model.kernel = (
gpflow.kernels.RBF()
) # need a gpflow kernel object for random feature decompositions
models = {OBJECTIVE: model}

subspaces = [SingleObjectiveTrustRegionBox(search_space) for _ in range(2)]
base_rule = EfficientGlobalOptimization( # type: ignore[var-annotated]
builder=ParallelContinuousThompsonSampling(), num_query_points=2
)
mtb = BatchTrustRegionBox(subspaces, base_rule)

tags: Union[List[str], Tuple[str, ...]] = ["0", "1"]
if not as_list:
tags = tuple(tags)

state = BatchTrustRegionState[UpdatableTrustRegionBox](subspaces, tags)
if acquire:
state_func = mtb.acquire(search_space, models, datasets)
else:
state_func = mtb.filter_datasets(models, datasets)
state_func(state) # Check that this does not raise an error.


class TestTrustRegionBox(SingleObjectiveTrustRegionBox):
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,7 @@ def state_func(
state: BatchTrustRegionState[UpdatableTrustRegionType] | None,
) -> Tuple[BatchTrustRegionState[UpdatableTrustRegionType] | None, Mapping[Tag, Dataset]]:
if state is not None:
assert self._tags == state.subspace_tags, (
assert self._tags == tuple(state.subspace_tags), (
f"The tags of the state acquisition space "
f"{state.subspace_tags} should be the same as the tags of the "
f"BatchTrustRegion acquisition rule {self._tags}"
Expand Down

0 comments on commit 7a595dc

Please sign in to comment.