diff --git a/src/icepack/optimization.py b/src/icepack/optimization.py index 87b2b9808..7cca34497 100644 --- a/src/icepack/optimization.py +++ b/src/icepack/optimization.py @@ -13,6 +13,8 @@ import firedrake from .utilities import default_solver_parameters +from firedrake.adjoint_utils.blocks.solving import NonlinearVariationalSolveBlock + class MinimizationProblem: def __init__(self, E, S, u, bcs, form_compiler_parameters): @@ -138,12 +140,67 @@ def solve(self): r"""Step the Newton iteration until convergence""" self.reinit() - dE_dv = self.dE_dv - S = self.problem.S - _assemble = self.problem.assemble - while abs(_assemble(dE_dv)) > self.tolerance * _assemble(S): - self.step() - if self.iteration >= self.max_iterations: - raise firedrake.ConvergenceError( - f"Newton search did not converge after {self.max_iterations} iterations!" - ) + class Block(NonlinearVariationalSolveBlock): + def __init__(self, solver): + F = solver.F + super().__init__( + F == 0, solver.problem.u, solver.problem.bcs, + firedrake.adjoint(solver.J), {}, problem_J=solver.J, + solver_params=solver.search_direction_solver.parameters, + solver_kwargs={}) + self._icepack__solver = solver + + def _forward_solve(self, lhs, rhs, func, bcs, **kwargs): + # Re-use the NewtonSolver by copying and restoring the values of + # dependencies + deps = {bv_dep.output: bv_dep.saved_output + for bv_dep in self.get_dependencies()} + vals = {} + for eq_dep, dep in deps.items(): + if isinstance(eq_dep, firedrake.Function): + vals[eq_dep] = eq_dep.copy(deepcopy=True) + eq_dep.assign(dep) + elif isinstance(eq_dep, firedrake.Constant): + # Scalar only + vals[eq_dep], = eq_dep.values() + eq_dep.assign(dep) + elif isinstance(eq_dep, firedrake.DirichletBC): + vals[eq_dep] = eq_dep.function_arg.copy(deepcopy=True) + eq_dep.function_arg = dep.function_arg + elif isinstance(eq_dep, firedrake.mesh.MeshGeometry): + # Assume fixed mesh + pass + else: + raise TypeError(f"Unexpected type: {type(eq_dep)}") + try: + self._icepack__solver.solve(**kwargs) + func.assign(self._icepack__solver.problem.u) + finally: + for eq_dep, eq_dep_val in vals.items(): + if isinstance(eq_dep, firedrake.Function): + eq_dep.assign(eq_dep_val) + elif isinstance(eq_dep, firedrake.Constant): + eq_dep.assign(eq_dep_val) + elif isinstance(eq_dep, firedrake.DirichletBC): + eq_dep.function_arg = eq_dep_val + elif isinstance(eq_dep, firedrake.mesh.MeshGeometry): + pass + else: + raise TypeError(f"Unexpected type: {type(eq_dep)}") + return func + + block = Block(self) + firedrake.adjoint.get_working_tape().add_block(block) + + with firedrake.adjoint.stop_annotating(): + dE_dv = self.dE_dv + S = self.problem.S + _assemble = self.problem.assemble + while abs(_assemble(dE_dv)) > self.tolerance * _assemble(S): + self.step() + if self.iteration >= self.max_iterations: + raise firedrake.ConvergenceError( + f"Newton search did not converge after {self.max_iterations} iterations!" + ) + + block.add_output(self.problem.u.create_block_variable()) diff --git a/test/statistics_test.py b/test/statistics_test.py index d3611af57..9c9dbd2e3 100644 --- a/test/statistics_test.py +++ b/test/statistics_test.py @@ -78,9 +78,14 @@ def regularization(q): assert firedrake.norm(q - q_true) < 0.25 -@pytest.mark.skipif(not icepack.statistics.has_rol, reason="Couldn't import ROL") @pytest.mark.parametrize("with_noise", [False, True]) -def test_ice_shelf_inverse(with_noise): +@pytest.mark.parametrize("diagnostic_solver_type", ["icepack", "petsc"]) +def test_ice_shelf_inverse(with_noise, diagnostic_solver_type): + if with_noise: + np.random.seed(561284280) + else: + np.random.seed(462461651) + Nx, Ny = 32, 32 Lx, Ly = 20e3, 20e3 @@ -122,7 +127,7 @@ def viscosity(**kwargs): flow_solver = icepack.solvers.FlowSolver( model, dirichlet_ids=dirichlet_ids, - diagnostic_solver_type="petsc", + diagnostic_solver_type=diagnostic_solver_type, diagnostic_solver_parameters={ "snes_type": "newtonls", "ksp_type": "preonly", @@ -160,10 +165,25 @@ def simulation(q): velocity=u_initial, thickness=h, log_fluidity=q ) - stats_problem = StatisticsProblem( - simulation, loss_functional, regularization, q_initial - ) - estimator = MaximumProbabilityEstimator(stats_problem) - q = estimator.solve() - - assert firedrake.norm(q - q_true) / firedrake.norm(q_initial - q_true) < 0.25 + def forward(q): + u = simulation(q) + return firedrake.assemble(loss_functional(u) + regularization(q)) + + q_test = firedrake.Function(q_initial.function_space()).assign(2.0) + firedrake.adjoint.continue_annotation() + J = forward(q_test) + firedrake.adjoint.pause_annotation() + + q_control = firedrake.adjoint.Control(q_test) + min_order = firedrake.adjoint.taylor_test( + firedrake.adjoint.ReducedFunctional(J, q_control), q_test, + q_test.assign(0.1)) + assert min_order > 1.97 + + # stats_problem = StatisticsProblem( + # simulation, loss_functional, regularization, q_initial + # ) + # estimator = MaximumProbabilityEstimator(stats_problem) + # q = estimator.solve() + + # assert firedrake.norm(q - q_true) / firedrake.norm(q_initial - q_true) < 0.25