From f55e9687c59ea8aeed9078d3745bfd71318579ea Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 10 Mar 2023 08:44:22 +0000 Subject: [PATCH] Remove class-level caches from CheckpointFile (#2810) * Remove class-level caches from CheckpointFile * Refactor backward compat IO tests --- firedrake/checkpointing.py | 244 ++++++++++-------------- tests/output/test_io_backward_compat.py | 118 +++++++----- 2 files changed, 170 insertions(+), 192 deletions(-) diff --git a/firedrake/checkpointing.py b/firedrake/checkpointing.py index 6da9094955..9e1fa6f3c1 100644 --- a/firedrake/checkpointing.py +++ b/firedrake/checkpointing.py @@ -1,6 +1,5 @@ import functools import pickle -import weakref from petsc4py.PETSc import ViewerHDF5 import ufl from pyop2 import op2 @@ -517,10 +516,6 @@ class CheckpointFile(object): One can also use different number of processes for saving and for loading. """ - # Cache for loaded meshes. - _mesh_cache = weakref.WeakValueDictionary() - _tmesh_cache = weakref.WeakValueDictionary() - def __init__(self, filename, mode, comm=COMM_WORLD): self.viewer = ViewerHDF5() self.filename = filename @@ -869,88 +864,67 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters) base_tmesh.init() - tmesh_key = self._generate_mesh_key_from_names(tmesh_name, - base_tmesh._distribution_name, - base_tmesh._permutation_name) - if tmesh_key in self._tmesh_cache: - tmesh = self._tmesh_cache[tmesh_key] + periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False + variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers") + if variable_layers: + cell = base_tmesh.ufl_cell() + element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2) + _ = self._load_function_space_topology(base_tmesh, element) + base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name, + base_tmesh._distribution_name, + base_tmesh._permutation_name) + sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element) + _, _, lsf = self._function_load_utils[base_tmesh_key + sd_key] + nroots, _, _ = lsf.getGraph() + layers_a = np.empty(nroots, dtype=utils.IntType) + layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm) + layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"])) + self.viewer.pushGroup(path) + layers_a_iset.load(self.viewer) + self.viewer.popGroup() + layers_a = layers_a_iset.getIndices() + layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType) + unit = MPI._typedict[np.dtype(utils.IntType).char] + lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE) + lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE) else: - periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False - variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers") - if variable_layers: - cell = base_tmesh.ufl_cell() - element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2) - _ = self._load_function_space_topology(base_tmesh, element) - base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name, - base_tmesh._distribution_name, - base_tmesh._permutation_name) - sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element) - _, _, lsf = self._function_load_utils[base_tmesh_key + sd_key] - nroots, _, _ = lsf.getGraph() - layers_a = np.empty(nroots, dtype=utils.IntType) - layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm) - layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"])) - self.viewer.pushGroup(path) - layers_a_iset.load(self.viewer) - self.viewer.popGroup() - layers_a = layers_a_iset.getIndices() - layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType) - unit = MPI._typedict[np.dtype(utils.IntType).char] - lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE) - lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE) - else: - layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers") - tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name) - self._tmesh_cache[tmesh_key] = tmesh + layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers") + tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name) # -- Load mesh -- - mesh_key = self._generate_mesh_key_from_names(name, - base_tmesh._distribution_name, - base_tmesh._permutation_name) - if mesh_key in self._mesh_cache: - mesh = self._mesh_cache[mesh_key] - else: - path = self._path_to_mesh(tmesh_name, name) - coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) - coord_name = self.get_attr(path, PREFIX + "_coordinates") - coordinates = self._load_function_topology(tmesh, coord_element, coord_name) - mesh = make_mesh_from_coordinates(coordinates, name) - if self.has_attr(path, PREFIX + "_radial_coordinates"): - radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element")) - radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates") - radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name) - tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element) - V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh) - radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function") - mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name) - # The followings are conceptually redundant, but needed. - path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED) - base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") - mesh._base_mesh = self.load_mesh(base_mesh_name) - self._mesh_cache[mesh_key] = mesh + path = self._path_to_mesh(tmesh_name, name) + coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) + coord_name = self.get_attr(path, PREFIX + "_coordinates") + coordinates = self._load_function_topology(tmesh, coord_element, coord_name) + mesh = make_mesh_from_coordinates(coordinates, name) + if self.has_attr(path, PREFIX + "_radial_coordinates"): + radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element")) + radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates") + radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name) + tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element) + V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh) + radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function") + mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name) + # The followings are conceptually redundant, but needed. + path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED) + base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") + mesh._base_mesh = self.load_mesh(base_mesh_name) else: utils._init() # -- Load mesh topology -- tmesh = self._load_mesh_topology(tmesh_name, reorder, distribution_parameters) - mesh_key = self._generate_mesh_key_from_names(name, - tmesh._distribution_name, - tmesh._permutation_name) - if mesh_key in self._mesh_cache: - mesh = self._mesh_cache[mesh_key] - else: - # -- Load coordinates -- - # tmesh.topology_dm has already been redistributed. - path = self._path_to_mesh(tmesh_name, name) - # Load firedrake coordinates directly. - # When implementing checkpointing for MeshHierarchy in the future, - # we will need to postpone calling tmesh.init(). - tmesh.init() - coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) - coord_name = self.get_attr(path, PREFIX + "_coordinates") - coordinates = self._load_function_topology(tmesh, coord_element, coord_name) - mesh = make_mesh_from_coordinates(coordinates, name) - # Load plex coordinates for a complete representation of plex. - tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC) - self._mesh_cache[mesh_key] = mesh + # -- Load coordinates -- + # tmesh.topology_dm has already been redistributed. + path = self._path_to_mesh(tmesh_name, name) + # Load firedrake coordinates directly. + # When implementing checkpointing for MeshHierarchy in the future, + # we will need to postpone calling tmesh.init(). + tmesh.init() + coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) + coord_name = self.get_attr(path, PREFIX + "_coordinates") + coordinates = self._load_function_topology(tmesh, coord_element, coord_name) + mesh = make_mesh_from_coordinates(coordinates, name) + # Load plex coordinates for a complete representation of plex. + tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC) return mesh @PETSc.Log.EventDecorator("LoadMeshTopology") @@ -989,65 +963,57 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters): distribution_name = None permutation_name = None perm_is = None - # This is only to return the same tmesh object if the same set of arguments are given. - # Multiple tmesh_key might end up having the same value, but it is hard to process - # all distribution and reorder options at this stage (many things happen in MeshTopology constructor). - tmesh_key = self._generate_mesh_key(tmesh_name, distribution_name, permutation_name, reorder, distribution_parameters) - if tmesh_key in self._tmesh_cache: - tmesh = self._tmesh_cache[tmesh_key] + plex = PETSc.DMPlex() + plex.create(comm=self._comm) + plex.setName(tmesh_name) + # Check format + path = os.path.join(self._path_to_topology(tmesh_name), "topology") + if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"), + os.path.join(path, "cones"), + os.path.join(path, "order"), + os.path.join(path, "orientation")]): + raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}") + format = ViewerHDF5.Format.HDF5_PETSC + self.viewer.pushFormat(format=format) + plex.distributionSetName(distribution_name) + sfXB = plex.topologyLoad(self.viewer) + plex.distributionSetName(None) + self.viewer.popFormat() + if load_distribution_permutation: + chart_size = np.empty(1, dtype=utils.IntType) + chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm) + chart_sizes_iset.setName("chart_sizes") + path = self._path_to_distribution(tmesh_name, distribution_name) + self.viewer.pushGroup(path) + chart_sizes_iset.load(self.viewer) + self.viewer.popGroup() + chart_size = chart_sizes_iset.getIndices().item() + perm = np.empty(chart_size, dtype=utils.IntType) + perm_is = PETSc.IS().createGeneral(perm, comm=self._comm) + path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name) + self.viewer.pushGroup(path) + perm_is.setName("permutation") + perm_is.load(self.viewer) + perm_is.setName(None) + self.viewer.popGroup() else: - plex = PETSc.DMPlex() - plex.create(comm=self._comm) - plex.setName(tmesh_name) - # Check format - path = os.path.join(self._path_to_topology(tmesh_name), "topology") - if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"), - os.path.join(path, "cones"), - os.path.join(path, "order"), - os.path.join(path, "orientation")]): - raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}") - format = ViewerHDF5.Format.HDF5_PETSC - self.viewer.pushFormat(format=format) - plex.distributionSetName(distribution_name) - sfXB = plex.topologyLoad(self.viewer) - plex.distributionSetName(None) - self.viewer.popFormat() - if load_distribution_permutation: - chart_size = np.empty(1, dtype=utils.IntType) - chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm) - chart_sizes_iset.setName("chart_sizes") - path = self._path_to_distribution(tmesh_name, distribution_name) - self.viewer.pushGroup(path) - chart_sizes_iset.load(self.viewer) - self.viewer.popGroup() - chart_size = chart_sizes_iset.getIndices().item() - perm = np.empty(chart_size, dtype=utils.IntType) - perm_is = PETSc.IS().createGeneral(perm, comm=self._comm) - path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name) - self.viewer.pushGroup(path) - perm_is.setName("permutation") - perm_is.load(self.viewer) - perm_is.setName(None) - self.viewer.popGroup() - else: - perm_is = None - # -- Construct Mesh (Topology) -- - # Use public API so pass user comm (self.comm) - tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder, - distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is, - distribution_name=distribution_name, permutation_name=permutation_name, - comm=self.comm) - self.viewer.pushFormat(format=format) - # tmesh.topology_dm has already been redistributed. - sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB - plex.labelsLoad(self.viewer, sfXCtemp) - self.viewer.popFormat() - # These labels are distribution dependent. - # We should be able to save/load labels selectively. - plex.removeLabel("pyop2_core") - plex.removeLabel("pyop2_owned") - plex.removeLabel("pyop2_ghost") - self._tmesh_cache[tmesh_key] = tmesh + perm_is = None + # -- Construct Mesh (Topology) -- + # Use public API so pass user comm (self.comm) + tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder, + distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is, + distribution_name=distribution_name, permutation_name=permutation_name, + comm=self.comm) + self.viewer.pushFormat(format=format) + # tmesh.topology_dm has already been redistributed. + sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB + plex.labelsLoad(self.viewer, sfXCtemp) + self.viewer.popFormat() + # These labels are distribution dependent. + # We should be able to save/load labels selectively. + plex.removeLabel("pyop2_core") + plex.removeLabel("pyop2_owned") + plex.removeLabel("pyop2_ghost") return tmesh @PETSc.Log.EventDecorator("LoadFunctionSpace") diff --git a/tests/output/test_io_backward_compat.py b/tests/output/test_io_backward_compat.py index b0b9ced40a..5c9dbf2692 100644 --- a/tests/output/test_io_backward_compat.py +++ b/tests/output/test_io_backward_compat.py @@ -57,9 +57,7 @@ def _get_expr(V): raise ValueError(f"Invalid shape {shape}") -@pytest.mark.skipcomplex -@pytest.mark.parallel(nprocs=3) -def test_io_backward_compat_load_146397af52673c7adffbc12b4e0492d4b357069a(): +def _old_mesh_filename(): """ --------------------------------------------------------------------------- |Package |Branch |Revision |Modified | @@ -81,57 +79,71 @@ def test_io_backward_compat_load_146397af52673c7adffbc12b4e0492d4b357069a(): |ufl |master |0c592ec5 |False | --------------------------------------------------------------------------- """ - filename = join(cwd, "test_io_backward_compat_files", "test_io_backward_compat_146397af52673c7adffbc12b4e0492d4b357069a.h5") - filename = COMM_WORLD.bcast(filename, root=0) - afile = CheckpointFile(filename, 'r', comm=COMM_WORLD) - # Base - for cell_type, family, degree in [("triangle", "P", 5), - ("triangle", "RTE", 4), - ("triangle", "RTF", 4), - ("triangle", "BDME", 4), - ("triangle", "BDMF", 4), - ("triangle", "DP", 6), - ("tetrahedra", "P", 6), - ("tetrahedra", "N1E", 2), # slow if high order - ("tetrahedra", "N1F", 5), - ("tetrahedra", "N2E", 2), # slow if high order - ("tetrahedra", "N2F", 5), - ("tetrahedra", "DP", 5), - ("quadrilateral", "Q", 7), - ("quadrilateral", "RTCE", 5), - ("quadrilateral", "RTCF", 5), - ("quadrilateral", "DQ", 7), - ("quadrilateral", "S", 5), - ("quadrilateral", "DPC", 5)]: - # meshes and functions have been saved as (in 'w' mode using 2 processes): - # >>> mesh = _get_mesh(cell_type, _generate_mesh_name(cell_type), COMM_WORLD) - # >>> V = FunctionSpace(mesh, family, degree) - # >>> f = Function(V, name=_generate_func_name(mesh.name, family, degree)) - # >>> _initialise_function(f, _get_expr(V)) - # >>> afile.save_function(f) + fname = join(cwd, "test_io_backward_compat_files", + "test_io_backward_compat_146397af52673c7adffbc12b4e0492d4b357069a.h5") + fname = COMM_WORLD.bcast(fname, root=0) + return fname + + +@pytest.mark.skipcomplex +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize(("cell_type", "family", "degree"), + [("triangle", "P", 5), + ("triangle", "RTE", 4), + ("triangle", "RTF", 4), + ("triangle", "BDME", 4), + ("triangle", "BDMF", 4), + ("triangle", "DP", 6), + ("tetrahedra", "P", 6), + ("tetrahedra", "N1E", 2), # slow if high order + ("tetrahedra", "N1F", 5), + ("tetrahedra", "N2E", 2), # slow if high order + ("tetrahedra", "N2F", 5), + ("tetrahedra", "DP", 5), + ("quadrilateral", "Q", 7), + ("quadrilateral", "RTCE", 5), + ("quadrilateral", "RTCF", 5), + ("quadrilateral", "DQ", 7), + ("quadrilateral", "S", 5), + ("quadrilateral", "DPC", 5)]) +def test_io_backward_compat_load(cell_type, family, degree): + # meshes and functions have been saved as (in 'w' mode using 2 processes): + # >>> mesh = _get_mesh(cell_type, _generate_mesh_name(cell_type), COMM_WORLD) + # >>> V = FunctionSpace(mesh, family, degree) + # >>> f = Function(V, name=_generate_func_name(mesh.name, family, degree)) + # >>> _initialise_function(f, _get_expr(V)) + # >>> afile.save_function(f) + filename = _old_mesh_filename() + with CheckpointFile(filename, "r", comm=COMM_WORLD) as afile: mesh = afile.load_mesh(_generate_mesh_name(cell_type)) f = afile.load_function(mesh, _generate_func_name(mesh.name, family, degree)) - V = f.function_space() - fe = Function(V) - _initialise_function(fe, _get_expr(V)) - assert assemble(inner(f - fe, f - fe) * dx) < 5.e-12 - # Extrusion - for cell_type, family, degree, vfamily, vdegree in [("triangle", "BDMF", 4, "DG", 3), - ("quadrilateral", "RTCF", 4, "DG", 3)]: - # meshes and functions have been saved as (in 'w' mode using 2 processes): - # >>> mesh = _get_mesh(cell_type, _generate_mesh_name(cell_type), COMM_WORLD) - # >>> extm = ExtrudedMesh(mesh, 4, layer_height=[0.2, 0.3, 0.5, 0.7], name=_generate_extruded_mesh_name(cell_type)) - # >>> helem = FiniteElement(family, cell_type, degree) - # >>> velem = FiniteElement(vfamily, "interval", vdegree) - # >>> elem = HDiv(TensorProductElement(helem, velem)) - # >>> V = FunctionSpace(extm, elem) - # >>> f = Function(V, name=_generate_func_name(extm.name, family, degree)) - # >>> _initialise_function(f, _get_expr(V)) - # >>> afile.save_function(f) + V = f.function_space() + fe = Function(V) + _initialise_function(fe, _get_expr(V)) + assert assemble(inner(f - fe, f - fe) * dx) < 5.e-12 + + +@pytest.mark.skipcomplex +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize(("cell_type", "family", "degree", "vfamily", "vdegree"), + [("triangle", "BDMF", 4, "DG", 3), + ("quadrilateral", "RTCF", 4, "DG", 3)]) +def test_io_backward_compat_load_extruded(cell_type, family, degree, vfamily, vdegree): + # meshes and functions have been saved as (in 'w' mode using 2 processes): + # >>> mesh = _get_mesh(cell_type, _generate_mesh_name(cell_type), COMM_WORLD) + # >>> extm = ExtrudedMesh(mesh, 4, layer_height=[0.2, 0.3, 0.5, 0.7], name=_generate_extruded_mesh_name(cell_type)) + # >>> helem = FiniteElement(family, cell_type, degree) + # >>> velem = FiniteElement(vfamily, "interval", vdegree) + # >>> elem = HDiv(TensorProductElement(helem, velem)) + # >>> V = FunctionSpace(extm, elem) + # >>> f = Function(V, name=_generate_func_name(extm.name, family, degree)) + # >>> _initialise_function(f, _get_expr(V)) + # >>> afile.save_function(f) + filename = _old_mesh_filename() + with CheckpointFile(filename, "r", comm=COMM_WORLD) as afile: extm = afile.load_mesh(_generate_extruded_mesh_name(cell_type)) f = afile.load_function(extm, _generate_func_name(extm.name, family, degree)) - V = f.function_space() - fe = Function(V) - _initialise_function(fe, _get_expr(V)) - assert assemble(inner(f - fe, f - fe) * dx) < 5.e-12 - afile.close() + V = f.function_space() + fe = Function(V) + _initialise_function(fe, _get_expr(V)) + assert assemble(inner(f - fe, f - fe) * dx) < 5.e-12