Skip to content

Commit

Permalink
Temporarily avoid isinstance(..., ProbabilisticModel)
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Apr 8, 2024
1 parent c6a039a commit 488e268
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tests/unit/models/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _model_stack() -> (

def test_model_stack_predict() -> None:
stack, (model01, model2, model3) = _model_stack()
assert all(
isinstance(model, TrainableProbabilisticModel) for model in (stack, model01, model2, model3)
)
query_points = tf.random.uniform([5, 7, 3])
mean, var = stack.predict(query_points)

Expand Down
2 changes: 1 addition & 1 deletion trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def with_input_active_dims(

# No selection for models.
# Nothing to do if active dimensions are not set.
if isinstance(value, ProbabilisticModel) or self.input_active_dims is None:
if not isinstance(value, (Dataset, tf.Tensor)) or self.input_active_dims is None:
return value

# Select components of query points for datasets.
Expand Down

0 comments on commit 488e268

Please sign in to comment.