Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nbouziani committed Sep 12, 2023
1 parent 9e7e10e commit 3d00bf4
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 106 deletions.
11 changes: 3 additions & 8 deletions test/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_differentiation(V1, V2):
F = Iu * v * dx
Ihat = TrialFunction(Iu.ufl_function_space())
dFdu = expand_derivatives(derivative(F, u, uhat))
# Compute dFdu = \partial F/\partial u + Action(dFdIu, dIu/du)
# Compute dFdu = ∂F/∂u + Action(dFdIu, dIu/du)
# = Action(dFdIu, Iu(uhat, v*))
dFdIu = expand_derivatives(derivative(F, Iu, Ihat))
assert dFdIu == Ihat * v * dx
Expand All @@ -113,8 +113,8 @@ def test_differentiation(V1, V2):
# -- Differentiate: u * I(u, V2) * v * dx -- #
F = u * Iu * v * dx
dFdu = expand_derivatives(derivative(F, u, uhat))
# Compute dFdu = \partial F/\partial u + Action(dFdIu, dIu/du)
# = \partial F/\partial u + Action(dFdIu, Iu(uhat, v*))
# Compute dFdu = ∂F/∂u + Action(dFdIu, dIu/du)
# = ∂F/∂u + Action(dFdIu, Iu(uhat, v*))
dFdu_partial = uhat * Iu * v * dx
dFdIu = Ihat * u * v * dx
assert dFdu == dFdu_partial + Action(dFdIu, dIu)
Expand Down Expand Up @@ -143,27 +143,22 @@ def test_extract_base_form_operators(V1, V2):
# -- Interp(u, V2) -- #
Iu = Interp(u, V2)
assert extract_arguments(Iu) == [vstar]
# assert extract_arguments_and_coefficients(Iu) == ([vstar], [u, Iu.result_coefficient()])
assert extract_arguments_and_coefficients(Iu) == ([vstar], [u])

F = Iu * dx
# Form composition: Iu * dx <=> Action(v * dx, Iu(u; v*))
assert extract_arguments(F) == []
assert extract_arguments_and_coefficients(F) == ([], [u])
# assert extract_arguments_and_coefficients(F) == ([], [u, Iu.result_coefficient()])

for e in [Iu, F]:
# assert extract_coefficients(e) == [u, Iu.result_coefficient()]
assert extract_coefficients(e) == [u]
assert extract_base_form_operators(e) == [Iu]

# -- Interp(u, V2) -- #
Iv = Interp(uhat, V2)
assert extract_arguments(Iv) == [vstar, uhat]
assert extract_arguments_and_coefficients(Iv) == ([vstar, uhat], [])
# assert extract_arguments_and_coefficients(Iv) == ([vstar, uhat], [Iv.result_coefficient()])
assert extract_coefficients(Iv) == []
# assert extract_coefficients(Iv) == [Iv.result_coefficient()]
assert extract_base_form_operators(Iv) == [Iv]

# -- Action(v * v2 * dx, Iv) -- #
Expand Down
23 changes: 12 additions & 11 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def extract_type(a, ufl_types):
if not isinstance(ufl_types, (list, tuple)):
ufl_types = (ufl_types,)

if all(t is not BaseFormOperator for t in ufl_types):
remove_base_form_ops = True
ufl_types += (BaseFormOperator,)
else:
remove_base_form_ops = False

# BaseForms that aren't forms or base form operators
# only contain arguments & coefficients
if isinstance(a, BaseForm) and not isinstance(a, (Form, BaseFormOperator)):
Expand All @@ -74,15 +80,7 @@ def extract_type(a, ufl_types):
if any(isinstance(o, t) for t in ufl_types))

# Need to extract objects contained in base form operators whose type is in ufl_types
if all(t is not BaseFormOperator for t in ufl_types):
# For the case where there are no base form operators in `a`, this effectively doubles the time
# since we traverse the DAG twice.
# -> A solution would be to have a mechanism to collect these operators as we traverse
# the DAG for the first time and use that information.
base_form_ops = extract_type(a, BaseFormOperator)
else:
base_form_ops = set(e for e in objects if isinstance(e, BaseFormOperator))

