diff --git a/src/icepack/optimization.py b/src/icepack/optimization.py index 42e8f720..103b3d0b 100644 --- a/src/icepack/optimization.py +++ b/src/icepack/optimization.py @@ -14,6 +14,7 @@ from .utilities import default_solver_parameters from firedrake.adjoint_utils.blocks.solving import NonlinearVariationalSolveBlock +import inspect import ufl @@ -143,12 +144,17 @@ def solve(self, *, annotate=None): class Block(NonlinearVariationalSolveBlock): def __init__(self, solver): + if "adj_cache" in inspect.signature(NonlinearVariationalSolveBlock).parameters: + kwargs = {"adj_cache": {}} + else: + # Backwards compatibility + kwargs = {"dFdm_cache": {}} super().__init__( solver.F == 0, solver.problem.u, solver.problem.bcs, - adj_F=firedrake.adjoint(solver.J), adj_cache={}, + adj_F=firedrake.adjoint(solver.J), problem_J=solver.J, solver_params=solver.search_direction_solver.parameters, - solver_kwargs={}) + solver_kwargs={}, **kwargs) for dep in ufl.algorithms.extract_coefficients(solver.problem.S): self.add_dependency(dep, no_duplicates=True) self._icepack__solver = solver