Skip to content

Commit

Permalink
Categorical search space tweaks (#869)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Sep 10, 2024
1 parent 9e5cdb1 commit aa06d67
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 4 deletions.
32 changes: 30 additions & 2 deletions tests/unit/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SearchSpace,
TaggedMultiSearchSpace,
TaggedProductSearchSpace,
cast_encoder,
one_hot_encoder,
)
from trieste.types import TensorType
Expand Down Expand Up @@ -1759,6 +1760,11 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None:
tf.constant([[0], [0]], dtype=tf.float64),
tf.constant([[1], [1]], dtype=tf.float64),
),
(
CategoricalSearchSpace(["Y", "N"]),
tf.constant([[0], [1], [0]], dtype=tf.float64),
tf.constant([[0], [1], [0]], dtype=tf.float64),
),
(
CategoricalSearchSpace(["R", "G", "B"], dtype=tf.float32),
tf.constant([[0], [2], [1]], dtype=tf.float32),
Expand All @@ -1777,13 +1783,13 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None:
(
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
tf.constant([[0, 0], [2, 0], [1, 1]], dtype=tf.float64),
tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float64),
tf.constant([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]], dtype=tf.float64),
),
(
CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]),
tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]], dtype=tf.float64),
tf.constant(
[[[1, 0, 0, 1, 0], [1, 0, 0, 1, 0]], [[0, 0, 1, 1, 0], [0, 1, 0, 0, 1]]],
[[[1, 0, 0, 0], [1, 0, 0, 0]], [[0, 0, 1, 0], [0, 1, 0, 1]]],
dtype=tf.float64,
),
),
Expand Down Expand Up @@ -1824,6 +1830,12 @@ def test_categorical_search_space_one_hot_encoding(
pytest.param(
CategoricalSearchSpace(["Y", "N"]),
tf.constant([[0], [2], [1]]),
ValueError,
id="Out of range binary input value",
),
pytest.param(
CategoricalSearchSpace(["Y", "N", "maybe"]),
tf.constant([[0], [3], [1]]),
InvalidArgumentError,
id="Out of range input value",
),
Expand Down Expand Up @@ -1859,3 +1871,19 @@ def test_unbound_search_spaces(
space.lower
with pytest.raises(AttributeError):
space.upper


@pytest.mark.parametrize("input_dtype", [None, tf.float64, tf.float32])
@pytest.mark.parametrize("output_dtype", [None, tf.float64, tf.float32])
def test_cast_encoder(input_dtype: Optional[tf.DType], output_dtype: Optional[tf.DType]) -> None:

query_points = tf.constant([1, 2, 3], dtype=tf.int32)

def add_encoder(x: TensorType) -> TensorType:
assert x.dtype is (input_dtype or tf.int32)
return x + 1

encoder = cast_encoder(add_encoder, input_dtype=input_dtype, output_dtype=output_dtype)
points = encoder(query_points)
assert points.dtype is (output_dtype or input_dtype or tf.int32)
npt.assert_array_equal(tf.cast(query_points + 1, points.dtype), points)
33 changes: 31 additions & 2 deletions trieste/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,24 @@ def one_hot_encoder(space: SearchSpace) -> EncoderFunction:
return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x


def cast_encoder(
encoder: EncoderFunction,
input_dtype: Optional[tf.DType] = None,
output_dtype: Optional[tf.DType] = None,
) -> EncoderFunction:
"A utility function for casting the input and/or output of an encoder."

def cast_and_encode(x: TensorType) -> TensorType:
if input_dtype is not None:
x = tf.cast(x, input_dtype)
y = encoder(x)
if output_dtype is not None:
y = tf.cast(y, output_dtype)
return y

return cast_and_encode


def one_hot_encoded_space(space: SearchSpace) -> SearchSpace:
"A bounded search space corresponding to the one-hot encoding of the given space."

Expand Down Expand Up @@ -633,7 +651,14 @@ def tags(self) -> Sequence[Sequence[str]]:

@property
def one_hot_encoder(self) -> EncoderFunction:
"""A one-hot encoder for the numerical indices."""
"""A one-hot encoder for the numerical indices. Note that binary categories
are left unchanged instead of adding an unnecessary second feature."""

def binary_encoder(x: TensorType) -> TensorType:
# no need to one-hot encode binary categories (but we should still validate)
if tf.reduce_any((x != 0) & (x != 1)):
raise ValueError(f"Invalid values {tf.boolean_mask(x, ((x != 0) & (x != 1)))}")
return x

def encoder(x: TensorType) -> TensorType:
flat_x, unflatten = flatten_leading_dims(x)
Expand All @@ -644,7 +669,11 @@ def encoder(x: TensorType) -> TensorType:
)
columns = tf.split(flat_x, flat_x.shape[-1], axis=1)
encoders = [
tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot")
(
binary_encoder
if len(ts) == 2
else tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot")
)
for ts in self.tags
]
encoded = tf.concat(
Expand Down

0 comments on commit aa06d67

Please sign in to comment.