From 51e2bb1cc3d4c6a8b10bbc81ad1d2e15ee331255 Mon Sep 17 00:00:00 2001 From: nbouziani Date: Sun, 20 Aug 2023 01:44:55 +0100 Subject: [PATCH] Fix tests --- test/test_duals.py | 4 ++-- ufl/algorithms/analysis.py | 19 +++++++++++++------ ufl/argument.py | 17 ++++++++++++++--- ufl/core/base_form_operator.py | 12 +++++++++--- ufl/form.py | 6 ------ 5 files changed, 38 insertions(+), 20 deletions(-) diff --git a/test/test_duals.py b/test/test_duals.py index 64b84f430..be3238a04 100644 --- a/test/test_duals.py +++ b/test/test_duals.py @@ -280,11 +280,11 @@ def test_differentiation(): # -- Action -- # Ac = Action(w, u) dAcdu = derivative(Ac, u) - assert dAcdu == Action(Adjoint(derivative(w, u)), u) + Action(w, derivative(u, u)) + assert dAcdu == action(adjoint(derivative(w, u)), u) + action(w, derivative(u, u)) dAcdu = expand_derivatives(dAcdu) # Since dw/du = 0 - assert dAcdu == 1 * Action(w, v) + assert dAcdu == Action(w, v) # -- Form sum -- # uhat = Argument(U, 1) diff --git a/ufl/algorithms/analysis.py b/ufl/algorithms/analysis.py index e6072910a..ec64b3f89 100644 --- a/ufl/algorithms/analysis.py +++ b/ufl/algorithms/analysis.py @@ -16,7 +16,7 @@ from ufl.core.terminal import Terminal from ufl.core.base_form_operator import BaseFormOperator -from ufl.argument import BaseArgument +from ufl.argument import BaseArgument, Coargument from ufl.coefficient import BaseCoefficient from ufl.constant import Constant from ufl.form import BaseForm, Form @@ -63,10 +63,12 @@ def extract_type(a, ufl_types): # only contain arguments & coefficients if isinstance(a, BaseForm) and not isinstance(a, (Form, BaseFormOperator)): objects = set() - if any(issubclass(t, BaseArgument) for t in ufl_types): - objects.update(a.arguments()) - if any(issubclass(t, BaseCoefficient) for t in ufl_types): - objects.update(a.coefficients()) + arg_types = tuple(t for t in ufl_types if issubclass(t, BaseArgument)) + if arg_types: + objects.update([e for e in a.arguments() if isinstance(e, arg_types)]) + coeff_types = tuple(t for t in ufl_types if issubclass(t, BaseCoefficient)) + if coeff_types: + objects.update([e for e in a.coefficients() if isinstance(e, coeff_types)]) return objects if all(issubclass(t, Terminal) for t in ufl_types): @@ -94,8 +96,13 @@ def extract_type(a, ufl_types): for o in base_form_ops: # This accounts for having BaseFormOperator in Forms: if N is a BaseFormOperator # N(u; v*) * v * dx <=> action(v1 * v * dx, N(...; v*)) - # where v, v1 are Arguments and v* a Coargument. + # where v, v1 are `Argument`s and v* a `Coargument`. for ai in tuple(arg for arg in o.argument_slots(isinstance(a, Form))): + # Extracting BaseArguments of an object of which a Coargument is an argument, + # then we just return the dual argument of the Coargument and not its primal argument. + # TODO: There might be a cleaner way to handle that case. + if isinstance(ai, Coargument): + ufl_types = tuple(Coargument if t is BaseArgument else t for t in ufl_types) base_form_objects += tuple(extract_type(ai, ufl_types)) # Look for BaseArguments in BaseFormOperator's argument slots only since that's where they are by definition. # Don't look into operands, which is convenient for external operator composition, e.g. N1(N2; v*) diff --git a/ufl/argument.py b/ufl/argument.py index aaac97d56..696a5e30d 100644 --- a/ufl/argument.py +++ b/ufl/argument.py @@ -205,13 +205,24 @@ def __init__(self, function_space, number, part=None): self._repr = "Coargument(%s, %s, %s)" % ( repr(self._ufl_function_space), repr(self._number), repr(self._part)) - def _analyze_form_arguments(self): + def arguments(self, outer_form=None): + "Return all ``Argument`` objects found in form." + if self._arguments is None: + self._analyze_form_arguments(outer_form=outer_form) + return self._arguments + + def _analyze_form_arguments(self, outer_form=None): "Analyze which Argument and Coefficient objects can be found in the form." # Define canonical numbering of arguments and coefficients + self._coefficients = () # Coarguments map from V* to V*, i.e. V* -> V*, or equivalently V* x V -> R. # So they have one argument in the primal space and one in the dual space. - self._arguments = (Argument(self.ufl_function_space().dual(), 0), self) - self._coefficients = () + # However, when they are composed with linear forms with dual arguments, such as BaseFormOperators, + # the primal argument is discarded when analysing the argument as Coarguments. + if not outer_form: + self._arguments = (Argument(self.ufl_function_space().dual(), 0), self) + else: + self._arguments = (self,) def ufl_domain(self): return BaseArgument.ufl_domain(self) diff --git a/ufl/core/base_form_operator.py b/ufl/core/base_form_operator.py index 3ec18ed28..bdb689231 100644 --- a/ufl/core/base_form_operator.py +++ b/ufl/core/base_form_operator.py @@ -11,7 +11,7 @@ # Modified by Nacime Bouziani, 2021-2022 from ufl.coefficient import Coefficient -from ufl.argument import Argument +from ufl.argument import Argument, Coargument from ufl.core.operator import Operator from ufl.form import BaseForm from ufl.core.ufl_type import ufl_type @@ -112,8 +112,14 @@ def coefficients(self): def _analyze_form_arguments(self): "Analyze which Argument and Coefficient objects can be found in the base form." - from ufl.algorithms.analysis import extract_arguments, extract_coefficients - arguments = tuple(a for arg in self.argument_slots() for a in extract_arguments(arg)) + from ufl.algorithms.analysis import extract_arguments, extract_coefficients, extract_type + dual_arg, *arguments = self.argument_slots() + # When coarguments are treated as BaseForms, they have two arguments (one primal and one dual) + # as they map from V* to V* => V* x V -> R. However, when they are treated as mere "arguments", + # the primal space argument is discarded and we only have the dual space argument (Coargument). + # This is the exact same situation than BaseFormOperator's arguments which are different depending on + # whether the BaseFormOperator is used in an outer form or not. + arguments = tuple(extract_type(dual_arg, Coargument)) + tuple(a for arg in arguments for a in extract_arguments(arg)) coefficients = tuple(c for op in self.ufl_operands for c in extract_coefficients(op)) # Define canonical numbering of arguments and coefficients from collections import OrderedDict diff --git a/ufl/form.py b/ufl/form.py index e1b673bc0..b4d86252e 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -399,12 +399,6 @@ def max_subdomain_ids(self): self._analyze_subdomain_data() return self._max_subdomain_ids - def arguments(self): - "Return all ``Argument`` objects found in form." - if self._arguments is None: - self._analyze_form_arguments() - return self._arguments - def coefficients(self): "Return all ``Coefficient`` objects found in form." if self._coefficients is None: