From 8f2edc8835e5252531d98283c17411cef000280f Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Fri, 11 Oct 2024 15:18:24 +0100 Subject: [PATCH] Clean up tf.debugging.asserts --- .../test_continuous_thompson_sampling.py | 4 +- trieste/acquisition/function/entropy.py | 12 ++--- trieste/acquisition/function/function.py | 52 ++++++++----------- trieste/acquisition/function/greedy_batch.py | 8 +-- .../acquisition/function/multi_objective.py | 26 ++++------ trieste/acquisition/optimizer.py | 2 +- trieste/acquisition/rule.py | 2 +- .../models/gpflow/inducing_point_selectors.py | 6 +-- trieste/models/gpflow/sampler.py | 8 +-- trieste/models/gpflux/sampler.py | 8 +-- trieste/models/keras/sampler.py | 4 +- trieste/space.py | 12 ++++- 12 files changed, 64 insertions(+), 80 deletions(-) diff --git a/tests/unit/acquisition/function/test_continuous_thompson_sampling.py b/tests/unit/acquisition/function/test_continuous_thompson_sampling.py index 8110e33103..53cbeb7db0 100644 --- a/tests/unit/acquisition/function/test_continuous_thompson_sampling.py +++ b/tests/unit/acquisition/function/test_continuous_thompson_sampling.py @@ -42,9 +42,7 @@ class DumbTrajectorySampler(RandomFourierFeatureTrajectorySampler): """A RandomFourierFeatureTrajectorySampler that doesn't update trajectories in place.""" def update_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunction: - tf.debugging.Assert( - isinstance(trajectory, feature_decomposition_trajectory), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(trajectory, feature_decomposition_trajectory), []) return self.get_trajectory() diff --git a/trieste/acquisition/function/entropy.py b/trieste/acquisition/function/entropy.py index d53cb6a61b..e291e8d461 100644 --- a/trieste/acquisition/function/entropy.py +++ b/trieste/acquisition/function/entropy.py @@ -128,7 +128,7 @@ def prepare_acquisition_function( :exc:`~tf.errors.InvalidArgumentError` if used with a batch size greater than one. :raise tf.errors.InvalidArgumentError: If ``dataset`` is empty. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") @@ -150,10 +150,10 @@ def update_acquisition_function( :param model: The model. :param dataset: The data from the observer. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert(isinstance(function, min_value_entropy_search), [tf.constant([])]) + tf.debugging.Assert(isinstance(function, min_value_entropy_search), []) query_points = self._search_space.sample(num_samples=self._grid_size) tf.debugging.assert_same_float_dtype([dataset.query_points, query_points]) @@ -334,7 +334,7 @@ def prepare_acquisition_function( f"covariance_between_points and get_observation_noise; received {model!r}" ) - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") @@ -363,10 +363,10 @@ def update_acquisition_function( for the current step. Defaults to ``True``. :return: The updated acquisition function. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert(self._quality_term is not None, [tf.constant([])]) + tf.debugging.Assert(self._quality_term is not None, []) if new_optimization_step: self._update_quality_term(dataset, model) diff --git a/trieste/acquisition/function/function.py b/trieste/acquisition/function/function.py index e917b84e09..81d801f5cd 100644 --- a/trieste/acquisition/function/function.py +++ b/trieste/acquisition/function/function.py @@ -86,7 +86,7 @@ def update_acquisition_function( tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert(isinstance(function, probability_below_threshold), [tf.constant([])]) + tf.debugging.Assert(isinstance(function, probability_below_threshold), []) mean, _ = model.predict(dataset.query_points) eta = tf.reduce_min(mean, axis=0)[0] function.update(eta) # type: ignore @@ -127,7 +127,7 @@ def prepare_acquisition_function( greater than one. :raise tf.errors.InvalidArgumentError: If ``dataset`` is empty. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") @@ -161,10 +161,10 @@ def update_acquisition_function( :param model: The model. :param dataset: The data from the observer. Must be populated. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert(isinstance(function, expected_improvement), [tf.constant([])]) + tf.debugging.Assert(isinstance(function, expected_improvement), []) # Check feasibility against any explicit constraints in the search space. if self._search_space is not None and self._search_space.has_constraints: @@ -251,7 +251,7 @@ def prepare_acquisition_function( f"AugmentedExpectedImprovement only works with models that support " f"get_observation_noise; received {model!r}" ) - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") mean, _ = model.predict(dataset.query_points) @@ -269,10 +269,10 @@ def update_acquisition_function( :param model: The model. :param dataset: The data from the observer. Must be populated. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert(isinstance(function, augmented_expected_improvement), [tf.constant([])]) + tf.debugging.Assert(isinstance(function, augmented_expected_improvement), []) mean, _ = model.predict(dataset.query_points) eta = tf.reduce_min(mean, axis=0) function.update(eta) # type: ignore @@ -669,7 +669,7 @@ def prepare_acquisition_function( :raise KeyError: If `objective_tag` is not found in ``datasets`` and ``models``. :raise tf.errors.InvalidArgumentError: If the objective data is empty. """ - tf.debugging.Assert(datasets is not None, [tf.constant([])]) + tf.debugging.Assert(datasets is not None, []) datasets = cast(Mapping[Tag, Dataset], datasets) objective_model = models[self._objective_tag] @@ -719,7 +719,7 @@ def update_acquisition_function( :param models: The models for each tag. :param datasets: The data from the observer. """ - tf.debugging.Assert(datasets is not None, [tf.constant([])]) + tf.debugging.Assert(datasets is not None, []) datasets = cast(Mapping[Tag, Dataset], datasets) objective_model = models[self._objective_tag] @@ -730,7 +730,7 @@ def update_acquisition_function( message="Expected improvement is defined with respect to existing points in the" " objective data, but the objective data is empty.", ) - tf.debugging.Assert(self._constraint_fn is not None, [tf.constant([])]) + tf.debugging.Assert(self._constraint_fn is not None, []) constraint_fn = cast(AcquisitionFunction, self._constraint_fn) self._constraint_builder.update_acquisition_function( @@ -777,9 +777,7 @@ def _update_expected_improvement_fn( if self._expected_improvement_fn is None: self._expected_improvement_fn = expected_improvement(objective_model, eta) else: - tf.debugging.Assert( - isinstance(self._expected_improvement_fn, expected_improvement), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(self._expected_improvement_fn, expected_improvement), []) self._expected_improvement_fn.update(eta) # type: ignore @@ -830,7 +828,7 @@ def prepare_acquisition_function( sampler = model.reparam_sampler(self._sample_size) - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") @@ -858,12 +856,10 @@ def update_acquisition_function( :param model: The model. Must have output dimension [1]. Unused here. :param dataset: The data from the observer. Cannot be empty """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert( - isinstance(function, monte_carlo_expected_improvement), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(function, monte_carlo_expected_improvement), []) sampler = function._sampler # type: ignore sampler.reset_sampler() samples_at_query_points = sampler.sample( @@ -974,7 +970,7 @@ def prepare_acquisition_function( sampler = model.reparam_sampler(self._sample_size) - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") @@ -1002,12 +998,10 @@ def update_acquisition_function( :param model: The model. Must have output dimension [1]. Unused here :param dataset: The data from the observer. Cannot be empty. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert( - isinstance(function, monte_carlo_augmented_expected_improvement), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(function, monte_carlo_augmented_expected_improvement), []) sampler = function._sampler # type: ignore sampler.reset_sampler() samples_at_query_points = sampler.sample( @@ -1111,7 +1105,7 @@ def prepare_acquisition_function( :raise ValueError (or InvalidArgumentError): If ``dataset`` is not populated, or ``model`` does not have an event shape of [1]. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") @@ -1135,12 +1129,10 @@ def update_acquisition_function( :param model: The model. Must have event shape [1]. :param dataset: The data from the observer. Must be populated. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert( - isinstance(function, batch_monte_carlo_expected_improvement), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(function, batch_monte_carlo_expected_improvement), []) mean, _ = model.predict(dataset.query_points) eta = tf.reduce_min(mean, axis=0) function.update(eta) # type: ignore @@ -1848,9 +1840,7 @@ def update_acquisition_function( :param model: The model. :param dataset: Unused. """ - tf.debugging.Assert( - isinstance(function, multiple_optimism_lower_confidence_bound), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(function, multiple_optimism_lower_confidence_bound), []) return function # nothing to update diff --git a/trieste/acquisition/function/greedy_batch.py b/trieste/acquisition/function/greedy_batch.py index 3de6050c85..2134f7587a 100644 --- a/trieste/acquisition/function/greedy_batch.py +++ b/trieste/acquisition/function/greedy_batch.py @@ -135,7 +135,7 @@ def prepare_acquisition_function( :return: The (log) expected improvement penalized with respect to the pending points. :raise tf.errors.InvalidArgumentError: If the ``dataset`` is empty. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") @@ -164,10 +164,10 @@ def update_acquisition_function( for the current step. Defaults to ``True``. :return: The updated acquisition function. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert(self._base_acquisition_function is not None, [tf.constant([])]) + tf.debugging.Assert(self._base_acquisition_function is not None, []) if new_optimization_step: self._update_base_acquisition_function(dataset, model) @@ -447,7 +447,7 @@ def __init__( See class docs for more details. :raise tf.errors.InvalidArgumentError: If ``fantasize_method`` is not "KB" or "sample". """ - tf.debugging.Assert(fantasize_method in ["KB", "sample"], [tf.constant([])]) + tf.debugging.Assert(fantasize_method in ["KB", "sample"], []) if base_acquisition_function_builder is None: base_acquisition_function_builder = ExpectedImprovement() diff --git a/trieste/acquisition/function/multi_objective.py b/trieste/acquisition/function/multi_objective.py index d72c00eae1..fd5d106ebd 100644 --- a/trieste/acquisition/function/multi_objective.py +++ b/trieste/acquisition/function/multi_objective.py @@ -91,7 +91,7 @@ def prepare_acquisition_function( :param dataset: The data from the observer. Must be populated. :return: The expected hypervolume improvement acquisition function. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") mean, _ = model.predict(dataset.query_points) @@ -121,10 +121,10 @@ def update_acquisition_function( :param model: The model. :param dataset: The data from the observer. Must be populated. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") - tf.debugging.Assert(isinstance(function, expected_hv_improvement), [tf.constant([])]) + tf.debugging.Assert(isinstance(function, expected_hv_improvement), []) mean, _ = model.predict(dataset.query_points) if callable(self._ref_point_spec): @@ -319,7 +319,7 @@ def prepare_acquisition_function( :param dataset: The data from the observer. Must be populated. :return: The batch expected hypervolume improvement acquisition function. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) dataset = cast(Dataset, dataset) tf.debugging.assert_positive(len(dataset), message="Dataset must be populated.") mean, _ = model.predict(dataset.query_points) @@ -564,9 +564,9 @@ def prepare_acquisition_function( :return: The HIPPO acquisition function. :raise tf.errors.InvalidArgumentError: If the ``dataset`` is empty. """ - tf.debugging.Assert(datasets is not None, [tf.constant([])]) + tf.debugging.Assert(datasets is not None, []) datasets = cast(Mapping[Tag, Dataset], datasets) - tf.debugging.Assert(datasets[self._objective_tag] is not None, [tf.constant([])]) + tf.debugging.Assert(datasets[self._objective_tag] is not None, []) tf.debugging.assert_positive( len(datasets[self._objective_tag]), message=f"{self._objective_tag} dataset must be populated.", @@ -599,14 +599,14 @@ def update_acquisition_function( for the current step. Defaults to ``True``. :return: The updated acquisition function. """ - tf.debugging.Assert(datasets is not None, [tf.constant([])]) + tf.debugging.Assert(datasets is not None, []) datasets = cast(Mapping[Tag, Dataset], datasets) - tf.debugging.Assert(datasets[self._objective_tag] is not None, [tf.constant([])]) + tf.debugging.Assert(datasets[self._objective_tag] is not None, []) tf.debugging.assert_positive( len(datasets[self._objective_tag]), message=f"{self._objective_tag} dataset must be populated.", ) - tf.debugging.Assert(self._base_acquisition_function is not None, [tf.constant([])]) + tf.debugging.Assert(self._base_acquisition_function is not None, []) if new_optimization_step: self._update_base_acquisition_function(models, datasets) @@ -689,9 +689,7 @@ def __init__(self, model: ProbabilisticModel, pending_points: TensorType): :return: The penalization function. This function will raise :exc:`ValueError` or :exc:`~tf.errors.InvalidArgumentError` if used with a batch size greater than one.""" - tf.debugging.Assert( - pending_points is not None and len(pending_points) != 0, [tf.constant([])] - ) + tf.debugging.Assert(pending_points is not None and len(pending_points) != 0, []) self._model = model self._pending_points = tf.Variable(pending_points, shape=[None, *pending_points.shape[1:]]) @@ -701,9 +699,7 @@ def __init__(self, model: ProbabilisticModel, pending_points: TensorType): def update(self, pending_points: TensorType) -> None: """Update the penalizer with new pending points.""" - tf.debugging.Assert( - pending_points is not None and len(pending_points) != 0, [tf.constant([])] - ) + tf.debugging.Assert(pending_points is not None and len(pending_points) != 0, []) self._pending_points.assign(pending_points) pending_means, pending_vars = self._model.predict(self._pending_points) diff --git a/trieste/acquisition/optimizer.py b/trieste/acquisition/optimizer.py index fdd8c01fa8..d9cc67682c 100644 --- a/trieste/acquisition/optimizer.py +++ b/trieste/acquisition/optimizer.py @@ -760,7 +760,7 @@ def get_bounds_of_box_relaxation_around_point( :param current_point: The point at which to make the continuous relaxation. :return: Bounds for the Scipy optimizer. """ - tf.debugging.Assert(isinstance(space, TaggedProductSearchSpace), [tf.constant([])]) + tf.debugging.Assert(isinstance(space, TaggedProductSearchSpace), []) space_with_fixed_discrete = space for tag in space.subspace_tags: diff --git a/trieste/acquisition/rule.py b/trieste/acquisition/rule.py index d4ff018074..515957666d 100644 --- a/trieste/acquisition/rule.py +++ b/trieste/acquisition/rule.py @@ -636,7 +636,7 @@ def acquire( def state_func( state: AsynchronousRuleState | None, ) -> tuple[AsynchronousRuleState | None, TensorType]: - tf.debugging.Assert(self._acquisition_function is not None, [tf.constant([])]) + tf.debugging.Assert(self._acquisition_function is not None, []) if state is None: state = AsynchronousRuleState(None) diff --git a/trieste/models/gpflow/inducing_point_selectors.py b/trieste/models/gpflow/inducing_point_selectors.py index b5acb70137..f0a123234a 100644 --- a/trieste/models/gpflow/inducing_point_selectors.py +++ b/trieste/models/gpflow/inducing_point_selectors.py @@ -74,7 +74,7 @@ def calculate_inducing_points( :return: The new updated inducing points. :raise NotImplementedError: If model has more than one set of inducing variables. """ - tf.debugging.Assert(current_inducing_points is not None, [tf.constant([])]) + tf.debugging.Assert(current_inducing_points is not None, []) if isinstance(current_inducing_points, list): raise NotImplementedError( @@ -165,7 +165,7 @@ def _recalculate_inducing_points( :raise tf.errors.InvalidArgumentError: If ``dataset`` is empty. """ - tf.debugging.Assert(len(dataset.query_points) is not None, [tf.constant([])]) + tf.debugging.Assert(len(dataset.query_points) is not None, []) N = tf.shape(dataset.query_points)[0] # training data size shuffled_query_points = tf.random.shuffle(dataset.query_points) # [N, d] @@ -208,7 +208,7 @@ def _recalculate_inducing_points( :raise tf.errors.InvalidArgumentError: If ``dataset`` is empty. """ - tf.debugging.Assert(dataset is not None, [tf.constant([])]) + tf.debugging.Assert(dataset is not None, []) query_points = dataset.query_points # [N, d] N = tf.shape(query_points)[0] diff --git a/trieste/models/gpflow/sampler.py b/trieste/models/gpflow/sampler.py index d1aa8f1ce7..fc0324edd5 100644 --- a/trieste/models/gpflow/sampler.py +++ b/trieste/models/gpflow/sampler.py @@ -414,9 +414,7 @@ def update_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunctio :param trajectory: The trajectory function to be resampled. :return: The new resampled trajectory function. """ - tf.debugging.Assert( - isinstance(trajectory, feature_decomposition_trajectory), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(trajectory, feature_decomposition_trajectory), []) self._feature_functions.resample() # resample Fourier feature decomposition weight_sampler = self._prepare_weight_sampler() # recalculate weight distribution @@ -433,9 +431,7 @@ def resample_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunct :param trajectory: The trajectory function to be resampled. :return: The new resampled trajectory function. """ - tf.debugging.Assert( - isinstance(trajectory, feature_decomposition_trajectory), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(trajectory, feature_decomposition_trajectory), []) cast(feature_decomposition_trajectory, trajectory).resample() return trajectory # return trajectory with resampled weights diff --git a/trieste/models/gpflux/sampler.py b/trieste/models/gpflux/sampler.py index 4aef9937a5..fae6cc9d3a 100644 --- a/trieste/models/gpflux/sampler.py +++ b/trieste/models/gpflux/sampler.py @@ -187,9 +187,7 @@ def update_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunctio :class:`dgp_feature_decomposition_trajectory` """ - tf.debugging.Assert( - isinstance(trajectory, dgp_feature_decomposition_trajectory), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(trajectory, dgp_feature_decomposition_trajectory), []) cast(dgp_feature_decomposition_trajectory, trajectory).update() return trajectory @@ -204,9 +202,7 @@ def resample_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunct :raise InvalidArgumentError: If ``trajectory`` is not a :class:`dgp_feature_decomposition_trajectory` """ - tf.debugging.Assert( - isinstance(trajectory, dgp_feature_decomposition_trajectory), [tf.constant([])] - ) + tf.debugging.Assert(isinstance(trajectory, dgp_feature_decomposition_trajectory), []) cast(dgp_feature_decomposition_trajectory, trajectory).resample() return trajectory diff --git a/trieste/models/keras/sampler.py b/trieste/models/keras/sampler.py index f43f8fdba8..24ddc01259 100644 --- a/trieste/models/keras/sampler.py +++ b/trieste/models/keras/sampler.py @@ -89,7 +89,7 @@ def update_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunctio :param trajectory: The trajectory function to be resampled. :return: The new trajectory function updated for a new model """ - tf.debugging.Assert(isinstance(trajectory, deep_ensemble_trajectory), [tf.constant([])]) + tf.debugging.Assert(isinstance(trajectory, deep_ensemble_trajectory), []) trajectory.resample() # type: ignore return trajectory @@ -101,7 +101,7 @@ def resample_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunct :param trajectory: The trajectory function to be resampled. :return: The new resampled trajectory function. """ - tf.debugging.Assert(isinstance(trajectory, deep_ensemble_trajectory), [tf.constant([])]) + tf.debugging.Assert(isinstance(trajectory, deep_ensemble_trajectory), []) trajectory.resample() # type: ignore return trajectory diff --git a/trieste/space.py b/trieste/space.py index 0de9214b55..04b95fad14 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -656,12 +656,20 @@ def one_hot_encoder(self) -> EncoderFunction: def binary_encoder(x: TensorType) -> TensorType: # no need to one-hot encode binary categories (but we should still validate) - tf.debugging.Assert(tf.reduce_all((x == 0) | (x == 1)), [tf.constant([])]) + tf.debugging.Assert( + tf.reduce_all((x == 0) | (x == 1)), + ["Invalid binary values for one-hot encoding:", x], + ) return x def encoder(x: TensorType) -> TensorType: flat_x, unflatten = flatten_leading_dims(x) - tf.debugging.assert_equal(flat_x.shape[-1], len(self.tags)) + tf.debugging.assert_equal( + flat_x.shape[-1], + len(self.tags), + message="Invalid input for one-hot encoding: " + f"expected {len(self.tags)} tags, got {flat_x.shape[-1]}", + ) columns = tf.split(flat_x, flat_x.shape[-1], axis=1) encoders = [ (