base_form_ops = set(e for e in objects if isinstance(e, BaseFormOperator))
ufl_types_no_args = tuple(t for t in ufl_types if not issubclass(t, BaseArgument))
base_form_objects = ()
for o in base_form_ops:
Expand All @@ -92,17 +90,20 @@ def extract_type(a, ufl_types):
for ai in tuple(arg for arg in o.argument_slots(isinstance(a, Form))):
# Extracting BaseArguments of an object of which a Coargument is an argument,
# then we just return the dual argument of the Coargument and not its primal argument.
# TODO: There might be a cleaner way to handle that case.
if isinstance(ai, Coargument):
ufl_types = tuple(Coargument if t is BaseArgument else t for t in ufl_types)
base_form_objects += tuple(extract_type(ai, ufl_types))
# Look for BaseArguments in BaseFormOperator's argument slots only since that's where they are by definition.
# Don't look into operands, which is convenient for external operator composition, e.g. N1(N2; v*)
# where N2 is seen as an operator and not a form.
slots = o.ufl_operands # + (o.result_coefficient(),)
slots = o.ufl_operands
for ai in slots:
base_form_objects += tuple(extract_type(ai, ufl_types_no_args))
objects.update(base_form_objects)

# `Remove BaseFormOperator` objects if there were initially not in `ufl_types`
if remove_base_form_ops:
objects -= base_form_ops
return objects


Expand Down
7 changes: 3 additions & 4 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def apply_derivatives(expression):
# - a BaseFormOperator → Return `d(expression)/dw` where `w` is the coefficient produced by the bfo `var`.
# - else → Record the bfo on the MultiFunction object and returns 0.
# Example:
# → If derivative(F(u, N(u); v), u) was taken the following line would compute `\frac{\partial F}{\partial u}`.
# → If derivative(F(u, N(u); v), u) was taken the following line would compute `∂F/∂u`.
dexpression_dvar = map_integrand_dags(rules, expression)

# Get the recorded delayed operations
Expand All @@ -1262,12 +1262,11 @@ def apply_derivatives(expression):
var, der_kwargs, *base_form_ops = pending_operations
for N in sorted(set(base_form_ops), key=lambda x: x.count()):
# Replace dexpr/dvar by dexpr/dN. We don't use `apply_derivatives` since
# the differentiation is done via `\partial` and not `d`.
# the differentiation is done via `` and not `d`.
dexpr_dN = map_integrand_dags(rules, replace_derivative_nodes(expression, {var.ufl_operands[0]: N}))
# Add the BaseFormOperatorDerivative node
# TODO: Should we use `derivative` here to take into account Extop/Interp node?
dN_dvar = apply_derivatives(BaseFormOperatorDerivative(N, var, **der_kwargs))
# Sum the Action: dF/du = \partial F/\partial u + \sum_{i=1,...} Action(dF/dNi, dNi/du)
# Sum the Action: dF/du = ∂F/∂u + \sum_{i=1,...} Action(dF/dNi, dNi/du)
if not (isinstance(dexpr_dN, Form) and len(dexpr_dN.integrals()) == 0):
# In this case: Action <=> ufl.action since `dN_var` has 2 arguments.
# We use Action to handle the trivial case dN_dvar = 0.
Expand Down
67 changes: 16 additions & 51 deletions ufl/core/base_form_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,25 @@
#
# Modified by Nacime Bouziani, 2021-2022

from ufl.coefficient import Coefficient
from collections import OrderedDict

from ufl.argument import Argument, Coargument
from ufl.core.operator import Operator
from ufl.form import BaseForm
from ufl.core.ufl_type import ufl_type
from ufl.constantvalue import as_ufl
from ufl.finiteelement import FiniteElementBase
from ufl.domain import default_domain
from ufl.functionspace import AbstractFunctionSpace, FunctionSpace
from ufl.referencevalue import ReferenceValue
from ufl.functionspace import AbstractFunctionSpace
from ufl.utils.counted import Counted


@ufl_type(num_ops="varying", is_differential=True)
class BaseFormOperator(Operator, BaseForm):
class BaseFormOperator(Operator, BaseForm, Counted):

# Slots are disabled here because they cause trouble in PyDOLFIN
# multiple inheritance pattern:
_ufl_noslots_ = True

def __init__(self, *operands, function_space, derivatives=None, result_coefficient=None, argument_slots=()):
def __init__(self, *operands, function_space, derivatives=None, argument_slots=()):
r"""
:param operands: operands on which acts the operator.
:param function_space: the :class:`.FunctionSpace`,
Expand All @@ -43,15 +42,10 @@ def __init__(self, *operands, function_space, derivatives=None, result_coefficie
ufl_operands = tuple(map(as_ufl, operands))
argument_slots = tuple(map(as_ufl, argument_slots))
Operator.__init__(self, ufl_operands)
Counted.__init__(self, counted_class=BaseFormOperator)

# -- Function space -- #
if isinstance(function_space, FiniteElementBase):
# For legacy support for .ufl files using cells, we map
# the cell to The Default Mesh
element = function_space
domain = default_domain(element.cell())
function_space = FunctionSpace(domain, element)
elif not isinstance(function_space, AbstractFunctionSpace):
if not isinstance(function_space, AbstractFunctionSpace):
raise ValueError("Expecting a FunctionSpace or FiniteElement.")

# -- Derivatives -- #
Expand All @@ -60,13 +54,6 @@ def __init__(self, *operands, function_space, derivatives=None, result_coefficie
# argument slots (e.g. Interp)
self.derivatives = derivatives

# Produce the resulting Coefficient: Is that really needed?
if result_coefficient is None:
result_coefficient = Coefficient(function_space)
elif not isinstance(result_coefficient, (Coefficient, ReferenceValue)):
raise TypeError('Expecting a Coefficient and not %s', type(result_coefficient))
self._result_coefficient = result_coefficient

# -- Argument slots -- #
if len(argument_slots) == 0:
# Make v*
Expand All @@ -81,27 +68,20 @@ def __init__(self, *operands, function_space, derivatives=None, result_coefficie
ufl_free_indices = ()
ufl_index_dimensions = ()

def result_coefficient(self, unpack_reference=True):
"Returns the coefficient produced by the base form operator"
result_coefficient = self._result_coefficient
if unpack_reference and isinstance(result_coefficient, ReferenceValue):
return result_coefficient.ufl_operands[0]
return result_coefficient

def argument_slots(self, outer_form=False):
r"""Returns a tuple of expressions containing argument and coefficient based expressions.
We get an argument uhat when we take the Gateaux derivative in the direction uhat:
-> d/du N(u; v*) = dNdu(u; uhat, v*) where uhat is a ufl.Argument and v* a ufl.Coargument
Applying the action replace the last argument by coefficient:
-> action(dNdu(u; uhat, v*), w) = dNdu(u; w, v*) where du is a ufl.Coefficient
"""
from ufl.algorithms.analysis import extract_arguments
if not outer_form:
return self._argument_slots
# Takes into account argument contraction when a base form operator is in an outer form:
# For example:
# F = N(u; v*) * v * dx can be seen as Action(v1 * v * dx, N(u; v*))
# => F.arguments() should return (v,)!
from ufl.algorithms.analysis import extract_arguments
return tuple(a for a in self._argument_slots[1:] if len(extract_arguments(a)) != 0)

def coefficients(self):
Expand All @@ -123,7 +103,6 @@ def _analyze_form_arguments(self):
+ tuple(a for arg in arguments for a in extract_arguments(arg)))
coefficients = tuple(c for op in self.ufl_operands for c in extract_coefficients(op))
# Define canonical numbering of arguments and coefficients
from collections import OrderedDict
# 1) Need concept of order since we may have arguments with the same number
# because of form composition (`argument_slots(outer_form=True)`):
# Example: Let u \in V1 and N \in V2 and F = N(u; v*) * dx, then
Expand All @@ -135,36 +114,22 @@ def _analyze_form_arguments(self):
self._coefficients = tuple(sorted(set(coefficients), key=lambda x: x.count()))

def count(self):
"Returns the count associated to the coefficient produced by the base form operator"
return self._count

@property
def _count(self):
return self.result_coefficient()._count
"Returns the count associated to the base form operator"
return self.count

@property
def ufl_shape(self):
"Returns the UFL shape of the coefficient.produced by the operator"
return self.result_coefficient()._ufl_shape
return self.arguments()[0]._ufl_shape

def ufl_function_space(self):
"Returns the ufl function space associated to the operator"
return self.result_coefficient()._ufl_function_space
"Returns the function space associated to the operator, i.e. the dual of the base form operator's `Coargument`"
return self.arguments()[0]._ufl_function_space.dual()

def _ufl_expr_reconstruct_(self, *operands, function_space=None, derivatives=None,
result_coefficient=None, argument_slots=None):
def _ufl_expr_reconstruct_(self, *operands, function_space=None, derivatives=None, argument_slots=None):
"Return a new object of the same type with new operands."
deriv_multiindex = derivatives or self.derivatives

if deriv_multiindex != self.derivatives:
# If we are constructing a derivative
corresponding_coefficient = None
else:
corresponding_coefficient = result_coefficient or self._result_coefficient

return type(self)(*operands, function_space=function_space or self.ufl_function_space(),
derivatives=deriv_multiindex,
result_coefficient=corresponding_coefficient,
derivatives=derivatives or self.derivatives,
argument_slots=argument_slots or self.argument_slots())

def __repr__(self):
Expand Down
24 changes: 5 additions & 19 deletions ufl/core/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

from ufl.core.ufl_type import ufl_type
from ufl.constantvalue import as_ufl
from ufl.finiteelement import FiniteElementBase
from ufl.functionspace import AbstractFunctionSpace, FunctionSpace
from ufl.functionspace import AbstractFunctionSpace
from ufl.argument import Coargument, Argument
from ufl.coefficient import Cofunction
from ufl.form import Form
Expand All @@ -27,24 +26,18 @@ class Interp(BaseFormOperator):
# multiple inheritance pattern:
_ufl_noslots_ = True

def __init__(self, expr, v, result_coefficient=None):
def __init__(self, expr, v):
r""" Symbolic representation of the interpolation operator.
:arg expr: a UFL expression to interpolate.
:arg v: the :class:`.FunctionSpace` to interpolate into or the :class:`.Coargument`
defined on the dual of the :class:`.FunctionSpace` to interpolate into.
:param result_coefficient: the :class:`.Coefficient` representing what is produced by the operator
"""

# This check could be more rigorous.
dual_args = (Coargument, Cofunction, Form)

if isinstance(v, FiniteElementBase):
element = v
domain = element.cell()
function_space = FunctionSpace(domain, element)
v = Argument(function_space.dual(), 0)
elif isinstance(v, AbstractFunctionSpace):
if isinstance(v, AbstractFunctionSpace):
if is_dual(v):
raise ValueError('Expecting a primal function space.')
v = Argument(v.dual(), 0)
Expand All @@ -63,17 +56,12 @@ def __init__(self, expr, v, result_coefficient=None):
# Set the operand as `expr` for DAG traversal purpose.
operand = expr
BaseFormOperator.__init__(self, operand, function_space=function_space,
result_coefficient=result_coefficient,
argument_slots=argument_slots)

def _ufl_expr_reconstruct_(self, expr, v=None, result_coefficient=None, **add_kwargs):
def _ufl_expr_reconstruct_(self, expr, v=None, **add_kwargs):
"Return a new object of the same type with new operands."
v = v or self.argument_slots()[0]
# This should check if we need a new coefficient, i.e. if we need
# to pass `self._result_coefficient` when `result_coefficient` is None.
# -> `result_coefficient` is deprecated so it shouldn't be a problem!
result_coefficient = result_coefficient or self._result_coefficient
return type(self)(expr, v, result_coefficient=result_coefficient, **add_kwargs)
return type(self)(expr, v, **add_kwargs)

def __repr__(self):
"Default repr string construction for Interp."
Expand All @@ -88,8 +76,6 @@ def __str__(self):
return s

def __eq__(self, other):
if type(other) is not Interp:
return False
if self is other:
return True
return (type(self) is type(other) and
Expand Down
18 changes: 5 additions & 13 deletions ufl/formoperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,11 @@ def adjoint(form, reordered_arguments=None, derivatives_expanded=None):
# Allow BaseForm objects that are not BaseForm such as Adjoint since there are cases
# where we need to expand derivatives: e.g. to get the number of arguments
# => For example: Adjoint(Action(2-form, derivative(u,u)))
try:
if not derivatives_expanded:
# For external operators differentiation may turn a Form into a FormSum
form = expand_derivatives(form)
if isinstance(form, Form):
return compute_form_adjoint(form, reordered_arguments)
except NotImplementedError:
# Catch cases where expand derivatives is not implemented
# e.g. `adjoint(Adjoint(M))` where M is a ufl.Matrix,
# expand_derivatives(M) will be M (no derivatives taken)
# and expand derivatives of Adjoint only works if when we push it through the Adjoint
# we get 0.
pass
if not derivatives_expanded:
# For external operators differentiation may turn a Form into a FormSum
form = expand_derivatives(form)
if isinstance(form, Form):
return compute_form_adjoint(form, reordered_arguments)
return Adjoint(form)


Expand Down

0 comments on commit 3d00bf4

Please sign in to comment.