Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] An exception should be raised when the training data has requires_grad=True #2253

Open
rexxy-sasori opened this issue Mar 19, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@rexxy-sasori
Copy link

rexxy-sasori commented Mar 19, 2024

🐛 Bug

The train_X and train_Y that go into SingleTaskGP will lead to a failing fit_gpytorch_mll if they have require_grad=True, i.e. grad_fn is not None. The error goes away when the flag require_grad=False

To reproduce

** Code snippet to reproduce **

# Your code goes here
# Please make sure it does not require any external dependencies
import botorch, gpytorch, torch
from botorch.models import FixedNoiseGP, ModelListGP, SingleTaskGP
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from botorch import fit_gpytorch_mll

print(botorch.__version__)
print(gpytorch.__version__)
print(torch.__version__)

train_x = torch.randn(30, 6, requires_grad=True)

train_obj = 3*train_x + train_x**2
train_con = 4*train_x - train_x**3

print(train_obj.grad_fn)
print(train_con.grad_fn)

model_obj = SingleTaskGP(train_x, train_obj).to(train_x)
model_con = SingleTaskGP(train_x, train_con).to(train_x)

model = ModelListGP(model_obj, model_con)
mll = SumMarginalLogLikelihood(model.likelihood, model)

fit_gpytorch_mll(mll)

** Stack trace/error message **

RuntimeError                              Traceback (most recent call last)
Cell In[55], line 25
     22 model = ModelListGP(model_obj, model_con)
     23 mll = SumMarginalLogLikelihood(model.likelihood, model)
---> 25 fit_gpytorch_mll(mll)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:305, in _fit_list(mll, _, __, **kwargs)
    303 mll.train()
    304 for sub_mll in mll.mlls:
--> 305     fit_gpytorch_mll(sub_mll, **kwargs)
    307 return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:252, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, warning_handler, caught_exception_types, **ignore)
    250 with catch_warnings(record=True) as warning_list, debug(True):
    251     simplefilter("always", category=OptimizationWarning)
--> 252     optimizer(mll, closure=closure, **optimizer_kwargs)
    254 # Resolved warnings and determine whether or not to retry
    255 done = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/fit.py:92, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
     89 if closure_kwargs is not None:
     90     closure = partial(closure, **closure_kwargs)
---> 92 result = scipy_minimize(
     93     closure=closure,
     94     parameters=parameters,
     95     bounds=bounds,
     96     method=method,
     97     options=options,
     98     callback=callback,
     99     timeout_sec=timeout_sec,
    100 )
    101 if result.status != OptimizationStatus.SUCCESS:
    102     warn(
    103         f"`scipy_minimize` terminated with status {result.status}, displaying"
    104         f" original message from `scipy.optimize.minimize`: {result.message}",
    105         OptimizationWarning,
    106     )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/core.py:109, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
    101         result = OptimizationResult(
    102             step=next(call_counter),
    103             fval=float(wrapped_closure(x)[0]),
    104             status=OptimizationStatus.RUNNING,
    105             runtime=monotonic() - start_time,
    106         )
    107         return callback(parameters, result)  # pyre-ignore [29]
--> 109 raw = minimize_with_timeout(
    110     wrapped_closure,
    111     wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
    112     jac=True,
    113     bounds=bounds_np,
    114     method=method,
    115     options=options,
    116     callback=wrapped_callback,
    117     timeout_sec=timeout_sec,
    118 )
    120 # Post-processing and outcome handling
    121 wrapped_closure.state = asarray(raw.x)  # set parameter state to optimal values

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/timeout.py:80, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
     77     wrapped_callback = callback
     79 try:
---> 80     return optimize.minimize(
     81         fun=fun,
     82         x0=x0,
     83         args=args,
     84         method=method,
     85         jac=jac,
     86         hess=hess,
     87         hessp=hessp,
     88         bounds=bounds,
     89         constraints=constraints,
     90         tol=tol,
     91         callback=wrapped_callback,
     92         options=options,
     93     )
     94 except OptimizationTimeoutError as e:
     95     msg = f"Optimization timed out after {e.runtime} seconds."

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_minimize.py:710, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    707     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    708                              **options)
    709 elif meth == 'l-bfgs-b':
--> 710     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    711                            callback=callback, **options)
    712 elif meth == 'tnc':
    713     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    714                         **options)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py:365, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    359 task_str = task.tobytes()
    360 if task_str.startswith(b'FG'):
    361     # The minimization routine wants f and g at the current x.
    362     # Note that interruptions due to maxfun are postponed
    363     # until the completion of the current minimization iteration.
    364     # Overwrite f and g:
--> 365     f, g = func_and_grad(x)
    366 elif task_str.startswith(b'NEW_X'):
    367     # new iteration
    368     n_iterations += 1

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:285, in ScalarFunction.fun_and_grad(self, x)
    283 if not np.array_equal(x, self.x):
    284     self._update_x_impl(x)
--> 285 self._update_fun()
    286 self._update_grad()
    287 return self.f, self.g

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
    249 def _update_fun(self):
    250     if not self.f_updated:
--> 251         self._update_fun_impl()
    252         self.f_updated = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
    154 def update_fun():
