Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] OrthogonalAdditiveKernel doesn't work with input transforms because they generate x values outside the unit hypercube #2270

Open
esantorella opened this issue Apr 1, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@esantorella
Copy link
Member

esantorella commented Apr 1, 2024

🐛 Bug

OrthogonalAdditiveKernel will error here if provided x values outside the unit hypercube, [0, 1]^d. Unfortunately, when combining this kernel with basic BoTorch functionality, it is hard to avoid passing such values. For example, if the search space is [0, 1], a model is trained on points ranging from 0.25 to 0.75, and a Normalize input transform is used, then 0 and 1 will transform to -1 and 2 and lie outside the hypercube.

To reproduce

** Code snippet to reproduce **

from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel

from botorch.models.gp_regression import SingleTaskGP, get_matern_kernel_with_gamma_prior
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models.transforms.outcome import Standardize
from botorch.models.transforms.input import Normalize

import torch
from botorch.acquisition.logei import qLogNoisyExpectedImprovement

train_X = torch.tensor([[0.3], [0.7]], dtype=torch.float64)
train_Y = torch.tensor([[0.3], [0.7]], dtype=torch.float64)

kernel = OrthogonalAdditiveKernel(
    base_kernel=get_matern_kernel_with_gamma_prior(
        ard_num_dims=None,
    ),
    dim=1,
    dtype=torch.double,
)

model = SingleTaskGP(
    train_X=train_X,
    train_Y=train_Y,
    input_transform=Normalize(d=1),
    outcome_transform=Standardize(m=1),
    covar_module=kernel
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_mll(mll)

model.posterior(train_X)  # works

# errors
model.posterior(torch.tensor([[0.2], [0.8]]), dtype=torch.float64)

** Stack trace/error message **

Traceback (most recent call last):
  File "/Users/lizs/oak_issue.py", line 36, in <module>
    model.posterior(torch.tensor([[0.2], [0.8]]), dtype=torch.float64)
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/gpytorch.py", line 388, in posterior
    mvn = self(X)
          ^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
    ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 281, in exact_prediction
    test_covar = joint_covar[..., self.num_train :, :].to_dense()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
    return self.evaluate_kernel().to_dense()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
    output = method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
    res = self.kernel(
          ^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/kernels/kernel.py", line 530, in __call__
    super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 163, in forward
    K_ortho = self._orthogonal_base_kernels(x1, x2)  # batch_shape x d x n1 x n2
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 202, in _orthogonal_base_kernels
    _check_hypercube(x1, "x1")
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 270, in _check_hypercube
    raise ValueError(name + " is not in hypercube [0, 1]^d.")
ValueError: x1 is not in hypercube [0, 1]^d.
acqf = qLogNoisyExpectedImprovement(
    model,
    X_baseline=train_X,
)
optimize_acqf(
    acqf,
    bounds=torch.tensor([[0.0], [1.0]]),
    q=1,
    num_restarts=16,
    raw_samples=32,
)
Traceback (most recent call last):
  File "/Users/lizs/oak_issue.py", line 43, in <module>
    optimize_acqf(
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 563, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 584, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 274, in _optimize_acqf_batch
    batch_initial_conditions = opt_inputs.get_ic_generator()(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/initializers.py", line 417, in gen_batch_initial_conditions
    Y_rnd_curr = acq_function(
                 ^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/utils/transforms.py", line 305, in decorated
    return method(cls, X, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/utils/transforms.py", line 259, in decorated
    output = method(acqf, X, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/monte_carlo.py", line 274, in forward
    non_reduced_acqval = self._non_reduced_forward(X=X)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/monte_carlo.py", line 287, in _non_reduced_forward
    samples, obj = self._get_samples_and_objectives(X)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/logei.py", line 465, in _get_samples_and_objectives
    posterior = self.model.posterior(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/gpytorch.py", line 388, in posterior
    mvn = self(X)
          ^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
    ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 281, in exact_prediction
    test_covar = joint_covar[..., self.num_train :, :].to_dense()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
    return self.evaluate_kernel().to_dense()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
    output = method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
    res = self.kernel(
          ^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/kernels/kernel.py", line 530, in __call__
    super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 163, in forward
    K_ortho = self._orthogonal_base_kernels(x1, x2)  # batch_shape x d x n1 x n2
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 202, in _orthogonal_base_kernels
    _check_hypercube(x1, "x1")
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 270, in _check_hypercube
    raise ValueError(name + " is not in hypercube [0, 1]^d.")
ValueError: x1 is not in hypercube [0, 1]^d.

Expected Behavior

This should not error. Currently, the only way to use the OAK is to manually normalize the search space (rather than the training data) to [0, 1], which is not documented or well-supported.

System information

Please complete the following information:

  • BoTorch Version: 0.10.0
  • GPyTorch Version: 1.11
  • PyTorch Version: 2.2.2
  • Computer OS: OS X
@esantorella esantorella added the bug Something isn't working label Apr 1, 2024
@Balandat
Copy link
Contributor

Balandat commented Apr 1, 2024

@SebastianAment is the requirement that all inputs are contained in the unit cube critical for this kernel?

@SebastianAment
Copy link
Contributor

Thanks for raising this. I added this check to ensure that the search space bounds are passed to Normalize, otherwise the orthogonality condition can only be guaranteed on the training set. In the example above, passing Normalize(d=1, bounds=bounds) would work. We can add this to the error message.

@SebastianAment is the requirement that all inputs are contained in the unit cube critical for this kernel?

In principle we could also open the kernel up to be evaluated outside of the orthogonality domain, but I think it's better to error out in this case, at least by default, as orthogonality is the defining property that users would expect from the kernel.

@Balandat
Copy link
Contributor

cc @hvarfner

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants