From aa06d679360a6431fcd14aca6c911c9b4be8e182 Mon Sep 17 00:00:00 2001 From: uri-granta <50578464+uri-granta@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:56:57 +0100 Subject: [PATCH] Categorical search space tweaks (#869) --- tests/unit/test_space.py | 32 ++++++++++++++++++++++++++++++-- trieste/space.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index 6fb473d13..7cbc34248 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -38,6 +38,7 @@ SearchSpace, TaggedMultiSearchSpace, TaggedProductSearchSpace, + cast_encoder, one_hot_encoder, ) from trieste.types import TensorType @@ -1759,6 +1760,11 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None: tf.constant([[0], [0]], dtype=tf.float64), tf.constant([[1], [1]], dtype=tf.float64), ), + ( + CategoricalSearchSpace(["Y", "N"]), + tf.constant([[0], [1], [0]], dtype=tf.float64), + tf.constant([[0], [1], [0]], dtype=tf.float64), + ), ( CategoricalSearchSpace(["R", "G", "B"], dtype=tf.float32), tf.constant([[0], [2], [1]], dtype=tf.float32), @@ -1777,13 +1783,13 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None: ( CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), tf.constant([[0, 0], [2, 0], [1, 1]], dtype=tf.float64), - tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float64), + tf.constant([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]], dtype=tf.float64), ), ( CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]], dtype=tf.float64), tf.constant( - [[[1, 0, 0, 1, 0], [1, 0, 0, 1, 0]], [[0, 0, 1, 1, 0], [0, 1, 0, 0, 1]]], + [[[1, 0, 0, 0], [1, 0, 0, 0]], [[0, 0, 1, 0], [0, 1, 0, 1]]], dtype=tf.float64, ), ), @@ -1824,6 +1830,12 @@ def test_categorical_search_space_one_hot_encoding( pytest.param( CategoricalSearchSpace(["Y", "N"]), tf.constant([[0], [2], [1]]), + ValueError, + id="Out of range binary input value", + ), + pytest.param( + CategoricalSearchSpace(["Y", "N", "maybe"]), + tf.constant([[0], [3], [1]]), InvalidArgumentError, id="Out of range input value", ), @@ -1859,3 +1871,19 @@ def test_unbound_search_spaces( space.lower with pytest.raises(AttributeError): space.upper + + +@pytest.mark.parametrize("input_dtype", [None, tf.float64, tf.float32]) +@pytest.mark.parametrize("output_dtype", [None, tf.float64, tf.float32]) +def test_cast_encoder(input_dtype: Optional[tf.DType], output_dtype: Optional[tf.DType]) -> None: + + query_points = tf.constant([1, 2, 3], dtype=tf.int32) + + def add_encoder(x: TensorType) -> TensorType: + assert x.dtype is (input_dtype or tf.int32) + return x + 1 + + encoder = cast_encoder(add_encoder, input_dtype=input_dtype, output_dtype=output_dtype) + points = encoder(query_points) + assert points.dtype is (output_dtype or input_dtype or tf.int32) + npt.assert_array_equal(tf.cast(query_points + 1, points.dtype), points) diff --git a/trieste/space.py b/trieste/space.py index 213498a85..4b32181cc 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -518,6 +518,24 @@ def one_hot_encoder(space: SearchSpace) -> EncoderFunction: return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x +def cast_encoder( + encoder: EncoderFunction, + input_dtype: Optional[tf.DType] = None, + output_dtype: Optional[tf.DType] = None, +) -> EncoderFunction: + "A utility function for casting the input and/or output of an encoder." + + def cast_and_encode(x: TensorType) -> TensorType: + if input_dtype is not None: + x = tf.cast(x, input_dtype) + y = encoder(x) + if output_dtype is not None: + y = tf.cast(y, output_dtype) + return y + + return cast_and_encode + + def one_hot_encoded_space(space: SearchSpace) -> SearchSpace: "A bounded search space corresponding to the one-hot encoding of the given space." @@ -633,7 +651,14 @@ def tags(self) -> Sequence[Sequence[str]]: @property def one_hot_encoder(self) -> EncoderFunction: - """A one-hot encoder for the numerical indices.""" + """A one-hot encoder for the numerical indices. Note that binary categories + are left unchanged instead of adding an unnecessary second feature.""" + + def binary_encoder(x: TensorType) -> TensorType: + # no need to one-hot encode binary categories (but we should still validate) + if tf.reduce_any((x != 0) & (x != 1)): + raise ValueError(f"Invalid values {tf.boolean_mask(x, ((x != 0) & (x != 1)))}") + return x def encoder(x: TensorType) -> TensorType: flat_x, unflatten = flatten_leading_dims(x) @@ -644,7 +669,11 @@ def encoder(x: TensorType) -> TensorType: ) columns = tf.split(flat_x, flat_x.shape[-1], axis=1) encoders = [ - tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot") + ( + binary_encoder + if len(ts) == 2 + else tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot") + ) for ts in self.tags ] encoded = tf.concat(