Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed May 17, 2023
1 parent b2ff4b1 commit d83293a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 30 deletions.
1 change: 1 addition & 0 deletions firedrake/adjoint/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def wrapper(*args, **kwargs):
even when the user calls the lower-level :py:data:`solve(A, x, b)`.
"""
ad_block_tag = kwargs.pop("ad_block_tag", None)
# ~ breakpoint()
annotate = annotate_tape(kwargs)
with stop_annotating():
output = assemble(*args, **kwargs)
Expand Down
36 changes: 20 additions & 16 deletions firedrake/adjoint/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,18 @@ def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint

def _ad_mul(self, other):
return self._constant_from_values(self.values() * other)
return self._constant_from_values(self.dat.data_ro.reshape(-1) * other)

def _ad_add(self, other):
return self._constant_from_values(self.values() + other.values())
return self._constant_from_values(
self.dat.data_ro.reshape(-1) + other.dat.data_ro.reshape(-1)
)

def _ad_dot(self, other, options=None):
if type(other) is AdjFloat:
return sum(self.values() * other)
return sum(self.dat.data_ro.reshape(-1) * other)
else:
return sum(self.values() * other.values())
return sum(self.dat.data_ro.reshape(-1) * other.dat.data_ro.reshape(-1))

@staticmethod
def _ad_assign_numpy(dst, src, offset):
Expand All @@ -100,37 +102,39 @@ def _ad_assign_numpy(dst, src, offset):

@staticmethod
def _ad_to_list(m):
return m.values().tolist()
return m.dat.data_ro.reshape(-1).tolist()

def _ad_copy(self):
return self._constant_from_values()

def _ad_dim(self):
return numpy.prod(self.values().shape)
return numpy.prod(self.dat.data_ro.cdim)

def _ad_imul(self, other):
self.assign(self._constant_from_values(self.values() * other))
self.assign(self._constant_from_values(self.dat.data_ro.reshape(-1) * other))

def _ad_iadd(self, other):
self.assign(self._constant_from_values(self.values() + other.values()))
self.assign(self._constant_from_values(
self.dat.data_ro.reshape(-1) + other.dat.data_ro.reshape(-1)
))

def _reduce(self, r, r0):
npdata = self.values()
npdata = self.dat.data_ro.reshape(-1)
for i in range(len(npdata)):
r0 = r(npdata[i], r0)
return r0

def _applyUnary(self, f):
npdata = self.values()
npdata = self.dat.data_ro.reshape(-1)
npdatacopy = npdata.copy()
for i in range(len(npdata)):
npdatacopy[i] = f(npdata[i])
self.assign(self._constant_from_values(npdatacopy))

def _applyBinary(self, f, y):
npdata = self.values()
npdatacopy = self.values().copy()
npdatay = y.values()
npdata = self.dat.data_ro.reshape(-1)
npdatacopy = self.dat.data_ro.reshape(-1).copy()
npdatay = y.dat.data_ro.reshape(-1)
for i in range(len(npdata)):
npdatacopy[i] = f(npdata[i], npdatay[i])
self.assign(self._constant_from_values(npdatacopy))
Expand All @@ -139,17 +143,17 @@ def __deepcopy__(self, memodict={}):
return self._constant_from_values()

def _constant_from_values(self, values=None):
"""Returns a new Constant with self.values() while preserving self.ufl_shape.
"""Returns a new Constant with self.dat.data_ro.reshape(-1) while preserving self.ufl_shape.
If the optional argument `values` is provided, then `values` will be the values of the
new Constant instead, still preserving the ufl_shape of self.
Args:
values (numpy.array): An optional argument to use instead of self.values().
values (numpy.array): An optional argument to use instead of ``self.dat.data_ro.reshape(-1)``.
Returns:
Constant: The created Constant
"""
values = self.values() if values is None else values
values = self.dat.data_ro.reshape(-1) if values is None else values
return type(self)(numpy.reshape(values, self.ufl_shape), domain=extract_unique_domain(self))
1 change: 0 additions & 1 deletion firedrake/adjoint/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def wrapper(self, other, *args, **kwargs):
"""To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
Firedrake assign call."""
ad_block_tag = kwargs.pop("ad_block_tag", None)

# do not annotate in case of self assignment
annotate = annotate_tape(kwargs) and self != other

Expand Down
29 changes: 17 additions & 12 deletions firedrake/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
from firedrake.vector import Vector


def _isconstant(expr):
return isinstance(expr, Constant) or \
(isinstance(expr, Function) and expr.ufl_element().family() == "Real")


def _isfunction(expr):
return isinstance(expr, Function) and expr.ufl_element().family() != "Real"


class CoefficientCollector(MultiFunction):
"""Multifunction used for converting an expression into a weighted sum of coefficients.
Expand Down Expand Up @@ -101,11 +110,11 @@ def _is_scalar_equiv(self, weighted_coefficients):
"""Return ``True`` if the sequence of ``(coefficient, weight)`` can be compressed to
a single scalar value.
This is only true when all coefficients are :class:`firedrake.Constant` and have
shape ``(1,)``.
This is only true when all coefficients are :class:`firedrake.Constant` or
are :class:`firedrake.Function` and ``c.ufl_element().family() == "Real"``
in both cases ``c.dat.dim`` must have shape ``(1,)``.
"""
return all(isinstance(c, Constant) and c.dat.dim == (1,)
for (c, _) in weighted_coefficients)
return all(_isconstant(c) and c.dat.dim == (1,) for (c, _) in weighted_coefficients)

def _as_scalar(self, weighted_coefficients):
"""Compress a sequence of ``(coefficient, weight)`` tuples to a single scalar value.
Expand Down Expand Up @@ -210,23 +219,19 @@ def assign(self):

@cached_property
def _constants(self):
return tuple(c for (c, _) in self._weighted_coefficients
if isinstance(c, Constant))
return tuple(c for (c, _) in self._weighted_coefficients if _isconstant(c))

@cached_property
def _constant_weights(self):
return tuple(w for (c, w) in self._weighted_coefficients
if isinstance(c, Constant))
return tuple(w for (c, w) in self._weighted_coefficients if _isconstant(c))

@cached_property
def _functions(self):
return tuple(c for (c, _) in self._weighted_coefficients
if isinstance(c, Function))
return tuple(c for (c, _) in self._weighted_coefficients if _isfunction(c))

@cached_property
def _function_weights(self):
return tuple(w for (c, w) in self._weighted_coefficients
if isinstance(c, Function))
return tuple(w for (c, w) in self._weighted_coefficients if _isfunction(c))

def _assign_single_dat(self, lhs_dat, indices, rvalue, assign_to_halos):
if assign_to_halos:
Expand Down
3 changes: 2 additions & 1 deletion firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __new__(cls, value, domain=None):
element = ufl.TensorElement("R", cell, 0, shape=shape)

R = FunctionSpace(domain, element, name="firedrake.Constant")
return Function(R, val=dat)
# Explicit assign ensures correct taping for adjoint
return Function(R).assign(value)
else:
return object.__new__(cls)

Expand Down

0 comments on commit d83293a

Please sign in to comment.