Skip to content

Commit

Permalink
pyadjoint annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Nov 23, 2023
1 parent 6550b70 commit 1f0a66d
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 19 deletions.
75 changes: 66 additions & 9 deletions src/icepack/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import firedrake
from .utilities import default_solver_parameters

from firedrake.adjoint_utils.blocks.solving import NonlinearVariationalSolveBlock


class MinimizationProblem:
def __init__(self, E, S, u, bcs, form_compiler_parameters):
Expand Down Expand Up @@ -138,12 +140,67 @@ def solve(self):
r"""Step the Newton iteration until convergence"""
self.reinit()

dE_dv = self.dE_dv
S = self.problem.S
_assemble = self.problem.assemble
while abs(_assemble(dE_dv)) > self.tolerance * _assemble(S):
self.step()
if self.iteration >= self.max_iterations:
raise firedrake.ConvergenceError(
f"Newton search did not converge after {self.max_iterations} iterations!"
)
class Block(NonlinearVariationalSolveBlock):
def __init__(self, solver):
F = solver.F
super().__init__(
F == 0, solver.problem.u, solver.problem.bcs,
firedrake.adjoint(solver.J), {}, problem_J=solver.J,
solver_params=solver.search_direction_solver.parameters,
solver_kwargs={})
self._icepack__solver = solver

def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
# Re-use the NewtonSolver by copying and restoring the values of
# dependencies
deps = {bv_dep.output: bv_dep.saved_output
for bv_dep in self.get_dependencies()}
vals = {}
for eq_dep, dep in deps.items():
if isinstance(eq_dep, firedrake.Function):
vals[eq_dep] = eq_dep.copy(deepcopy=True)
eq_dep.assign(dep)
elif isinstance(eq_dep, firedrake.Constant):
# Scalar only
vals[eq_dep], = eq_dep.values()
eq_dep.assign(dep)
elif isinstance(eq_dep, firedrake.DirichletBC):
vals[eq_dep] = eq_dep.function_arg.copy(deepcopy=True)
eq_dep.function_arg = dep.function_arg
elif isinstance(eq_dep, firedrake.mesh.MeshGeometry):
# Assume fixed mesh
pass
else:
raise TypeError(f"Unexpected type: {type(eq_dep)}")
try:
self._icepack__solver.solve(**kwargs)
func.assign(self._icepack__solver.problem.u)
finally:
for eq_dep, eq_dep_val in vals.items():
if isinstance(eq_dep, firedrake.Function):
eq_dep.assign(eq_dep_val)
elif isinstance(eq_dep, firedrake.Constant):
eq_dep.assign(eq_dep_val)
elif isinstance(eq_dep, firedrake.DirichletBC):
eq_dep.function_arg = eq_dep_val
elif isinstance(eq_dep, firedrake.mesh.MeshGeometry):
pass
else:
raise TypeError(f"Unexpected type: {type(eq_dep)}")
return func

block = Block(self)
firedrake.adjoint.get_working_tape().add_block(block)

with firedrake.adjoint.stop_annotating():
dE_dv = self.dE_dv
S = self.problem.S
_assemble = self.problem.assemble
while abs(_assemble(dE_dv)) > self.tolerance * _assemble(S):
self.step()
if self.iteration >= self.max_iterations:
raise firedrake.ConvergenceError(
f"Newton search did not converge after {self.max_iterations} iterations!"
)

block.add_output(self.problem.u.create_block_variable())
40 changes: 30 additions & 10 deletions test/statistics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ def regularization(q):
assert firedrake.norm(q - q_true) < 0.25


@pytest.mark.skipif(not icepack.statistics.has_rol, reason="Couldn't import ROL")
@pytest.mark.parametrize("with_noise", [False, True])
def test_ice_shelf_inverse(with_noise):
@pytest.mark.parametrize("diagnostic_solver_type", ["icepack", "petsc"])
def test_ice_shelf_inverse(with_noise, diagnostic_solver_type):
if with_noise:
np.random.seed(561284280)
else:
np.random.seed(462461651)

Nx, Ny = 32, 32
Lx, Ly = 20e3, 20e3

Expand Down Expand Up @@ -122,7 +127,7 @@ def viscosity(**kwargs):
flow_solver = icepack.solvers.FlowSolver(
model,
dirichlet_ids=dirichlet_ids,
diagnostic_solver_type="petsc",
diagnostic_solver_type=diagnostic_solver_type,
diagnostic_solver_parameters={
"snes_type": "newtonls",
"ksp_type": "preonly",
Expand Down Expand Up @@ -160,10 +165,25 @@ def simulation(q):
velocity=u_initial, thickness=h, log_fluidity=q
)

stats_problem = StatisticsProblem(
simulation, loss_functional, regularization, q_initial
)
estimator = MaximumProbabilityEstimator(stats_problem)
q = estimator.solve()

assert firedrake.norm(q - q_true) / firedrake.norm(q_initial - q_true) < 0.25
def forward(q):
u = simulation(q)
return firedrake.assemble(loss_functional(u) + regularization(q))

q_test = firedrake.Function(q_initial.function_space()).assign(2.0)
firedrake.adjoint.continue_annotation()
J = forward(q_test)
firedrake.adjoint.pause_annotation()

q_control = firedrake.adjoint.Control(q_test)
min_order = firedrake.adjoint.taylor_test(
firedrake.adjoint.ReducedFunctional(J, q_control), q_test,
q_test.assign(0.1))
assert min_order > 1.97

# stats_problem = StatisticsProblem(
# simulation, loss_functional, regularization, q_initial
# )
# estimator = MaximumProbabilityEstimator(stats_problem)
# q = estimator.solve()

# assert firedrake.norm(q - q_true) / firedrake.norm(q_initial - q_true) < 0.25

0 comments on commit 1f0a66d

Please sign in to comment.