Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for priors in OAK Kernel #2535

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 94 additions & 1 deletion botorch/models/kernels/orthogonal_additive_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
from botorch.exceptions.errors import UnsupportedError
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.module import Module
from gpytorch.priors import Prior

from torch import nn, Tensor

_positivity_constraint = Positive()
SECOND_ORDER_PRIOR_ERROR_MSG = (
"Second order interactions are disabled, but there is a prior on the second order "
"coefficients. Please remove the second order prior or enable second order terms."
)


class OrthogonalAdditiveKernel(Kernel):
Expand All @@ -40,6 +46,9 @@ def __init__(
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
coeff_constraint: Interval = _positivity_constraint,
offset_prior: Optional[Prior] = None,
coeffs_1_prior: Optional[Prior] = None,
coeffs_2_prior: Optional[Prior] = None,
):
"""
Args:
Expand All @@ -52,9 +61,18 @@ def __init__(
dtype: Initialization dtype for required Tensors.
device: Initialization device for required Tensors.
coeff_constraint: Constraint on the coefficients of the additive kernel.
offset_prior: Prior on the offset coefficient. Should be prior with non-
negative support.
coeffs_1_prior: Prior on the parameter main effects. Should be prior with
non-negative support.
coeffs_2_prior: coeffs_1_prior: Prior on the parameter interactions. Should
be prior with non-negative support.
"""
super().__init__(batch_shape=batch_shape)
self.base_kernel = base_kernel
if not second_order and coeffs_2_prior is not None:
raise AttributeError(SECOND_ORDER_PRIOR_ERROR_MSG)

# integration nodes, weights for [0, 1]
tkwargs = {"dtype": dtype, "device": device}
z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs)
Expand Down Expand Up @@ -82,6 +100,29 @@ def __init__(
else None
),
)
if offset_prior is not None:
self.register_prior(
name="offset_prior",
prior=offset_prior,
param_or_closure=_offset_param,
setting_closure=_offset_closure,
)
if coeffs_1_prior is not None:
self.register_prior(
name="coeffs_1_prior",
prior=coeffs_1_prior,
param_or_closure=_coeffs_1_param,
setting_closure=_coeffs_1_closure,
)
if coeffs_2_prior is not None:
self.register_prior(
name="coeffs_2_prior",
prior=coeffs_2_prior,
param_or_closure=_coeffs_2_param,
setting_closure=_coeffs_2_closure,
)

# for second order interactions, we only
if second_order:
self._rev_triu_indices = torch.tensor(
_reverse_triu_indices(dim),
Expand All @@ -95,7 +136,7 @@ def __init__(
self.coeff_constraint = coeff_constraint
self.dim = dim

def k(self, x1, x2) -> Tensor:
def k(self, x1: Tensor, x2: Tensor) -> Tensor:
"""Evaluates the kernel matrix base_kernel(x1, x2) on each input dimension
independently.

Expand Down Expand Up @@ -140,6 +181,34 @@ def coeffs_2(self) -> Optional[Tensor]:
else:
return None

def _set_coeffs_1(self, value: Tensor) -> None:
value = torch.as_tensor(value).to(self.raw_coeffs_1)
value = value.expand(*self.batch_shape, self.dim)
self.initialize(raw_coeffs_1=self.coeff_constraint.inverse_transform(value))

def _set_coeffs_2(self, value: Tensor) -> None:
value = torch.as_tensor(value).to(self.raw_coeffs_1)
value = value.expand(*self.batch_shape, self.dim, self.dim)
row_idcs, col_idcs = torch.triu_indices(self.dim, self.dim, offset=1)
value = value[..., row_idcs, col_idcs].to(self.raw_coeffs_2)
self.initialize(raw_coeffs_2=self.coeff_constraint.inverse_transform(value))

def _set_offset(self, value: Tensor) -> None:
value = torch.as_tensor(value).to(self.raw_offset)
self.initialize(raw_offset=self.coeff_constraint.inverse_transform(value))

@coeffs_1.setter
def coeffs_1(self, value) -> None:
self._set_coeffs_1(value)

@coeffs_2.setter
def coeffs_2(self, value) -> None:
self._set_coeffs_2(value)

@offset.setter
def offset(self, value) -> None:
self._set_offset(value)

def forward(
self,
x1: Tensor,
Expand Down Expand Up @@ -296,3 +365,27 @@ def _reverse_triu_indices(d: int) -> list[int]:
indices.extend(range(j, j + d - i - 1)) # indexing coeffs (super-diagonal)
j += d - i - 1
return indices


def _coeffs_1_param(m: Module) -> Tensor:
return m.coeffs_1


def _coeffs_2_param(m: Module) -> Tensor:
return m.coeffs_2


def _offset_param(m: Module) -> Tensor:
return m.offset


def _coeffs_1_closure(m: Module, v: Tensor) -> Tensor:
return m._set_coeffs_1(v)


def _coeffs_2_closure(m: Module, v: Tensor) -> Tensor:
return m._set_coeffs_2(v)


def _offset_closure(m: Module, v: Tensor) -> Tensor:
return m._set_offset(v)
191 changes: 190 additions & 1 deletion test/models/kernels/test_orthogonal_additive_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.kernels.orthogonal_additive_kernel import (
OrthogonalAdditiveKernel,
SECOND_ORDER_PRIOR_ERROR_MSG,
)
from botorch.utils.testing import BotorchTestCase
from gpytorch.constraints import Positive
from gpytorch.kernels import MaternKernel, RBFKernel
from gpytorch.lazy import LazyEvaluatedKernelTensor
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors import LogNormalPrior
from gpytorch.priors.torch_priors import GammaPrior, HalfCauchyPrior, UniformPrior
from torch import nn, Tensor


Expand Down Expand Up @@ -118,6 +129,184 @@ def test_kernel(self):
tol = 1e-5
self.assertTrue(((K_ortho @ oak.w).squeeze(-1) < tol).all())

def test_priors(self):
d = 5
dtypes = [torch.float, torch.double]
batch_shapes = [(), (2,), (7, 2)]

# test no prior
oak = OrthogonalAdditiveKernel(
RBFKernel(), dim=d, batch_shape=None, second_order=True
)
for dtype, batch_shape in itertools.product(dtypes, batch_shapes):
# test with default args and batch_shape = None in second_order
tkwargs = {"dtype": dtype, "device": self.device}
offset_prior = HalfCauchyPrior(0.1).to(**tkwargs)
coeffs_1_prior = LogNormalPrior(0, 1).to(**tkwargs)
coeffs_2_prior = GammaPrior(3, 6).to(**tkwargs)
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
second_order=True,
offset_prior=offset_prior,
coeffs_1_prior=coeffs_1_prior,
coeffs_2_prior=coeffs_2_prior,
batch_shape=batch_shape,
**tkwargs,
)

self.assertIsInstance(oak.offset_prior, HalfCauchyPrior)
self.assertIsInstance(oak.coeffs_1_prior, LogNormalPrior)
self.assertEqual(oak.coeffs_1_prior.scale, 1)
self.assertEqual(oak.coeffs_2_prior.concentration, 3)

oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
second_order=True,
coeffs_1_prior=None,
coeffs_2_prior=coeffs_2_prior,
batch_shape=batch_shape,
**tkwargs,
)
self.assertEqual(oak.coeffs_2_prior.concentration, 3)
with self.assertRaisesRegex(
AttributeError,
"'OrthogonalAdditiveKernel' object has no attribute 'coeffs_1_prior",
):
_ = oak.coeffs_1_prior
# test with batch_shape = None in second_order
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
second_order=True,
coeffs_1_prior=coeffs_1_prior,
batch_shape=batch_shape,
**tkwargs,
)
with self.assertRaisesRegex(AttributeError, SECOND_ORDER_PRIOR_ERROR_MSG):
OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=False,
coeffs_2_prior=GammaPrior(1, 1),
)

# train the model to ensure that param setters are called
train_X = torch.rand(5, d, dtype=dtype, device=self.device)
train_Y = torch.randn(5, 1, dtype=dtype, device=self.device)

oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=True,
offset_prior=offset_prior,
coeffs_1_prior=coeffs_1_prior,
coeffs_2_prior=coeffs_2_prior,
**tkwargs,
)
model = SingleTaskGP(train_X=train_X, train_Y=train_Y, covar_module=oak)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 2}})

unif_prior = UniformPrior(10, 11)
# coeff_constraint is not enforced so that we can check the raw parameter
# values and not the reshaped (triu transformed) ones
oak_for_sample = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=True,
offset_prior=unif_prior,
coeffs_1_prior=unif_prior,
coeffs_2_prior=unif_prior,
coeff_constraint=Positive(transform=None, inv_transform=None),
**tkwargs,
)
oak_for_sample.sample_from_prior("offset_prior")
oak_for_sample.sample_from_prior("coeffs_1_prior")
oak_for_sample.sample_from_prior("coeffs_2_prior")

# check that all sampled values are within the bounds set by the priors
self.assertTrue(torch.all(10 <= oak_for_sample.raw_offset <= 11))
self.assertTrue(
torch.all(
(10 <= oak_for_sample.raw_coeffs_1)
* (oak_for_sample.raw_coeffs_1 <= 11)
)
)
self.assertTrue(
torch.all(
(10 <= oak_for_sample.raw_coeffs_2)
* (oak_for_sample.raw_coeffs_2 <= 11)
)
)

def test_set_coeffs(self):
d = 5
dtype = torch.double
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=True,
dtype=dtype,
)
constraint = oak.coeff_constraint
coeffs_1 = torch.arange(d, dtype=dtype)
coeffs_2 = torch.ones((d * d), dtype=dtype).reshape(d, d).triu()
oak.coeffs_1 = coeffs_1
oak.coeffs_2 = coeffs_2

self.assertAllClose(
oak.raw_coeffs_1,
constraint.inverse_transform(coeffs_1),
)
# raw_coeffs_2 has length d * (d-1) / 2
self.assertAllClose(
oak.raw_coeffs_2, constraint.inverse_transform(torch.ones(10, dtype=dtype))
)

batch_shapes = torch.Size([2]), torch.Size([5, 2])
for batch_shape in batch_shapes:
dtype = torch.double
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=batch_shape,
second_order=True,
dtype=dtype,
coeff_constraint=Positive(transform=None, inv_transform=None),
)
constraint = oak.coeff_constraint
coeffs_1 = torch.arange(d, dtype=dtype)
coeffs_2 = torch.ones((d * d), dtype=dtype).reshape(d, d).triu()
oak.coeffs_1 = coeffs_1
oak.coeffs_2 = coeffs_2

self.assertEqual(oak.raw_coeffs_1.shape, batch_shape + torch.Size([5]))
# raw_coeffs_2 has length d * (d-1) / 2
self.assertEqual(oak.raw_coeffs_2.shape, batch_shape + torch.Size([10]))

# test setting value as float
oak.offset = 0.5
self.assertAllClose(oak.offset, 0.5 * torch.ones_like(oak.offset))
# raw_coeffs_2 has length d * (d-1) / 2
oak.coeffs_1 = 0.2
self.assertAllClose(
oak.raw_coeffs_1, 0.2 * torch.ones_like(oak.raw_coeffs_1)
)
oak.coeffs_2 = 0.3
self.assertAllClose(
oak.raw_coeffs_2, 0.3 * torch.ones_like(oak.raw_coeffs_2)
)
# the lower triangular part is set to 0 automatically since the
self.assertAllClose(
oak.coeffs_2.tril(diagonal=-1), torch.zeros_like(oak.coeffs_2)
)


def isposdef(A: Tensor) -> bool:
"""Determines whether A is positive definite or not, by attempting a Cholesky
Expand Down
Loading