Skip to content

Commit

Permalink
Robustify prune_inferior_points tests against sorting order (#2548)
Browse files Browse the repository at this point in the history
Summary:
Our nightly CI started failing, likely due to a sorting order change introduced in pytorch/pytorch#127936

This change robustifies the tests against the point order (and also fixes a torch deprecation warning).

NOTE: Even though pytorch/pytorch#127936 was unlanded, getting these changes is will help robustify the tests going forward.

NOTE: As this makes `torch.sort` use `stable=True`, this will come at a slight performance hit. However, the tensor sizes typically involved in `prune_inferior_points` are quite small (order of a few hundred items maybe), so this should be negligible.

Pull Request resolved: #2548

Test Plan: unit tests

Reviewed By: sdaulton, saitcakmak

Differential Revision: D63260870

Pulled By: Balandat
  • Loading branch information
Balandat authored and facebook-github-bot committed Sep 23, 2024
1 parent 9fc39fa commit f1ffb7e
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 29 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/multi_objective/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def prune_inferior_points_multi_objective(
probs = pareto_mask.to(dtype=X.dtype).mean(dim=0)
idcs = probs.nonzero().view(-1)
if idcs.shape[0] > max_points:
counts, order_idcs = torch.sort(probs, descending=True)
counts, order_idcs = torch.sort(probs, stable=True, descending=True)
idcs = order_idcs[:max_points]
effective_n_w = obj_vals.shape[-2] // X.shape[-2]
idcs = (idcs / effective_n_w).long().unique()
Expand Down
7 changes: 4 additions & 3 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,16 @@ def prune_inferior_points(
marginalize_dim=marginalize_dim,
)
if infeas.any():
# set infeasible points to worse than worst objective
# across all samples
# set infeasible points to worse than worst objective across all samples
# Use clone() here to avoid deprecated `index_put_` on an expanded tensor
obj_vals = obj_vals.clone()
obj_vals[infeas] = obj_vals.min() - 1

is_best = torch.argmax(obj_vals, dim=-1)
idcs, counts = torch.unique(is_best, return_counts=True)

if len(idcs) > max_points:
counts, order_idcs = torch.sort(counts, descending=True)
counts, order_idcs = torch.sort(counts, stable=True, descending=True)
idcs = order_idcs[:max_points]

return X[idcs]
Expand Down
23 changes: 8 additions & 15 deletions test/acquisition/multi_objective/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_get_default_partitioning_alpha(self):


class DummyMCMultiOutputObjective(MCMultiOutputObjective):
def forward(self, samples: Tensor) -> Tensor:
def forward(self, samples: Tensor, X: Tensor | None) -> Tensor:
return samples


Expand Down Expand Up @@ -130,13 +130,12 @@ def test_prune_inferior_points_multi_objective(self):
X_pruned = prune_inferior_points_multi_objective(
model=mm, X=X, ref_point=ref_point, max_frac=2 / 3
)
if self.device.type == "cuda":
# sorting has different order on cuda
self.assertTrue(
torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]])
self.assertTrue(
torch.equal(
torch.sort(X_pruned, stable=True).values,
torch.sort(X[:2], stable=True).values,
)
else:
self.assertTrue(torch.equal(X_pruned, X[:2]))
)
# test that zero-probability is in fact pruned
samples[2, 0, 0] = 10
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
Expand Down Expand Up @@ -276,10 +275,7 @@ def test_random_search_optimizer(self):
input_dim = 3
num_initial = 5
tkwargs = {"device": self.device}
optimizer_kwargs = {
"pop_size": 1000,
"max_tries": 5,
}
optimizer_kwargs = {"pop_size": 1000, "max_tries": 5}

for (
dtype,
Expand Down Expand Up @@ -350,10 +346,7 @@ def test_sample_optimal_points(self):
input_dim = 3
num_initial = 5
tkwargs = {"device": self.device}
optimizer_kwargs = {
"pop_size": 100,
"max_tries": 1,
}
optimizer_kwargs = {"pop_size": 100, "max_tries": 1}
num_samples = 2
num_points = 1

Expand Down
17 changes: 7 additions & 10 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,12 @@ def test_prune_inferior_points(self):
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
mm = MockModel(MockPosterior(samples=samples))
X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3)
if self.device.type == "cuda":
# sorting has different order on cuda
self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
else:
self.assertTrue(torch.equal(X_pruned, X[:2]))
self.assertTrue(
torch.equal(
torch.sort(X_pruned, stable=True).values,
torch.sort(X[:2], stable=True).values,
)
)
# test that zero-probability is in fact pruned
samples[2, 0, 0] = 10
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
Expand All @@ -289,11 +290,7 @@ def test_prune_inferior_points(self):
device=self.device,
dtype=dtype,
)
mm = MockModel(
MockPosterior(
samples=samples,
)
)
mm = MockModel(MockPosterior(samples=samples))
X_pruned = prune_inferior_points(
model=mm,
X=X,
Expand Down

0 comments on commit f1ffb7e

Please sign in to comment.