Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new Constant #2927

Merged
merged 52 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d143187
WIP
connorjward Apr 19, 2023
98e1b42
Mostly working
connorjward Apr 20, 2023
b772e2e
Fixes
connorjward Apr 20, 2023
62314a0
Remove COFFEE flop counting from notebook
connorjward Apr 28, 2023
0a387d1
Remove refs to coffee
connorjward Apr 28, 2023
397c4d3
Linting
connorjward Apr 28, 2023
25d2755
WIP
connorjward Apr 28, 2023
d1349ff
fixes
connorjward Apr 28, 2023
993c1b4
linting
connorjward Apr 28, 2023
8a93f53
DO NOT MERGE Point at correct branches
connorjward Apr 28, 2023
89cabdf
Remove Eigen
connorjward May 2, 2023
ce08a67
fixes
connorjward May 4, 2023
f672666
Fix complex mode
connorjward May 4, 2023
749c73c
Fix complex dtypes
connorjward May 4, 2023
c49746a
Simply the most ridiculous hack to get complex working
connorjward May 4, 2023
2e5abb9
linting
connorjward May 5, 2023
435d5a3
WIP: Constants with domains are functions in the RealSpace™
JDBetteridge Feb 8, 2023
4588a5a
WIP: Some things work now, needs TSFC modifications
JDBetteridge Feb 10, 2023
f88ce4c
Adjoint constants are functions in the real space
JDBetteridge Mar 23, 2023
bd06298
Add constant values as assign method
JDBetteridge Mar 23, 2023
1a11c79
WIP
connorjward Apr 18, 2023
d633729
WIP
connorjward Apr 19, 2023
ee35896
WIP
connorjward Apr 19, 2023
59c4a66
WIP
connorjward Apr 19, 2023
bcc0fc7
Remove breakpoint
connorjward Apr 19, 2023
144a8bf
Fix bug
connorjward Apr 19, 2023
d55d1f7
Fix projection for unknown in reals
JDBetteridge Apr 19, 2023
c351855
Fix more tests
JDBetteridge Apr 20, 2023
f509e8b
Fix remaining regression tests
JDBetteridge May 10, 2023
d040fe2
Remove later
JDBetteridge May 10, 2023
81bdeab
Minor linitng
JDBetteridge May 10, 2023
ad09744
Unused variables
JDBetteridge May 11, 2023
82e5128
Fix for PatchPC and output test
JDBetteridge May 15, 2023
eb37758
Fixes for Slate
JDBetteridge May 15, 2023
a19a87a
Fix failing notebook
JDBetteridge May 15, 2023
fdd7e8d
Ooops, looks like I messed up the rebase
JDBetteridge May 15, 2023
866cf4e
Fixes
JDBetteridge May 17, 2023
590e313
Fixes for Firedrake adjoint
JDBetteridge May 22, 2023
aafb23a
Refactor
JDBetteridge May 23, 2023
e133fb3
Fix issue in pyadjoint test
JDBetteridge May 23, 2023
5257537
Mixin from TSFC
JDBetteridge May 23, 2023
39f97fa
Differentiate wrt Constant
JDBetteridge May 31, 2023
d49b86c
Fix derivative for coordinates
JDBetteridge May 31, 2023
17a81db
Lint
JDBetteridge Jun 7, 2023
485af36
Suggestions from code review
JDBetteridge Jun 8, 2023
e314fe0
Connor's warning message
JDBetteridge Jun 14, 2023
e4f1077
Connor
JDBetteridge Jun 14, 2023
6f49a04
Connor
JDBetteridge Jun 14, 2023
4ea7d13
Connor
JDBetteridge Jun 14, 2023
c20ef2f
Connor
JDBetteridge Jun 14, 2023
417be55
Code review
JDBetteridge Jun 14, 2023
eda83c2
Update .github/workflows/build.yml
dham Jun 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ jobs:
--install gusto \
--install icepack \
--install irksome \
--install femlium || (cat firedrake-install.log && /bin/false)
--install femlium \
--package-branch tsfc JDBetteridge/constant_update \
--package-branch pyop2 JDBetteridge/global_literal \
--package-branch pyadjoint connorjward/constants-away || (cat firedrake-install.log && /bin/false)
dham marked this conversation as resolved.
Show resolved Hide resolved
- name: Install test dependencies
run: |
. ../firedrake_venv/bin/activate
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/11-extract-adjoint-solutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
"u.assign(u_)\n",
"\n",
"# Set diffusivity constant\n",
"nu = Constant(0.0001)\n",
"nu = Constant(0.0001, domain=mesh)\n",
"\n",
"# Define nonlinear form\n",
"dt = 1.0/n\n",
Expand Down Expand Up @@ -1118,7 +1118,7 @@
],
"source": [
"g = compute_gradient(J, Control(nu))\n",
"print(\"Gradient of J w.r.t. diffusivity = {:.4f}\".format(*g.values()))"
"print(\"Gradient of J w.r.t. diffusivity = {:.4f}\".format(*g.dat.data_ro))"
]
},
{
Expand Down
36 changes: 20 additions & 16 deletions firedrake/adjoint/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,18 @@ def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint

def _ad_mul(self, other):
return self._constant_from_values(self.values() * other)
return self._constant_from_values(self.dat.data_ro.reshape(-1) * other)

def _ad_add(self, other):
return self._constant_from_values(self.values() + other.values())
return self._constant_from_values(
self.dat.data_ro.reshape(-1) + other.dat.data_ro.reshape(-1)
)

def _ad_dot(self, other, options=None):
if type(other) is AdjFloat:
return sum(self.values() * other)
return sum(self.dat.data_ro.reshape(-1) * other)
else:
return sum(self.values() * other.values())
return sum(self.dat.data_ro.reshape(-1) * other.dat.data_ro.reshape(-1))

@staticmethod
def _ad_assign_numpy(dst, src, offset):
Expand All @@ -100,37 +102,39 @@ def _ad_assign_numpy(dst, src, offset):

@staticmethod
def _ad_to_list(m):
return m.values().tolist()
return m.dat.data_ro.reshape(-1).tolist()

def _ad_copy(self):
return self._constant_from_values()

def _ad_dim(self):
return numpy.prod(self.values().shape)
return numpy.prod(self.dat.data_ro.cdim)

def _ad_imul(self, other):
self.assign(self._constant_from_values(self.values() * other))
self.assign(self._constant_from_values(self.dat.data_ro.reshape(-1) * other))

def _ad_iadd(self, other):
self.assign(self._constant_from_values(self.values() + other.values()))
self.assign(self._constant_from_values(
self.dat.data_ro.reshape(-1) + other.dat.data_ro.reshape(-1)
))

def _reduce(self, r, r0):
npdata = self.values()
npdata = self.dat.data_ro.reshape(-1)
for i in range(len(npdata)):
r0 = r(npdata[i], r0)
return r0

def _applyUnary(self, f):
npdata = self.values()
npdata = self.dat.data_ro.reshape(-1)
npdatacopy = npdata.copy()
for i in range(len(npdata)):
npdatacopy[i] = f(npdata[i])
self.assign(self._constant_from_values(npdatacopy))

def _applyBinary(self, f, y):
npdata = self.values()
npdatacopy = self.values().copy()
npdatay = y.values()
npdata = self.dat.data_ro.reshape(-1)
npdatacopy = self.dat.data_ro.reshape(-1).copy()
npdatay = y.dat.data_ro.reshape(-1)
for i in range(len(npdata)):
npdatacopy[i] = f(npdata[i], npdatay[i])
self.assign(self._constant_from_values(npdatacopy))
Expand All @@ -139,17 +143,17 @@ def __deepcopy__(self, memodict={}):
return self._constant_from_values()

def _constant_from_values(self, values=None):
"""Returns a new Constant with self.values() while preserving self.ufl_shape.
"""Returns a new Constant with self.dat.data_ro.reshape(-1) while preserving self.ufl_shape.

If the optional argument `values` is provided, then `values` will be the values of the
new Constant instead, still preserving the ufl_shape of self.

Args:
values (numpy.array): An optional argument to use instead of self.values().
values (numpy.array): An optional argument to use instead of ``self.dat.data_ro.reshape(-1)``.

Returns:
Constant: The created Constant

"""
values = self.values() if values is None else values
values = self.dat.data_ro.reshape(-1) if values is None else values
return type(self)(numpy.reshape(values, self.ufl_shape), domain=extract_unique_domain(self))
4 changes: 3 additions & 1 deletion firedrake/adjoint/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def wrapper(self, other, *args, **kwargs):
"""To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
Firedrake assign call."""
ad_block_tag = kwargs.pop("ad_block_tag", None)

# do not annotate in case of self assignment
annotate = annotate_tape(kwargs) and self != other

Expand Down Expand Up @@ -348,6 +347,9 @@ def _ad_iadd(self, other):
else:
vec += ovec

def _ad_function_space(self, mesh):
return self.ufl_function_space()

def _reduce(self, r, r0):
vec = self.vector().get_local()
for i in range(len(vec)):
Expand Down
6 changes: 3 additions & 3 deletions firedrake/adjoint/variational_solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from firedrake.constant import Constant
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint.blocks import NonlinearVariationalSolveBlock
Expand Down Expand Up @@ -105,6 +104,7 @@ def _ad_problem_clone(self, problem, dependencies):
expressions, we'll instead create clones of them.
"""
from firedrake import NonlinearVariationalProblem
from firedrake.function import Function
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved
F_replace_map = {}
J_replace_map = {}

Expand All @@ -115,7 +115,7 @@ def _ad_problem_clone(self, problem, dependencies):
for block_variable in dependencies:
coeff = block_variable.output
if coeff in F_coefficients and coeff not in F_replace_map:
if isinstance(coeff, Constant):
if isinstance(coeff, Function) and coeff.ufl_element().family() == "Real":
F_replace_map[coeff] = copy.deepcopy(coeff)
else:
F_replace_map[coeff] = coeff.copy(deepcopy=True)
Expand All @@ -124,7 +124,7 @@ def _ad_problem_clone(self, problem, dependencies):
if coeff in J_coefficients and coeff not in J_replace_map:
if coeff in F_replace_map:
J_replace_map[coeff] = F_replace_map[coeff]
elif isinstance(coeff, Constant):
elif isinstance(coeff, Function) and coeff.ufl_element().family() == "Real":
J_replace_map[coeff] = copy.deepcopy(coeff)
else:
J_replace_map[coeff] = coeff.copy()
Expand Down
27 changes: 27 additions & 0 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def __init__(self, form, local_knl, all_integer_subdomain_ids, diagonal=False, u
self._unroll = unroll

self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo)
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)