--> 155     self.f = fun_wrapped(self.x)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
    133 self.nfev += 1
    134 # Send a copy because the user may overwrite it.
    135 # Overwriting results in undefined behaviour because
    136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
    138 # Make sure the function returns a true scalar
    139 if not np.isscalar(fx):

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:77, in MemoizeJac.__call__(self, x, *args)
     75 def __call__(self, x, *args):
     76     """ returns the function value """
---> 77     self._compute_if_needed(x, *args)
     78     return self._value

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:71, in MemoizeJac._compute_if_needed(self, x, *args)
     69 if not np.all(x == self.x) or self._value is None or self.jac is None:
     70     self.x = np.asarray(x).copy()
---> 71     fg = self.fun(x, *args)
     72     self.jac = fg[1]
     73     self._value = fg[0]

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:160, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    158         index += size
    159 except RuntimeError as e:
--> 160     value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
    162 return value, grads

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/common.py:52, in _handle_numerical_errors(error, x, dtype)
     50     _dtype = x.dtype if dtype is None else dtype
     51     return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype)
---> 52 raise error

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:150, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    147     self.state = state
    149 try:
--> 150     value_tensor, grad_tensors = self.closure(**kwargs)
    151     value = self.as_array(value_tensor)
    152     grads = self._get_gradient_ndarray(fill_value=self.fill_value)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 values = self.forward(**kwargs)
     65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
     68 grads = tuple(param.grad for param in self.parameters.values())
     69 if self.callback:

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/_tensor.py:522, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    512 if has_torch_function_unary(self):
    513     return handle_torch_function(
    514         Tensor.backward,
    515         (self,),
   (...)
    520         inputs=inputs,
    521     )
--> 522 torch.autograd.backward(
    523     self, gradient, retain_graph, create_graph, inputs=inputs
    524 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/autograd/__init__.py:266, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    261     retain_graph = create_graph
    263 # The reason we repeat the same comment below is that
    264 # some Python versions print out the first line of a multi-line function
    265 # calls in the traceback and some print out the last line
--> 266 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    267     tensors,
    268     grad_tensors_,
    269     retain_graph,
    270     create_graph,
    271     inputs,
    272     allow_unreachable=True,
    273     accumulate_grad=True,
    274 )

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Expected Behavior

An error message is thrown:

"Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward"

System information

Please complete the following information:

  • BoTorch 0.10.1.dev16+g3e34a4fc
  • GPyTorch 1.12.dev28+g392dd41e
  • PyTorch 2.2.1
  • MacOS 14.4 (23E214) Sonoma on M3 Pro

Additional context

@rexxy-sasori rexxy-sasori added the bug Something isn't working label Mar 19, 2024
@Balandat
Copy link
Contributor

Is there a particular use case you have for training data with requires_grad=True? Or is this an issue that you ran into without deliberately setting this? In general it's not clear what it would mean to fit the model if the training data itself were a parameter (you'd get a perfect fit all the time...). One could compute the gradient of the model parameters w.r.t. the training data at the optimum (MAP maximizer), but that's a different thing.

@esantorella
Copy link
Member

Thanks for reporting this issue. I'm not surprised this fails, because BoTorch figures out which tensors are parameters that need to be optimized by looking at which have requires_grad=True. I second Max's question about the use case, since I'm not sure whether it would make sense to support this. Would it work to (perhaps temporarily) detach the input data?

@rexxy-sasori
Copy link
Author

rexxy-sasori commented Mar 19, 2024

Sorry, perhaps I should give a little bit context. This is actually an issue I ran into without deliberately setting requires_grad. The example here is just a demonstration for anyone here to reproduce the bug. I want to use risk-averse BO for model predictive control, in which I first built a model in PyTorch that maps my control variable to my objective. However, the fit_gpytorch method always failed until I figured out that the issue went away until I used torch.no_grad()

@Balandat
Copy link
Contributor

However, the fit_gpytorch method always failed until I figured out that the issue went away until I used with torch.no_grad()

Just to make sure there is no confusion here, you are not putting the fit_gpytorch_mll() call into a no_grad() context, right? Just making sure the inputs to the GP model don't require gradients, i.e. do whatever prediction you do on your model in a no_grad() context.

It may make sense on our end to explicitly check whether the training data requires grad when calling fit_gpytorch_mll to emit a more informative error message.

@rexxy-sasori
Copy link
Author

No I am not putting the fit_gpytorch_mll() call into a no_grad() context.

Yes, I agree with your suggestion to explicitly check whether the training data requires grad when calling fit_gpytorch_mll to emit a more informative error message.

@esantorella esantorella self-assigned this Mar 19, 2024
@esantorella esantorella removed their assignment Apr 5, 2024
@esantorella esantorella changed the title [Bug] Fitting gpytorch model fails when the training data has requires_grad=True [Bug] An exception should be raised when the training data has requires_grad=True Apr 5, 2024
saitcakmak added a commit to saitcakmak/botorch that referenced this issue Jul 24, 2024
Summary:
Addresses pytorch#2253

When the model inputs have gradients enabled, we get errors during model fitting.

Differential Revision: D60184299
@saitcakmak
Copy link
Contributor

I looked into adding an exception in model input validation but disallowing inputs that require gradients breaks acquisition functions that utilize fantasy models. If we wanted to prevent this error, the validation would have to happen in fit_gpytorch_mll rather than in model constructors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants