Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nbouziani committed Aug 20, 2023
1 parent 23211ce commit 51e2bb1
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 20 deletions.
4 changes: 2 additions & 2 deletions test/test_duals.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ def test_differentiation():
# -- Action -- #
Ac = Action(w, u)
dAcdu = derivative(Ac, u)
assert dAcdu == Action(Adjoint(derivative(w, u)), u) + Action(w, derivative(u, u))
assert dAcdu == action(adjoint(derivative(w, u)), u) + action(w, derivative(u, u))

dAcdu = expand_derivatives(dAcdu)
# Since dw/du = 0
assert dAcdu == 1 * Action(w, v)
assert dAcdu == Action(w, v)

# -- Form sum -- #
uhat = Argument(U, 1)
Expand Down
19 changes: 13 additions & 6 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ufl.core.terminal import Terminal
from ufl.core.base_form_operator import BaseFormOperator
from ufl.argument import BaseArgument
from ufl.argument import BaseArgument, Coargument
from ufl.coefficient import BaseCoefficient
from ufl.constant import Constant
from ufl.form import BaseForm, Form
Expand Down Expand Up @@ -63,10 +63,12 @@ def extract_type(a, ufl_types):
# only contain arguments & coefficients
if isinstance(a, BaseForm) and not isinstance(a, (Form, BaseFormOperator)):
objects = set()
if any(issubclass(t, BaseArgument) for t in ufl_types):
objects.update(a.arguments())
if any(issubclass(t, BaseCoefficient) for t in ufl_types):
objects.update(a.coefficients())
arg_types = tuple(t for t in ufl_types if issubclass(t, BaseArgument))
if arg_types:
objects.update([e for e in a.arguments() if isinstance(e, arg_types)])
coeff_types = tuple(t for t in ufl_types if issubclass(t, BaseCoefficient))
if coeff_types:
objects.update([e for e in a.coefficients() if isinstance(e, coeff_types)])
return objects

if all(issubclass(t, Terminal) for t in ufl_types):
Expand Down Expand Up @@ -94,8 +96,13 @@ def extract_type(a, ufl_types):
for o in base_form_ops:
# This accounts for having BaseFormOperator in Forms: if N is a BaseFormOperator
# N(u; v*) * v * dx <=> action(v1 * v * dx, N(...; v*))
# where v, v1 are Arguments and v* a Coargument.
# where v, v1 are `Argument`s and v* a `Coargument`.
for ai in tuple(arg for arg in o.argument_slots(isinstance(a, Form))):
# Extracting BaseArguments of an object of which a Coargument is an argument,
# then we just return the dual argument of the Coargument and not its primal argument.
# TODO: There might be a cleaner way to handle that case.
if isinstance(ai, Coargument):
ufl_types = tuple(Coargument if t is BaseArgument else t for t in ufl_types)
base_form_objects += tuple(extract_type(ai, ufl_types))
# Look for BaseArguments in BaseFormOperator's argument slots only since that's where they are by definition.
# Don't look into operands, which is convenient for external operator composition, e.g. N1(N2; v*)
Expand Down
17 changes: 14 additions & 3 deletions ufl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,24 @@ def __init__(self, function_space, number, part=None):
self._repr = "Coargument(%s, %s, %s)" % (
repr(self._ufl_function_space), repr(self._number), repr(self._part))

def _analyze_form_arguments(self):
def arguments(self, outer_form=None):
"Return all ``Argument`` objects found in form."
if self._arguments is None:
self._analyze_form_arguments(outer_form=outer_form)
return self._arguments

def _analyze_form_arguments(self, outer_form=None):
"Analyze which Argument and Coefficient objects can be found in the form."
# Define canonical numbering of arguments and coefficients
self._coefficients = ()
# Coarguments map from V* to V*, i.e. V* -> V*, or equivalently V* x V -> R.
# So they have one argument in the primal space and one in the dual space.
self._arguments = (Argument(self.ufl_function_space().dual(), 0), self)
self._coefficients = ()
# However, when they are composed with linear forms with dual arguments, such as BaseFormOperators,
# the primal argument is discarded when analysing the argument as Coarguments.
if not outer_form:
self._arguments = (Argument(self.ufl_function_space().dual(), 0), self)
else:
self._arguments = (self,)

def ufl_domain(self):
return BaseArgument.ufl_domain(self)
Expand Down
12 changes: 9 additions & 3 deletions ufl/core/base_form_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Modified by Nacime Bouziani, 2021-2022

from ufl.coefficient import Coefficient
from ufl.argument import Argument
from ufl.argument import Argument, Coargument
from ufl.core.operator import Operator
from ufl.form import BaseForm
from ufl.core.ufl_type import ufl_type
Expand Down Expand Up @@ -112,8 +112,14 @@ def coefficients(self):

def _analyze_form_arguments(self):
"Analyze which Argument and Coefficient objects can be found in the base form."
from ufl.algorithms.analysis import extract_arguments, extract_coefficients
arguments = tuple(a for arg in self.argument_slots() for a in extract_arguments(arg))
from ufl.algorithms.analysis import extract_arguments, extract_coefficients, extract_type
dual_arg, *arguments = self.argument_slots()
# When coarguments are treated as BaseForms, they have two arguments (one primal and one dual)
# as they map from V* to V* => V* x V -> R. However, when they are treated as mere "arguments",
# the primal space argument is discarded and we only have the dual space argument (Coargument).
# This is the exact same situation than BaseFormOperator's arguments which are different depending on
# whether the BaseFormOperator is used in an outer form or not.
arguments = tuple(extract_type(dual_arg, Coargument)) + tuple(a for arg in arguments for a in extract_arguments(arg))
coefficients = tuple(c for op in self.ufl_operands for c in extract_coefficients(op))
# Define canonical numbering of arguments and coefficients
from collections import OrderedDict
Expand Down
6 changes: 0 additions & 6 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,6 @@ def max_subdomain_ids(self):
self._analyze_subdomain_data()
return self._max_subdomain_ids

def arguments(self):
"Return all ``Argument`` objects found in form."
if self._arguments is None:
self._analyze_form_arguments()
return self._arguments

def coefficients(self):
"Return all ``Coefficient`` objects found in form."
if self._coefficients is None:
Expand Down

0 comments on commit 51e2bb1

Please sign in to comment.