self._map_arg_cache = {}
# Cache for holding :class:`op2.MapKernelArg` instances.
Expand Down Expand Up @@ -984,6 +985,13 @@ def _as_global_kernel_arg_coefficient(_, self):
return self._make_dat_global_kernel_arg(finat_element, index)


@_as_global_kernel_arg.register(kernel_args.ConstantKernelArg)
def _as_global_kernel_arg_constant(_, self):
const = next(self._constants)
value_size = numpy.prod(const.ufl_shape, dtype=int)
return op2.GlobalKernelArg((value_size,))


@_as_global_kernel_arg.register(kernel_args.CellSizesKernelArg)
def _as_global_kernel_arg_cell_sizes(_, self):
# this mirrors tsfc.kernel_interface.firedrake_loopy.KernelBuilder.set_cell_sizes
Expand Down Expand Up @@ -1049,6 +1057,7 @@ def __init__(self, form, local_knl, global_knl, tensor,
self._lgmaps = lgmaps

self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo)
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)

def build(self):
"""Construct the parloop."""
Expand Down Expand Up @@ -1179,6 +1188,12 @@ def _as_parloop_arg_coefficient(arg, self):
return op2.DatParloopArg(coeff.dat, m)


@_as_parloop_arg.register(kernel_args.ConstantKernelArg)
def _as_parloop_arg_constant(arg, self):
const = next(self._constants)
return op2.GlobalParloopArg(const.dat)


@_as_parloop_arg.register(kernel_args.CellOrientationsKernelArg)
def _as_parloop_arg_cell_orientations(_, self):
func = self._mesh.cell_orientations()
Expand Down Expand Up @@ -1229,6 +1244,18 @@ def iter_active_coefficients(form, kinfo):
for subidx in subidxs:
yield form.coefficients()[idx].subfunctions[subidx]

@staticmethod
def iter_constants(form, kinfo):
"""Yield the form constants"""
# Is kinfo really needed?
from tsfc.ufl_utils import extract_firedrake_constants
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
else:
for const in extract_firedrake_constants(form):
yield const

@staticmethod
def index_function_spaces(form, indices):
"""Return the function spaces of the form's arguments, indexed
Expand Down
34 changes: 21 additions & 13 deletions firedrake/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
from firedrake.vector import Vector


def _isconstant(expr):
return isinstance(expr, Constant) or \
(isinstance(expr, Function) and expr.ufl_element().family() == "Real")


def _isfunction(expr):
return isinstance(expr, Function) and expr.ufl_element().family() != "Real"


class CoefficientCollector(MultiFunction):
"""Multifunction used for converting an expression into a weighted sum of coefficients.

Expand Down Expand Up @@ -90,18 +99,21 @@ def component_tensor(self, o, a, _):
def coefficient(self, o):
return ((o, 1),)

