Skip to content

Commit

Permalink
JDBetteridge/pyop3 sprint more (#3527)
Browse files Browse the repository at this point in the history
* WIP

* All tests in test_constant passing

* Apply suggestions from code review

Co-authored-by: Jack Betteridge <[email protected]>

* Matnest broke whilst I fixed constant

* Mark tests Connor doesn't want to run

* Skip some more tests

* Fix test_zero_forms, except parallel

* Re-enable tests in test_multiple_domains

* Enable some tests, disable others

---------

Co-authored-by: Connor Ward <[email protected]>
  • Loading branch information
JDBetteridge and connorjward authored Apr 26, 2024
1 parent 25dd5ad commit 4d0b2d1
Show file tree
Hide file tree
Showing 25 changed files with 95 additions and 66 deletions.
2 changes: 1 addition & 1 deletion firedrake/adjoint_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _ad_copy(self):
return self._constant_from_values()

def _ad_dim(self):
return self.dat.cdim
return self.dat.data_ro.size

def _ad_imul(self, other):
self.assign(self._constant_from_values(self.dat.data_ro.reshape(-1) * other))
Expand Down
4 changes: 2 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,8 +1323,8 @@ def _get_mat_type(mat_type, sub_mat_type, arguments):
if sub_mat_type is None:
sub_mat_type = parameters.parameters["default_sub_matrix_type"]

if has_real_subspace and mat_type != "nest":
raise ValueError
if has_real_subspace and mat_type not in ["nest", "matfree"]:
raise ValueError("Matrices containing real space arguments must have type 'nest' or 'matfree'")
if sub_mat_type not in {"aij", "baij"}:
raise ValueError(
f"Invalid submatrix type, '{sub_mat_type}' (not 'aij' or 'baij')"
Expand Down
6 changes: 4 additions & 2 deletions firedrake/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import finat.ufl
import numpy as np
import pyop3 as op3
from pyop3.exceptions import DataValueError
import pytools
from pyadjoint.tape import annotate_tape
from pyop2.utils import cached_property
Expand Down Expand Up @@ -256,8 +257,9 @@ def _assign_single_dat(self, lhs, subset, rvalue, assign_to_halos):
if isinstance(rvalue, numbers.Number) or rvalue.shape in {(1,), assignee.shape}:
assignee[...] = rvalue
else:
cdim = self._assignee.function_space()._cdim
assert rvalue.shape == (cdim,)
cdim = self._assignee.function_space().value_size
if rvalue.shape != (cdim,):
raise DataValueError("Assignee and assignment values are different shapes")
assignee.reshape((-1, cdim))[...] = rvalue

def _compute_rvalue(self, func_data):
Expand Down
18 changes: 7 additions & 11 deletions firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import finat.ufl

from tsfc.ufl_utils import TSFCConstantMixin
from pyop2.exceptions import DataTypeError, DataValueError
import pyop3 as op3
from pyop3.exceptions import DataValueError
from firedrake.petsc import PETSc
from firedrake.utils import ScalarType
from ufl.classes import all_ufl_classes, ufl_classes, terminal_classes
Expand All @@ -29,15 +29,12 @@ def _create_const(value, comm):
shape = data.shape
rank = len(shape)

if comm is not None:
raise NotImplementedError("Won't be a back door for real space here, do elsewhere")

if rank == 0:
axes = op3.AxisTree(op3.Axis(1))
axes = op3.AxisTree()
else:
axes = op3.AxisTree(op3.Axis(shape[0]))
for size in shape[1:]:
axes = axes.add_axis(op3.Axis(size), *axes.leaf)
axes = op3.AxisTree(op3.Axis({"XXX": shape[0]}, label="dim0"))
for i, s in enumerate(shape[1:]):
axes = axes.add_axis(op3.Axis({"XXX": s}, label=f"dim{i+1}"), *axes.leaf)
dat = op3.HierarchicalArray(axes, data=data.flatten())
return dat, rank, shape

Expand Down Expand Up @@ -198,11 +195,10 @@ def assign(self, value):
self
"""
if self.ufl_shape() and np.array(value).shape != self.ufl_shape():
raise DataValueError("Cannot assign to constant, value has incorrect shape")
self.dat.data_wo[...] = value
return self
# TODO pyop3
# except (DataTypeError, DataValueError) as e:
# raise ValueError(e)

def __iadd__(self, o):
raise NotImplementedError("Augmented assignment to Constant not implemented")
Expand Down
4 changes: 4 additions & 0 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,10 @@ def _local_ises(self):
@utils.cached_property
def local_section(self):
section = PETSc.Section().create(comm=self.comm)
if self._ufl_function_space.ufl_element().family() == "Real":
# If real we don't need to populate the section
return section

points = self._mesh.points
section.setChart(0, points.size)
perm = PETSc.IS().createGeneral(points.numbering.data_ro, comm=self.comm)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"RED", "GREEN", "BLUE")


packages = ("pyop2", "tsfc", "firedrake", "UFL")
packages = ("pyop2", "pyop3", "tsfc", "firedrake", "UFL")


logger = logging.getLogger("firedrake")
Expand Down
53 changes: 28 additions & 25 deletions firedrake/parloops.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def _(
):
plex = V.mesh().topology

if V.ufl_element().family() == "Real":
return array

if integral_type == "cell":
# TODO ideally the FIAT permutation would not need to be known
# about by the mesh topology and instead be handled here. This
Expand Down Expand Up @@ -518,30 +521,24 @@ def _(
def _cell_integral_pack_indices(V: WithGeometry, cell: op3.LoopIndex) -> op3.IndexTree:
plex = V.mesh().topology

if V.ufl_element().family() == "Real":
indices = op3.IndexTree(op3.Slice("dof", [op3.AffineSliceComponent("XXX")]))
else:
indices = op3.IndexTree.from_nest({
plex._fiat_closure(cell): [
op3.Slice("dof", [op3.AffineSliceComponent("XXX")])
for _ in range(plex.dimension+1)
]
})
indices = op3.IndexTree.from_nest({
plex._fiat_closure(cell): [
op3.Slice("dof", [op3.AffineSliceComponent("XXX")])
for _ in range(plex.dimension+1)
]
})
return _with_shape_indices(V, indices)


def _facet_integral_pack_indices(V: WithGeometry, facet: op3.LoopIndex) -> op3.IndexTree:
plex = V.ufl_domain().topology

if V.ufl_element().family() == "Real":
indices = op3.IndexTree(op3.ScalarIndex(plex.name, "XXX", 0))
else:
indices = op3.IndexTree.from_nest({
plex._fiat_closure(plex.support(facet)): [
op3.Slice("dof", [op3.AffineSliceComponent("XXX")])
for _ in range(plex.dimension+1)
]
})
indices = op3.IndexTree.from_nest({
plex._fiat_closure(plex.support(facet)): [
op3.Slice("dof", [op3.AffineSliceComponent("XXX")])
for _ in range(plex.dimension+1)
]
})
# don't add support as an extra axis here, done already
return _with_shape_indices(V, indices, and_support=False)

Expand Down Expand Up @@ -624,13 +621,19 @@ def _with_shape_axes(V, axes, target_paths, index_exprs, integral_type):
trees_ = []
for space, tree in zip(spaces, trees):
if space.shape:
for leaf in tree.leaves:
for i, dim in enumerate(space.shape):
label = f"dim{i}"
subaxis = op3.Axis({"XXX": dim}, label)
tree = tree.add_axis(subaxis, *leaf)
new_target_paths[subaxis.id, "XXX"] = pmap({label: "XXX"})
new_index_exprs[subaxis.id, "XXX"] = pmap({label: op3.AxisVariable(label)})
for parent, component in tree.leaves:
axis_list = [
op3.Axis({"XXX": dim}, f"dim{ii}")
for ii, dim in enumerate(space.shape)
]
tree = tree.add_subtree(
op3.AxisTree.from_iterable(axis_list),
parent=parent,
component=component
)
for axis in axis_list:
new_target_paths[axis.id, "XXX"] = pmap({axis.label: "XXX"})
new_index_exprs[axis.id, "XXX"] = pmap({axis.label: op3.AxisVariable(axis.label)})

trees_.append(tree)
trees = tuple(trees_)
Expand Down
4 changes: 4 additions & 0 deletions tests/regression/test_appctx_cleanup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy
from firedrake import *
import pytest


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


class NonePC(PCBase):
Expand Down
4 changes: 2 additions & 2 deletions tests/regression/test_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_init_bcs_illegal(mesh, v):
DirichletBC(FunctionSpace(mesh, "CG", 1), v, 0)


@pytest.mark.skip(reason="pyop3 TODO")
@pytest.mark.parametrize('measure', [dx, ds])
def test_assemble_bcs_wrong_fs(V, measure):
"Assemble a Matrix with a DirichletBC on an incompatible FunctionSpace."
Expand Down Expand Up @@ -241,7 +240,6 @@ def test_preassembly_doesnt_modify_assembled_rhs(V, f):
assert np.allclose(b_vals, b.dat.data_ro)


@pytest.mark.skip(reason="pyop3 TODO")
def test_preassembly_bcs_caching(V):
bc1 = DirichletBC(V, 0, 1)
bc2 = DirichletBC(V, 1, 2)
Expand All @@ -268,6 +266,7 @@ def test_preassembly_bcs_caching(V):
assert not any(Aneither.M.values.diagonal() == 0)


@pytest.mark.skip(reason="pyop3 TODO")
def test_assemble_mass_bcs_2d(V):
if V.value_size > 1:
pytest.skip(reason="pyop3 TODO")
Expand Down Expand Up @@ -295,6 +294,7 @@ def test_assemble_mass_bcs_2d(V):
assert assemble(inner((w - f), (w - f))*dx) < 1e-12


@pytest.mark.skip(reason="pyop3 TODO")
@pytest.mark.parametrize("quad",
[False, True],
ids=["triangle", "quad"])
Expand Down
16 changes: 8 additions & 8 deletions tests/regression/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import pytest


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


def test_scalar_constant():
for m in [UnitIntervalMesh(5), UnitSquareMesh(2, 2), UnitCubeMesh(2, 2, 2)]:
c = Constant(1, domain=m)
Expand Down Expand Up @@ -122,8 +119,8 @@ def test_constant_vector_assign_works():

f.assign(c)

assert np.allclose(f.dat.data_ro[:, 0], 10)
assert np.allclose(f.dat.data_ro[:, 1], 11)
assert np.allclose(f.sub(0).dat.data_ro, 10)
assert np.allclose(f.sub(1).dat.data_ro, 11)


def test_constant_vector_assign_to_scalar_error():
Expand Down Expand Up @@ -162,9 +159,11 @@ def test_constant_assign_to_mixed():
f.sub(0).assign(c)
f.sub(1).assign(c)

for d in f.dat.data_ro:
assert np.allclose(d[:, 0], 10)
assert np.allclose(d[:, 1], 11)

assert np.allclose(f.sub(0).sub(0).dat.data_ro, 10)
assert np.allclose(f.sub(0).sub(1).dat.data_ro, 11)
assert np.allclose(f.sub(1).sub(0).dat.data_ro, 10)
assert np.allclose(f.sub(0).sub(1).dat.data_ro, 11)


def test_constant_multiplies_function():
Expand Down Expand Up @@ -207,6 +206,7 @@ def test_constant_names_are_not_used_in_generated_code():


@pytest.mark.skipcomplex
@pytest.mark.xfail(reason="requires matnest")
def test_correct_constants_are_used_in_split_form():
# see https://github.com/firedrakeproject/firedrake/issues/3091
mesh = UnitSquareMesh(3, 3)
Expand Down
4 changes: 4 additions & 0 deletions tests/regression/test_custom_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from firedrake import *
from firedrake.utils import ScalarType
import numpy as np
import pytest


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


def test_callbacks():
Expand Down
2 changes: 2 additions & 0 deletions tests/regression/test_facets.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,15 @@ def test_internal_integral_unit_tet():
assert abs(assemble(u('+') * dS)) < 1.0e-14


@pytest.mark.xfail(reason="pyop3 TODO")
def test_facet_map_no_reshape():
m = UnitSquareMesh(1, 1)
V = FunctionSpace(m, "DG", 0)
efnm = V.exterior_facet_node_map()
assert efnm.values_with_halo.shape == (4, 1)


@pytest.mark.skip(reason="pyop3 TODO")
def test_mesh_with_no_facet_markers():
mesh = UnitTriangleMesh()
mesh.init()
Expand Down
3 changes: 3 additions & 0 deletions tests/regression/test_interior_facets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
PETSc.Sys.popErrorHandler()


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


def run_test():
# mesh = UnitSquareMesh(10, 10)
mesh = UnitSquareMesh(2, 2)
Expand Down
3 changes: 3 additions & 0 deletions tests/regression/test_interp_dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import ufl


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


@pytest.fixture(scope='module')
def mesh():
return UnitSquareMesh(5, 5)
Expand Down
3 changes: 3 additions & 0 deletions tests/regression/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
cwd = abspath(dirname(__file__))


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


def test_constant():
cg1 = FunctionSpace(UnitSquareMesh(5, 5), "CG", 1)
f = assemble(interpolate(Constant(1.0), cg1))
Expand Down
1 change: 0 additions & 1 deletion tests/regression/test_multiple_domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_mismatching_meshes_indexed_function(mesh1, mesh3):
assemble(inner(d1, TestFunction(V2))*dx(domain=mesh1))


@pytest.mark.skip(reason="pyop3 TODO")
def test_mismatching_meshes_constant(mesh1, mesh3):
V2 = FunctionSpace(mesh3, "CG", 1)

Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_point_eval_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
cwd = abspath(dirname(__file__))


pytest.mark.skip(allow_module_level=True, reason="pyop3 point location")
pytest.skip(allow_module_level=True, reason="pyop3 point location")


def test_1d_args():
Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_point_eval_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
cwd = abspath(dirname(__file__))


pytest.mark.skip(allow_module_level=True, reason="pyop3 point location")
pytest.skip(allow_module_level=True, reason="pyop3 point location")


@pytest.fixture(params=[False, True])
Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_point_eval_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
cwd = abspath(dirname(__file__))


pytest.mark.skip(allow_module_level=True, reason="pyop3 point location")
pytest.skip(allow_module_level=True, reason="pyop3 point location")


@pytest.fixture
Expand Down
3 changes: 3 additions & 0 deletions tests/regression/test_real_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from firedrake.__future__ import *


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


@pytest.mark.skipcomplex
def test_real_assembly():
mesh = UnitIntervalMesh(3)
Expand Down
4 changes: 4 additions & 0 deletions tests/regression/test_serendipity_biharmonic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from firedrake import *
import numpy
import pytest


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


def test_serendipity_biharmonic():
Expand Down
3 changes: 3 additions & 0 deletions tests/regression/test_vfs_component_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np


pytest.skip(allow_module_level=True, reason="pyop3 TODO")


@pytest.fixture
def m():
return UnitSquareMesh(4, 4)
Expand Down
Loading

0 comments on commit 4d0b2d1

Please sign in to comment.