diff --git a/AUTHORS b/AUTHORS index 8b0fd6507..3ad2dfc66 100644 --- a/AUTHORS +++ b/AUTHORS @@ -29,3 +29,4 @@ Contributors: | Jack S. Hale | Tuomas Airaksinen | Reuben W. Hill + | Nacime Bouziani diff --git a/test/test_duals.py b/test/test_duals.py index d29c62639..0f4f46ef2 100644 --- a/test/test_duals.py +++ b/test/test_duals.py @@ -4,7 +4,7 @@ from ufl import (FiniteElement, FunctionSpace, MixedFunctionSpace, Coefficient, Matrix, Cofunction, FormSum, Argument, Coargument, TestFunction, TrialFunction, Adjoint, Action, - action, adjoint, derivative, tetrahedron, triangle, interval, dx) + action, adjoint, derivative, inner, tetrahedron, triangle, interval, dx) from ufl.constantvalue import Zero from ufl.form import ZeroBaseForm @@ -102,8 +102,6 @@ def test_addition(): domain_2d = default_domain(triangle) f_2d = FiniteElement("CG", triangle, 1) V = FunctionSpace(domain_2d, f_2d) - f_2d_2 = FiniteElement("CG", triangle, 2) - V2 = FunctionSpace(domain_2d, f_2d_2) V_dual = V.dual() u = TrialFunction(V) @@ -137,11 +135,6 @@ def test_addition(): res -= ZeroBaseForm((v,)) assert res == L - with pytest.raises(ValueError): - # Raise error for incompatible arguments - v2 = TestFunction(V2) - res = L + ZeroBaseForm((v2, u)) - def test_scalar_mult(): domain_2d = default_domain(triangle) @@ -256,7 +249,7 @@ def test_differentiation(): w = Cofunction(U.dual()) dwdu = expand_derivatives(derivative(w, u)) assert isinstance(dwdu, ZeroBaseForm) - assert dwdu.arguments() == (Argument(u.ufl_function_space(), 0),) + assert dwdu.arguments() == (Argument(w.ufl_function_space().dual(), 0), Argument(u.ufl_function_space(), 1)) # Check compatibility with int/float assert dwdu == 0 @@ -285,24 +278,21 @@ def test_differentiation(): assert dMdu == 0 # -- Action -- # - Ac = Action(M, u) - dAcdu = expand_derivatives(derivative(Ac, u)) - - # Action(dM/du, u) + Action(M, du/du) = Action(M, uhat) since dM/du = 0. - # Multiply by 1 to get a FormSum (type compatibility). - assert dAcdu == 1 * Action(M, v) + Ac = Action(w, u) + dAcdu = derivative(Ac, u) + assert dAcdu == action(adjoint(derivative(w, u)), u) + action(w, derivative(u, u)) - # -- Adjoint -- # - Ad = Adjoint(M) - dAddu = expand_derivatives(derivative(Ad, u)) - # Push differentiation through Adjoint - assert dAddu == 0 + dAcdu = expand_derivatives(dAcdu) + # Since dw/du = 0 + assert dAcdu == 1 * Action(w, v) # -- Form sum -- # - Fs = M + Ac + uhat = Argument(U, 1) + what = Argument(U, 2) + Fs = M + inner(u * uhat, v) * dx dFsdu = expand_derivatives(derivative(Fs, u)) # Distribute differentiation over FormSum components - assert dFsdu == 1 * Action(M, v) + assert dFsdu == FormSum([inner(what * uhat, v) * dx, 1]) def test_zero_base_form_mult(): diff --git a/ufl/action.py b/ufl/action.py index e7ec6ad5e..d2dc2b03e 100644 --- a/ufl/action.py +++ b/ufl/action.py @@ -12,7 +12,8 @@ from ufl.form import BaseForm, FormSum, Form, ZeroBaseForm from ufl.core.ufl_type import ufl_type from ufl.algebra import Sum -from ufl.argument import Argument +from ufl.constantvalue import Zero +from ufl.argument import Argument, Coargument from ufl.coefficient import BaseCoefficient, Coefficient, Cofunction from ufl.differentiation import CoefficientDerivative from ufl.matrix import Matrix @@ -40,6 +41,8 @@ class Action(BaseForm): "ufl_operands", "_repr", "_arguments", + "_coefficients", + "_domains", "_hash") def __new__(cls, *args, **kw): @@ -47,12 +50,17 @@ def __new__(cls, *args, **kw): # Check trivial case if left == 0 or right == 0: - # Check compatibility of function spaces - _check_function_spaces(left, right) # Still need to work out the ZeroBaseForm arguments. - new_arguments = _get_action_form_arguments(left, right) + new_arguments, _ = _get_action_form_arguments(left, right) return ZeroBaseForm(new_arguments) + # Coarguments (resp. Argument) from V* to V* (resp. from V to V) are identity matrices, + # i.e. we have: V* x V -> R (resp. V x V* -> R). + if isinstance(left, (Coargument, Argument)): + return right + if isinstance(right, (Coargument, Argument)): + return left + if isinstance(left, (FormSum, Sum)): # Action distributes over sums on the LHS return FormSum(*[(Action(component, right), 1) @@ -70,6 +78,7 @@ def __init__(self, left, right): self._left = left self._right = right self.ufl_operands = (self._left, self._right) + self._domains = None # Check compatibility of function spaces _check_function_spaces(left, right) @@ -97,14 +106,22 @@ def _analyze_form_arguments(self): The highest number Argument of the left operand and the lowest number Argument of the right operand are consumed by the action. """ - self._arguments = _get_action_form_arguments(self._left, self._right) + self._arguments, self._coefficients = _get_action_form_arguments(self._left, self._right) + + def _analyze_domains(self): + """Analyze which domains can be found in Action.""" + from ufl.domain import join_domains + # Collect unique domains + self._domains = join_domains([e.ufl_domain() for e in self.ufl_operands]) def equals(self, other): if type(other) is not Action: return False if self is other: return True - return self._left == other._left and self._right == other._right + # Make sure we are returning a boolean as left and right equalities can be `ufl.Equation`s + # if the underlying objects are `ufl.BaseForm`. + return bool(self._left == other._left) and bool(self._right == other._right) def __str__(self): return f"Action({self._left}, {self._right})" @@ -127,29 +144,51 @@ def _check_function_spaces(left, right): # right as a consequence of Leibniz formula. right, *_ = right.ufl_operands + # `left` can also be a Coefficient in V (= V**), e.g. `action(Coefficient(V), Cofunction(V.dual()))`. + left_arg = left.arguments()[-1] if not isinstance(left, Coefficient) else left if isinstance(right, (Form, Action, Matrix, ZeroBaseForm)): - if left.arguments()[-1].ufl_function_space().dual() != right.arguments()[0].ufl_function_space(): + if left_arg.ufl_function_space().dual() != right.arguments()[0].ufl_function_space(): raise TypeError("Incompatible function spaces in Action") elif isinstance(right, (Coefficient, Cofunction, Argument)): - if left.arguments()[-1].ufl_function_space() != right.ufl_function_space(): + if left_arg.ufl_function_space() != right.ufl_function_space(): raise TypeError("Incompatible function spaces in Action") - else: + # `Zero` doesn't contain any information about the function space. + # -> Not a problem since Action will get simplified with a `ZeroBaseForm` + # which won't take into account the arguments on the right because of argument contraction. + # This occurs for: + # `derivative(Action(A, B), u)` with B is an `Expr` such that dB/du == 0 + # -> `derivative(B, u)` becomes `Zero` when expanding derivatives since B is an Expr. + elif not isinstance(right, Zero): raise TypeError("Incompatible argument in Action: %s" % type(right)) def _get_action_form_arguments(left, right): """Perform argument contraction to work out the arguments of Action""" - if isinstance(right, CoefficientDerivative): + coefficients = () + # `left` can also be a Coefficient in V (= V**), e.g. `action(Coefficient(V), Cofunction(V.dual()))`. + left_args = left.arguments()[:-1] if not isinstance(left, Coefficient) else () + if isinstance(right, BaseForm): + arguments = left_args + right.arguments()[1:] + coefficients += right.coefficients() + elif isinstance(right, CoefficientDerivative): # Action differentiation pushes differentiation through # right as a consequence of Leibniz formula. - right, *_ = right.ufl_operands - - if isinstance(right, BaseForm): - return left.arguments()[:-1] + right.arguments()[1:] - elif isinstance(right, BaseCoefficient): - return left.arguments()[:-1] + from ufl.algorithms.analysis import extract_arguments_and_coefficients + right_args, right_coeffs = extract_arguments_and_coefficients(right) + arguments = left_args + tuple(right_args) + coefficients += tuple(right_coeffs) + elif isinstance(right, (BaseCoefficient, Zero)): + arguments = left_args + # When right is ufl.Zero, Action gets simplified so updating + # coefficients here doesn't matter + coefficients += (right,) elif isinstance(right, Argument): - return left.arguments()[:-1] + (right,) + arguments = left_args + (right,) else: raise TypeError + + if isinstance(left, BaseForm): + coefficients += left.coefficients() + + return arguments, coefficients diff --git a/ufl/adjoint.py b/ufl/adjoint.py index 29dce30e8..cf76a75e9 100644 --- a/ufl/adjoint.py +++ b/ufl/adjoint.py @@ -10,6 +10,7 @@ # Modified by Nacime Bouziani, 2021-2022. from ufl.form import BaseForm, FormSum, ZeroBaseForm +from ufl.argument import Coargument from ufl.core.ufl_type import ufl_type # --- The Adjoint class represents the adjoint of a numerical object that # needs to be computed at assembly time --- @@ -28,6 +29,8 @@ class Adjoint(BaseForm): "_form", "_repr", "_arguments", + "_coefficients", + "_domains", "ufl_operands", "_hash") @@ -44,6 +47,14 @@ def __new__(cls, *args, **kw): # Adjoint distributes over sums return FormSum(*[(Adjoint(component), 1) for component in form.components()]) + elif isinstance(form, Coargument): + # The adjoint of a coargument `c: V* -> V*` is the identity matrix mapping from V to V (i.e. V x V* -> R). + # Equivalently, the adjoint of `c` is its first argument, which is a ufl.Argument defined on the + # primal space of `c`. + primal_arg, _ = form.arguments() + # Returning the primal argument avoids explicit argument reconstruction, making it + # a robust strategy for handling subclasses of `ufl.Coargument`. + return primal_arg return super(Adjoint, cls).__new__(cls) @@ -55,6 +66,7 @@ def __init__(self, form): self._form = form self.ufl_operands = (self._form,) + self._domains = None self._hash = None self._repr = "Adjoint(%s)" % repr(self._form) @@ -68,13 +80,22 @@ def form(self): def _analyze_form_arguments(self): """The arguments of adjoint are the reverse of the form arguments.""" self._arguments = self._form.arguments()[::-1] + self._coefficients = self._form.coefficients() + + def _analyze_domains(self): + """Analyze which domains can be found in Adjoint.""" + from ufl.domain import join_domains + # Collect unique domains + self._domains = join_domains([e.ufl_domain() for e in self.ufl_operands]) def equals(self, other): if type(other) is not Adjoint: return False if self is other: return True - return self._form == other._form + # Make sure we are returning a boolean as the equality can result in a `ufl.Equation` + # if the underlying objects are `ufl.BaseForm`. + return bool(self._form == other._form) def __str__(self): return f"Adjoint({self._form})" diff --git a/ufl/algorithms/__init__.py b/ufl/algorithms/__init__.py index 687e23db2..727973437 100644 --- a/ufl/algorithms/__init__.py +++ b/ufl/algorithms/__init__.py @@ -21,6 +21,7 @@ "estimate_total_polynomial_degree", "sort_elements", "compute_form_data", + "preprocess_form", "apply_transformer", "ReuseTransformer", "load_ufl_file", @@ -79,7 +80,7 @@ # Preprocessing a form to extract various meta data # from ufl.algorithms.formdata import FormData -from ufl.algorithms.compute_form_data import compute_form_data +from ufl.algorithms.compute_form_data import compute_form_data, preprocess_form # Utilities for checking properties of forms from ufl.algorithms.signature import compute_form_signature diff --git a/ufl/algorithms/ad.py b/ufl/algorithms/ad.py index 4755070ae..20e102f39 100644 --- a/ufl/algorithms/ad.py +++ b/ufl/algorithms/ad.py @@ -11,7 +11,6 @@ import warnings -from ufl.adjoint import Adjoint from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering from ufl.algorithms.apply_derivatives import apply_derivatives @@ -29,13 +28,6 @@ def expand_derivatives(form, **kwargs): if kwargs: warnings("Deprecation: expand_derivatives no longer takes any keyword arguments") - if isinstance(form, Adjoint): - dform = expand_derivatives(form._form) - if dform == 0: - return dform - # Adjoint is taken on a 3-form which can't happen - raise NotImplementedError('Adjoint derivative is not supported.') - # Lower abstractions for tensor-algebra types into index notation form = apply_algebra_lowering(form) diff --git a/ufl/algorithms/analysis.py b/ufl/algorithms/analysis.py index 6c2c2190d..b17ec4d2c 100644 --- a/ufl/algorithms/analysis.py +++ b/ufl/algorithms/analysis.py @@ -50,12 +50,16 @@ def extract_type(a, ufl_types): if not isinstance(ufl_types, (list, tuple)): ufl_types = (ufl_types,) - # BaseForms that aren't forms only have arguments + # BaseForms that aren't forms only contain arguments & coefficients if isinstance(a, BaseForm) and not isinstance(a, Form): - if any(issubclass(t, BaseArgument) for t in ufl_types): - return set(a.arguments()) - else: - return set() + objects = set() + arg_types = tuple(t for t in ufl_types if issubclass(t, BaseArgument)) + if arg_types: + objects.update([e for e in a.arguments() if isinstance(e, arg_types)]) + coeff_types = tuple(t for t in ufl_types if issubclass(t, BaseCoefficient)) + if coeff_types: + objects.update([e for e in a.coefficients() if isinstance(e, coeff_types)]) + return objects if all(issubclass(t, Terminal) for t in ufl_types): # Optimization diff --git a/ufl/algorithms/apply_derivatives.py b/ufl/algorithms/apply_derivatives.py index befb4067b..3fced9e3a 100644 --- a/ufl/algorithms/apply_derivatives.py +++ b/ufl/algorithms/apply_derivatives.py @@ -24,7 +24,7 @@ from ufl.core.terminal import Terminal from ufl.corealg.map_dag import map_expr_dag from ufl.corealg.multifunction import MultiFunction -from ufl.differentiation import CoordinateDerivative +from ufl.differentiation import CoordinateDerivative, BaseFormCoordinateDerivative from ufl.domain import extract_unique_domain from ufl.operators import (bessel_I, bessel_J, bessel_K, bessel_Y, cell_avg, conditional, cos, cosh, exp, facet_avg, ln, sign, @@ -1032,7 +1032,7 @@ def cofunction(self, o): dc = self.coefficient(o) if dc == 0: # Convert ufl.Zero into ZeroBaseForm - return ZeroBaseForm(self._v) + return ZeroBaseForm(o.arguments() + self._v) return dc def coargument(self, o): @@ -1102,6 +1102,14 @@ def coordinate_derivative(self, o, f, dummy_w, dummy_v, dummy_cd): rcache=self.rcaches[key]), o_[1], o_[2], o_[3]) + def base_form_coordinate_derivative(self, o, f, dummy_w, dummy_v, dummy_cd): + o_ = o.ufl_operands + key = (BaseFormCoordinateDerivative, o_[0]) + return BaseFormCoordinateDerivative(map_expr_dag(self, o_[0], + vcache=self.vcaches[key], + rcache=self.rcaches[key]), + o_[1], o_[2], o_[3]) + def indexed(self, o, Ap, ii): # TODO: (Partially) duplicated in generic rules # Reuse if untouched if Ap is o.ufl_operands[0]: diff --git a/ufl/algorithms/compute_form_data.py b/ufl/algorithms/compute_form_data.py index 348e27e17..afa092e07 100644 --- a/ufl/algorithms/compute_form_data.py +++ b/ufl/algorithms/compute_form_data.py @@ -206,31 +206,7 @@ def attach_estimated_degrees(form): return Form(new_integrals) -def compute_form_data(form, - # Default arguments configured to behave the way old FFC expects it: - do_apply_function_pullbacks=False, - do_apply_integral_scaling=False, - do_apply_geometry_lowering=False, - preserve_geometry_types=(), - do_apply_default_restrictions=True, - do_apply_restrictions=True, - do_estimate_degrees=True, - do_append_everywhere_integrals=True, - complex_mode=False, - ): - - # TODO: Move this to the constructor instead - self = FormData() - - # --- Store untouched form for reference. - # The user of FormData may get original arguments, - # original coefficients, and form signature from this object. - # But be aware that the set of original coefficients are not - # the same as the ones used in the final UFC form. - # See 'reduced_coefficients' below. - self.original_form = form - - # --- Pass form integrands through some symbolic manipulation +def preprocess_form(form, complex_mode): # Note: Default behaviour here will process form the way that is # currently expected by vanilla FFC @@ -258,6 +234,37 @@ def compute_form_data(form, # user-defined coefficient relations it just gets too messy form = apply_derivatives(form) + return form + + +def compute_form_data(form, + # Default arguments configured to behave the way old FFC expects it: + do_apply_function_pullbacks=False, + do_apply_integral_scaling=False, + do_apply_geometry_lowering=False, + preserve_geometry_types=(), + do_apply_default_restrictions=True, + do_apply_restrictions=True, + do_estimate_degrees=True, + do_append_everywhere_integrals=True, + complex_mode=False, + ): + + # TODO: Move this to the constructor instead + self = FormData() + + # --- Store untouched form for reference. + # The user of FormData may get original arguments, + # original coefficients, and form signature from this object. + # But be aware that the set of original coefficients are not + # the same as the ones used in the final UFC form. + # See 'reduced_coefficients' below. + self.original_form = form + + # --- Pass form integrands through some symbolic manipulation + + form = preprocess_form(form, complex_mode) + # --- Group form integrals # TODO: Refactor this, it's rather opaque what this does # TODO: Is self.original_form.ufl_domains() right here? diff --git a/ufl/algorithms/map_integrands.py b/ufl/algorithms/map_integrands.py index 568e65ddb..65ede8f0b 100644 --- a/ufl/algorithms/map_integrands.py +++ b/ufl/algorithms/map_integrands.py @@ -39,9 +39,14 @@ def map_integrands(function, form, only_integral_type=None): elif isinstance(form, FormSum): mapped_components = [map_integrands(function, component, only_integral_type) for component in form.components()] - nonzero_components = [(component, 1) for component in mapped_components + nonzero_components = [(component, w) for component, w in zip(mapped_components, form.weights()) # Catch ufl.Zero and ZeroBaseForm if component != 0] + if all(not isinstance(component, BaseForm) for component, _ in nonzero_components): + # Simplification of `BaseForm` objects may turn `FormSum` into a sum of `Expr` objects + # that are not `BaseForm`, i.e. into a `Sum` object. + # Example: `Action(Adjoint(c*), u)` with `c*` a `Coargument` and u a `Coefficient`. + return sum([component for component, _ in nonzero_components]) return FormSum(*nonzero_components) elif isinstance(form, Adjoint): # Zeros are caught inside `Adjoint.__new__` diff --git a/ufl/argument.py b/ufl/argument.py index 74358d503..30dd815b0 100644 --- a/ufl/argument.py +++ b/ufl/argument.py @@ -180,6 +180,7 @@ class Coargument(BaseForm, BaseArgument): "_ufl_function_space", "_ufl_shape", "_arguments", + "_coefficients", "ufl_operands", "_number", "_part", @@ -208,7 +209,13 @@ def __init__(self, function_space, number, part=None): def _analyze_form_arguments(self): "Analyze which Argument and Coefficient objects can be found in the form." # Define canonical numbering of arguments and coefficients - self._arguments = (Argument(self._ufl_function_space, 0),) + # Coarguments map from V* to V*, i.e. V* -> V*, or equivalently V* x V -> R. + # So they have one argument in the primal space and one in the dual space. + self._arguments = (Argument(self.ufl_function_space().dual(), 0), self) + self._coefficients = () + + def ufl_domain(self): + return BaseArgument.ufl_domain(self) def equals(self, other): if type(other) is not Coargument: diff --git a/ufl/coefficient.py b/ufl/coefficient.py index 704c3912d..5ea15ce09 100644 --- a/ufl/coefficient.py +++ b/ufl/coefficient.py @@ -14,6 +14,7 @@ from ufl.core.ufl_type import ufl_type from ufl.core.terminal import FormArgument +from ufl.argument import Argument from ufl.finiteelement import FiniteElementBase from ufl.domain import default_domain from ufl.functionspace import AbstractFunctionSpace, FunctionSpace, MixedFunctionSpace @@ -109,6 +110,7 @@ class Cofunction(BaseCoefficient, BaseForm): "_count", "_counted_class", "_arguments", + "_coefficients", "_ufl_function_space", "ufl_operands", "_repr", @@ -149,7 +151,9 @@ def __hash__(self): def _analyze_form_arguments(self): "Analyze which Argument and Coefficient objects can be found in the form." # Define canonical numbering of arguments and coefficients - self._arguments = () + # Cofunctions have one argument in primal space as they map from V to R. + self._arguments = (Argument(self._ufl_function_space.dual(), 0),) + self._coefficients = (self,) @ufl_type() diff --git a/ufl/differentiation.py b/ufl/differentiation.py index 75c4197db..d8828e37e 100644 --- a/ufl/differentiation.py +++ b/ufl/differentiation.py @@ -8,6 +8,7 @@ from ufl.checks import is_cellwise_constant from ufl.coefficient import Coefficient +from ufl.argument import Argument, Coargument from ufl.constantvalue import Zero from ufl.core.expr import Expr from ufl.core.operator import Operator @@ -91,8 +92,39 @@ def __init__(self, base_form, coefficients, arguments, def _analyze_form_arguments(self): """Collect the arguments of the corresponding BaseForm""" - base_form = self.ufl_operands[0] - self._arguments = base_form.arguments() + from ufl.algorithms.analysis import extract_type, extract_coefficients + base_form, _, arguments, _ = self.ufl_operands + + def arg_type(x): + if isinstance(x, BaseForm): + return Coargument + return Argument + # Each derivative arguments can either be a: + # - `ufl.BaseForm`: if it contains a `ufl.Coargument` + # - or a `ufl.Expr`: if it contains a `ufl.Argument` + # When a `Coargument` is encountered, it is treated as an argument (i.e. as V* -> V* and not V* x V -> R) + # and should result in one single argument (in the dual space). + base_form_args = base_form.arguments() + tuple(arg for a in arguments.ufl_operands + for arg in extract_type(a, arg_type(a))) + # BaseFormDerivative's arguments don't necessarily contain BaseArgument objects only + # -> e.g. `derivative(u ** 2, u, u)` with `u` a Coefficient. + base_form_coeffs = base_form.coefficients() + tuple(arg for a in arguments.ufl_operands + for arg in extract_coefficients(a)) + # Reconstruct arguments for correct numbering + self._arguments = tuple(type(arg)(arg.ufl_function_space(), arg.number(), arg.part()) for arg in base_form_args) + self._coefficients = base_form_coeffs + + +@ufl_type(num_ops=4, inherit_shape_from_operand=0, + inherit_indices_from_operand=0) +class BaseFormCoordinateDerivative(BaseFormDerivative, CoordinateDerivative): + """Derivative of a base form w.r.t. the SpatialCoordinates.""" + _ufl_noslots_ = True + + def __init__(self, base_form, coefficients, arguments, + coefficient_derivatives): + BaseFormDerivative.__init__(self, base_form, coefficients, arguments, + coefficient_derivatives) @ufl_type(num_ops=2) diff --git a/ufl/form.py b/ufl/form.py index d4370b65a..601e4d13a 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -86,11 +86,12 @@ class BaseForm(object, metaclass=UFLType): # classes __slots__ = () _ufl_is_abstract_ = True - _ufl_required_methods_ = ('_analyze_form_arguments', "ufl_domains") + _ufl_required_methods_ = ('_analyze_form_arguments', '_analyze_domains', "ufl_domains") def __init__(self): - # Internal variables for caching form argument data + # Internal variables for caching form argument/coefficient data self._arguments = None + self._coefficients = None # --- Accessor interface --- def arguments(self): @@ -99,6 +100,24 @@ def arguments(self): self._analyze_form_arguments() return self._arguments + def coefficients(self): + "Return all ``Coefficient`` objects found in form." + if self._coefficients is None: + self._analyze_form_arguments() + return self._coefficients + + def ufl_domain(self): + """Return the single geometric integration domain occuring in the + base form. Fails if multiple domains are found. + """ + if self._domains is None: + self._analyze_domains() + + if len(self._domains) > 1: + raise ValueError("%s must have exactly one domain." % type(self).__name__) + # Return the single geometric domain + return self._domains[0] + # --- Operator implementations --- def __eq__(self, other): @@ -123,7 +142,6 @@ def __add__(self, other): return self elif isinstance(other, ZeroBaseForm): - self._check_arguments_sum(other) # Simplify addition with ZeroBaseForm return self @@ -131,7 +149,6 @@ def __add__(self, other): # We could overwrite ZeroBaseForm.__add__ but that implies # duplicating cases with `0` and `ufl.Zero`. elif isinstance(self, ZeroBaseForm): - self._check_arguments_sum(other) # Simplify addition with ZeroBaseForm return other @@ -143,18 +160,6 @@ def __add__(self, other): # Let python protocols do their job if we don't handle it return NotImplemented - def _check_arguments_sum(self, other): - # Get component with the highest number of arguments - a = max((self, other), key=lambda x: len(x.arguments())) - b = self if a is other else other - # Components don't necessarily have the exact same arguments - # but the first argument(s) need to match as for `a + L` - # where a and L are a bilinear and linear form respectively. - a_args = sorted(a.arguments(), key=lambda x: x.number()) - b_args = sorted(b.arguments(), key=lambda x: x.number()) - if b_args != a_args[:len(b_args)]: - raise ValueError('Mismatching arguments when summing:\n %s\n and\n %s' % (self, other)) - def __sub__(self, other): "Subtract other form from this one." return self + (-other) @@ -506,7 +511,6 @@ def __add__(self, other): return Form(list(chain(self.integrals(), other.integrals()))) if isinstance(other, ZeroBaseForm): - self._check_arguments_sum(other) # Simplify addition with ZeroBaseForm return self @@ -730,6 +734,7 @@ class FormSum(BaseForm): arg_weights is a list of tuples of component index and weight""" __slots__ = ("_arguments", + "_coefficients", "_weights", "_components", "ufl_operands", @@ -738,20 +743,37 @@ class FormSum(BaseForm): "_hash") _ufl_required_methods_ = ('_analyze_form_arguments') + def __new__(cls, *args, **kwargs): + # All the components are `ZeroBaseForm` + if all(component == 0 for component, _ in args): + # Assume that the arguments of all the components have consistent with each other and select + # the first one to define the arguments of `ZeroBaseForm`. + # This might not always be true but `ZeroBaseForm`'s arguments are not checked anywhere + # because we can't reliably always infer them. + ((arg, _), *_) = args + arguments = arg.arguments() + return ZeroBaseForm(arguments) + + return super(FormSum, cls).__new__(cls) + def __init__(self, *components): BaseForm.__init__(self) + # Remove `ZeroBaseForm` components + filtered_components = [(component, w) for component, w in components if component != 0] + weights = [] full_components = [] - for (component, w) in components: + for (component, w) in filtered_components: if isinstance(component, FormSum): full_components.extend(component.components()) - weights.extend(w * component.weights()) + weights.extend([w * wc for wc in component.weights()]) else: full_components.append(component) weights.append(w) self._arguments = None + self._coefficients = None self._domains = None self._domain_numbering = None self._hash = None @@ -788,9 +810,18 @@ def _sum_variational_components(self): def _analyze_form_arguments(self): "Return all ``Argument`` objects found in form." arguments = [] + coefficients = [] for component in self._components: arguments.extend(component.arguments()) + coefficients.extend(component.coefficients()) self._arguments = tuple(set(arguments)) + self._coefficients = tuple(set(coefficients)) + + def _analyze_domains(self): + """Analyze which domains can be found in FormSum.""" + from ufl.domain import join_domains + # Collect unique domains + self._domains = join_domains([component.ufl_domain() for component in self._components]) def __hash__(self): "Hash code for use in dicts (includes incidental numbering of indices etc.)" @@ -849,7 +880,8 @@ def __init__(self, arguments): self.form = None def _analyze_form_arguments(self): - return self._arguments + # `self._arguments` is already set in `BaseForm.__init__` + self._coefficients = () def __ne__(self, other): # Overwrite BaseForm.__neq__ which relies on `equals` diff --git a/ufl/formoperators.py b/ufl/formoperators.py index 355bc9713..40caeaf68 100644 --- a/ufl/formoperators.py +++ b/ufl/formoperators.py @@ -11,7 +11,7 @@ # Modified by Massimiliano Leoni, 2016 # Modified by Cecile Daversin-Catty, 2018 -from ufl.form import Form, FormSum, BaseForm, as_form +from ufl.form import Form, FormSum, BaseForm, ZeroBaseForm, as_form from ufl.core.expr import Expr, ufl_err_str from ufl.split_functions import split from ufl.exprcontainers import ExprList, ExprMapping @@ -21,7 +21,8 @@ from ufl.coefficient import Coefficient, Cofunction from ufl.adjoint import Adjoint from ufl.action import Action -from ufl.differentiation import CoefficientDerivative, BaseFormDerivative, CoordinateDerivative +from ufl.differentiation import (CoefficientDerivative, BaseFormDerivative, + CoordinateDerivative, BaseFormCoordinateDerivative) from ufl.constantvalue import is_true_ufl_scalar, as_ufl from ufl.indexed import Indexed from ufl.core.multiindex import FixedIndex, MultiIndex @@ -285,15 +286,23 @@ def derivative(form, coefficient, argument=None, coefficient_derivatives=None): return FormSum(*[(derivative(component, coefficient, argument, coefficient_derivatives), 1) for component in form.components()]) elif isinstance(form, Adjoint): - # Push derivative through Adjoint - return adjoint(derivative(form._form, coefficient, argument, coefficient_derivatives)) + # Is `derivative(Adjoint(A), ...)` with A a 2-form even legal ? + # -> If yes, what's the right thing to do here ? + raise NotImplementedError('Adjoint derivative is not supported.') elif isinstance(form, Action): # Push derivative through Action slots left, right = form.ufl_operands - dleft = derivative(left, coefficient, argument, coefficient_derivatives) - dright = derivative(right, coefficient, argument, coefficient_derivatives) - # Leibniz formula - return action(dleft, right) + action(left, dright) + # Eagerly simplify spatial derivatives when Action results in a scalar. + if not len(form.arguments()) and isinstance(coefficient, SpatialCoordinate): + return ZeroBaseForm(()) + + if len(left.arguments()) == 1: + dleft = derivative(left, coefficient, argument, coefficient_derivatives) + dright = derivative(right, coefficient, argument, coefficient_derivatives) + # Leibniz formula + return action(adjoint(dleft), right) + action(left, dright) + else: + raise NotImplementedError('Action derivative not supported when the left argument is not a 1-form.') coefficients, arguments = _handle_derivative_arguments(form, coefficient, argument) @@ -309,18 +318,24 @@ def derivative(form, coefficient, argument=None, coefficient_derivatives=None): if isinstance(form, Form): integrals = [] for itg in form.integrals(): - if not isinstance(coefficient, SpatialCoordinate): - fd = CoefficientDerivative(itg.integrand(), coefficients, - arguments, coefficient_derivatives) - else: + if isinstance(coefficient, SpatialCoordinate): fd = CoordinateDerivative(itg.integrand(), coefficients, arguments, coefficient_derivatives) + elif isinstance(coefficient, BaseForm): + # Make the `ZeroBaseForm` arguments + arguments = form.arguments() + coefficient.arguments() + return ZeroBaseForm(arguments) + else: + fd = CoefficientDerivative(itg.integrand(), coefficients, + arguments, coefficient_derivatives) integrals.append(itg.reconstruct(fd)) return Form(integrals) elif isinstance(form, BaseForm): if not isinstance(coefficient, SpatialCoordinate): return BaseFormDerivative(form, coefficients, arguments, coefficient_derivatives) + else: + return BaseFormCoordinateDerivative(form, coefficients, arguments, coefficient_derivatives) elif isinstance(form, Expr): # What we got was in fact an integrand diff --git a/ufl/matrix.py b/ufl/matrix.py index 0b120f414..23cefcc6f 100644 --- a/ufl/matrix.py +++ b/ufl/matrix.py @@ -30,7 +30,9 @@ class Matrix(BaseForm, Counted): "_repr", "_hash", "_ufl_shape", - "_arguments") + "_arguments", + "_coefficients", + "_domains") def __getnewargs__(self): return (self._ufl_function_spaces[0], self._ufl_function_spaces[1], @@ -49,6 +51,7 @@ def __init__(self, row_space, column_space, count=None): self._ufl_function_spaces = (row_space, column_space) self.ufl_operands = () + self._domains = None self._hash = None self._repr = f"Matrix({self._ufl_function_spaces[0]!r}, {self._ufl_function_spaces[1]!r}, {self._count!r})" @@ -60,6 +63,13 @@ def _analyze_form_arguments(self): "Define arguments of a matrix when considered as a form." self._arguments = (Argument(self._ufl_function_spaces[0], 0), Argument(self._ufl_function_spaces[1], 1)) + self._coefficients = () + + def _analyze_domains(self): + """Analyze which domains can be found in a Matrix.""" + from ufl.domain import join_domains + # Collect unique domains + self._domains = join_domains([fs.ufl_domain() for fs in self._ufl_function_spaces]) def __str__(self): count = str(self._count)