def constant_value(self, o):
return ((o, 1),)

def expr(self, o, *operands):
raise NotImplementedError(f"Handler not defined for {type(o)}")

def _is_scalar_equiv(self, weighted_coefficients):
"""Return ``True`` if the sequence of ``(coefficient, weight)`` can be compressed to
a single scalar value.

This is only true when all coefficients are :class:`firedrake.Constant` and have
shape ``(1,)``.
This is only true when all coefficients are :class:`firedrake.Constant` or
are :class:`firedrake.Function` and ``c.ufl_element().family() == "Real"``
in both cases ``c.dat.dim`` must have shape ``(1,)``.
"""
return all(isinstance(c, Constant) and c.dat.dim == (1,)
for (c, _) in weighted_coefficients)
return all(_isconstant(c) and c.dat.dim == (1,) for (c, _) in weighted_coefficients)

def _as_scalar(self, weighted_coefficients):
"""Compress a sequence of ``(coefficient, weight)`` tuples to a single scalar value.
Expand Down Expand Up @@ -134,7 +146,7 @@ def __init__(self, assignee, expression, subset=None):
expression = as_ufl(expression)

for coeff in extract_coefficients(expression):
if isinstance(coeff, Function):
if isinstance(coeff, Function) and coeff.ufl_element().family() != "Real":
if coeff.ufl_element() != assignee.ufl_element():
raise ValueError("All functions in the expression must have the same "
"element as the assignee")
Expand Down Expand Up @@ -206,23 +218,19 @@ def assign(self):

@cached_property
def _constants(self):
return tuple(c for (c, _) in self._weighted_coefficients
if isinstance(c, Constant))
return tuple(c for (c, _) in self._weighted_coefficients if _isconstant(c))

@cached_property
def _constant_weights(self):
return tuple(w for (c, w) in self._weighted_coefficients
if isinstance(c, Constant))
return tuple(w for (c, w) in self._weighted_coefficients if _isconstant(c))

@cached_property
def _functions(self):
return tuple(c for (c, _) in self._weighted_coefficients
if isinstance(c, Function))
return tuple(c for (c, _) in self._weighted_coefficients if _isfunction(c))

@cached_property
def _function_weights(self):
return tuple(w for (c, w) in self._weighted_coefficients
if isinstance(c, Function))
return tuple(w for (c, w) in self._weighted_coefficients if _isfunction(c))

def _assign_single_dat(self, lhs_dat, indices, rvalue, assign_to_halos):
if assign_to_halos:
Expand Down
2 changes: 1 addition & 1 deletion firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal
@function_arg.setter
def function_arg(self, g):
'''Set the value of this boundary condition.'''
if isinstance(g, firedrake.Function):
if isinstance(g, firedrake.Function) and g.ufl_element().family() != "Real":
if g.function_space() != self.function_space():
raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g)
self._function_arg = g
Expand Down
Loading