Skip to content

Commit

Permalink
Added posterior_transform to posterior method in ApproximateGPyTorchM…
Browse files Browse the repository at this point in the history
…odel (pytorch#2531)

Summary:
## Motivation

This PR fixes pytorch#2530. Adds a new posterior_transform parameter to the posterior method of ApproximateGPyTorchModel.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#2531

Test Plan:
I was able to generate candidate points with ExpectedImprovement acquisition function with a SingleTaskVariationalGP that was trained on 2 output columns.

## Related PRs

NA

Reviewed By: mgarrard

Differential Revision: D62652630

Pulled By: Balandat

fbshipit-source-id: 6870c8a1f47454e70951e7e7eb420cd05a2fb246
  • Loading branch information
SaiAakash authored and facebook-github-bot committed Sep 16, 2024
1 parent 6ebfa82 commit e9ce11f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
16 changes: 13 additions & 3 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

import copy
import warnings

from typing import Optional, Union

import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.warnings import UserInputWarning
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.transforms.input import InputTransform
Expand Down Expand Up @@ -146,8 +146,16 @@ def train(self, mode: bool = True) -> Self:
return Module.train(self, mode=mode)

def posterior(
self, X, output_indices=None, observation_noise=False, *args, **kwargs
self,
X,
output_indices: Optional[list[int]] = None,
observation_noise: bool = False,
posterior_transform: Optional[PosteriorTransform] = None,
) -> GPyTorchPosterior:
if output_indices is not None:
raise NotImplementedError( # pragma: no cover
f"{self.__class__.__name__}.posterior does not support output indices."
)
self.eval() # make sure model is in eval mode

# input transforms are applied at `posterior` in `eval` mode, and at
Expand All @@ -161,11 +169,13 @@ def posterior(
X = X.unsqueeze(-3).repeat(*[1] * (X_ndim - 2), self.num_outputs, 1, 1)
dist = self.model(X)
if observation_noise:
dist = self.likelihood(dist, *args, **kwargs)
dist = self.likelihood(dist)

posterior = GPyTorchPosterior(distribution=dist)
if hasattr(self, "outcome_transform"):
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
posterior = posterior_transform(posterior)
return posterior

def forward(self, X) -> MultivariateNormal:
Expand Down
11 changes: 11 additions & 0 deletions test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings

import torch
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.exceptions.warnings import UserInputWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models.approximate_gp import (
Expand Down Expand Up @@ -103,6 +104,16 @@ def test_posterior(self):
# test batch_shape property
self.assertEqual(model.batch_shape, tx.shape[:-2])

# Test that checks if posterior_transform is correctly applied
[tx1, ty1, test1] = all_tests["non_batched_mo"]
model1 = SingleTaskVariationalGP(tx1, ty1, inducing_points=tx1)
posterior_transform = ScalarizedPosteriorTransform(
weights=torch.tensor([1.0, 1.0], device=self.device)
)
posterior1 = model1.posterior(test1, posterior_transform=posterior_transform)
self.assertIsInstance(posterior1, GPyTorchPosterior)
self.assertEqual(posterior1.mean.shape[1], 1)

def test_variational_setUp(self):
for dtype in [torch.float, torch.double]:
train_X = torch.rand(10, 1, device=self.device, dtype=dtype)
Expand Down

0 comments on commit e9ce11f

Please sign in to comment.