From 7d912a22fcd4e4cf01c4bfdf3f6a74d5cff151f9 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Thu, 9 Mar 2023 13:24:51 +0000 Subject: [PATCH] Remove class-level caches from CheckpointFile --- firedrake/checkpointing.py | 244 ++++++++++++++++--------------------- 1 file changed, 105 insertions(+), 139 deletions(-) diff --git a/firedrake/checkpointing.py b/firedrake/checkpointing.py index 6da9094955..9e1fa6f3c1 100644 --- a/firedrake/checkpointing.py +++ b/firedrake/checkpointing.py @@ -1,6 +1,5 @@ import functools import pickle -import weakref from petsc4py.PETSc import ViewerHDF5 import ufl from pyop2 import op2 @@ -517,10 +516,6 @@ class CheckpointFile(object): One can also use different number of processes for saving and for loading. """ - # Cache for loaded meshes. - _mesh_cache = weakref.WeakValueDictionary() - _tmesh_cache = weakref.WeakValueDictionary() - def __init__(self, filename, mode, comm=COMM_WORLD): self.viewer = ViewerHDF5() self.filename = filename @@ -869,88 +864,67 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters) base_tmesh.init() - tmesh_key = self._generate_mesh_key_from_names(tmesh_name, - base_tmesh._distribution_name, - base_tmesh._permutation_name) - if tmesh_key in self._tmesh_cache: - tmesh = self._tmesh_cache[tmesh_key] + periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False + variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers") + if variable_layers: + cell = base_tmesh.ufl_cell() + element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2) + _ = self._load_function_space_topology(base_tmesh, element) + base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name, + base_tmesh._distribution_name, + base_tmesh._permutation_name) + sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element) + _, _, lsf = self._function_load_utils[base_tmesh_key + sd_key] + nroots, _, _ = lsf.getGraph() + layers_a = np.empty(nroots, dtype=utils.IntType) + layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm) + layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"])) + self.viewer.pushGroup(path) + layers_a_iset.load(self.viewer) + self.viewer.popGroup() + layers_a = layers_a_iset.getIndices() + layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType) + unit = MPI._typedict[np.dtype(utils.IntType).char] + lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE) + lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE) else: - periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False - variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers") - if variable_layers: - cell = base_tmesh.ufl_cell() - element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2) - _ = self._load_function_space_topology(base_tmesh, element) - base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name, - base_tmesh._distribution_name, - base_tmesh._permutation_name) - sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element) - _, _, lsf = self._function_load_utils[base_tmesh_key + sd_key] - nroots, _, _ = lsf.getGraph() - layers_a = np.empty(nroots, dtype=utils.IntType) - layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm) - layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"])) - self.viewer.pushGroup(path) - layers_a_iset.load(self.viewer) - self.viewer.popGroup() - layers_a = layers_a_iset.getIndices() - layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType) - unit = MPI._typedict[np.dtype(utils.IntType).char] - lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE) - lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE) - else: - layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers") - tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name) - self._tmesh_cache[tmesh_key] = tmesh + layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers") + tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name) # -- Load mesh -- - mesh_key = self._generate_mesh_key_from_names(name, - base_tmesh._distribution_name, - base_tmesh._permutation_name) - if mesh_key in self._mesh_cache: - mesh = self._mesh_cache[mesh_key] - else: - path = self._path_to_mesh(tmesh_name, name) - coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) - coord_name = self.get_attr(path, PREFIX + "_coordinates") - coordinates = self._load_function_topology(tmesh, coord_element, coord_name) - mesh = make_mesh_from_coordinates(coordinates, name) - if self.has_attr(path, PREFIX + "_radial_coordinates"): - radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element")) - radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates") - radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name) - tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element) - V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh) - radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function") - mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name) - # The followings are conceptually redundant, but needed. - path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED) - base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") - mesh._base_mesh = self.load_mesh(base_mesh_name) - self._mesh_cache[mesh_key] = mesh + path = self._path_to_mesh(tmesh_name, name) + coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) + coord_name = self.get_attr(path, PREFIX + "_coordinates") + coordinates = self._load_function_topology(tmesh, coord_element, coord_name) + mesh = make_mesh_from_coordinates(coordinates, name) + if self.has_attr(path, PREFIX + "_radial_coordinates"): + radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element")) + radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates") + radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name) + tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element) + V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh) + radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function") + mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name) + # The followings are conceptually redundant, but needed. + path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED) + base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh") + mesh._base_mesh = self.load_mesh(base_mesh_name) else: utils._init() # -- Load mesh topology -- tmesh = self._load_mesh_topology(tmesh_name, reorder, distribution_parameters) - mesh_key = self._generate_mesh_key_from_names(name, - tmesh._distribution_name, - tmesh._permutation_name) - if mesh_key in self._mesh_cache: - mesh = self._mesh_cache[mesh_key] - else: - # -- Load coordinates -- - # tmesh.topology_dm has already been redistributed. - path = self._path_to_mesh(tmesh_name, name) - # Load firedrake coordinates directly. - # When implementing checkpointing for MeshHierarchy in the future, - # we will need to postpone calling tmesh.init(). - tmesh.init() - coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) - coord_name = self.get_attr(path, PREFIX + "_coordinates") - coordinates = self._load_function_topology(tmesh, coord_element, coord_name) - mesh = make_mesh_from_coordinates(coordinates, name) - # Load plex coordinates for a complete representation of plex. - tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC) - self._mesh_cache[mesh_key] = mesh + # -- Load coordinates -- + # tmesh.topology_dm has already been redistributed. + path = self._path_to_mesh(tmesh_name, name) + # Load firedrake coordinates directly. + # When implementing checkpointing for MeshHierarchy in the future, + # we will need to postpone calling tmesh.init(). + tmesh.init() + coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element")) + coord_name = self.get_attr(path, PREFIX + "_coordinates") + coordinates = self._load_function_topology(tmesh, coord_element, coord_name) + mesh = make_mesh_from_coordinates(coordinates, name) + # Load plex coordinates for a complete representation of plex. + tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC) return mesh @PETSc.Log.EventDecorator("LoadMeshTopology") @@ -989,65 +963,57 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters): distribution_name = None permutation_name = None perm_is = None - # This is only to return the same tmesh object if the same set of arguments are given. - # Multiple tmesh_key might end up having the same value, but it is hard to process - # all distribution and reorder options at this stage (many things happen in MeshTopology constructor). - tmesh_key = self._generate_mesh_key(tmesh_name, distribution_name, permutation_name, reorder, distribution_parameters) - if tmesh_key in self._tmesh_cache: - tmesh = self._tmesh_cache[tmesh_key] + plex = PETSc.DMPlex() + plex.create(comm=self._comm) + plex.setName(tmesh_name) + # Check format + path = os.path.join(self._path_to_topology(tmesh_name), "topology") + if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"), + os.path.join(path, "cones"), + os.path.join(path, "order"), + os.path.join(path, "orientation")]): + raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}") + format = ViewerHDF5.Format.HDF5_PETSC + self.viewer.pushFormat(format=format) + plex.distributionSetName(distribution_name) + sfXB = plex.topologyLoad(self.viewer) + plex.distributionSetName(None) + self.viewer.popFormat() + if load_distribution_permutation: + chart_size = np.empty(1, dtype=utils.IntType) + chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm) + chart_sizes_iset.setName("chart_sizes") + path = self._path_to_distribution(tmesh_name, distribution_name) + self.viewer.pushGroup(path) + chart_sizes_iset.load(self.viewer) + self.viewer.popGroup() + chart_size = chart_sizes_iset.getIndices().item() + perm = np.empty(chart_size, dtype=utils.IntType) + perm_is = PETSc.IS().createGeneral(perm, comm=self._comm) + path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name) + self.viewer.pushGroup(path) + perm_is.setName("permutation") + perm_is.load(self.viewer) + perm_is.setName(None) + self.viewer.popGroup() else: - plex = PETSc.DMPlex() - plex.create(comm=self._comm) - plex.setName(tmesh_name) - # Check format - path = os.path.join(self._path_to_topology(tmesh_name), "topology") - if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"), - os.path.join(path, "cones"), - os.path.join(path, "order"), - os.path.join(path, "orientation")]): - raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}") - format = ViewerHDF5.Format.HDF5_PETSC - self.viewer.pushFormat(format=format) - plex.distributionSetName(distribution_name) - sfXB = plex.topologyLoad(self.viewer) - plex.distributionSetName(None) - self.viewer.popFormat() - if load_distribution_permutation: - chart_size = np.empty(1, dtype=utils.IntType) - chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm) - chart_sizes_iset.setName("chart_sizes") - path = self._path_to_distribution(tmesh_name, distribution_name) - self.viewer.pushGroup(path) - chart_sizes_iset.load(self.viewer) - self.viewer.popGroup() - chart_size = chart_sizes_iset.getIndices().item() - perm = np.empty(chart_size, dtype=utils.IntType) - perm_is = PETSc.IS().createGeneral(perm, comm=self._comm) - path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name) - self.viewer.pushGroup(path) - perm_is.setName("permutation") - perm_is.load(self.viewer) - perm_is.setName(None) - self.viewer.popGroup() - else: - perm_is = None - # -- Construct Mesh (Topology) -- - # Use public API so pass user comm (self.comm) - tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder, - distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is, - distribution_name=distribution_name, permutation_name=permutation_name, - comm=self.comm) - self.viewer.pushFormat(format=format) - # tmesh.topology_dm has already been redistributed. - sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB - plex.labelsLoad(self.viewer, sfXCtemp) - self.viewer.popFormat() - # These labels are distribution dependent. - # We should be able to save/load labels selectively. - plex.removeLabel("pyop2_core") - plex.removeLabel("pyop2_owned") - plex.removeLabel("pyop2_ghost") - self._tmesh_cache[tmesh_key] = tmesh + perm_is = None + # -- Construct Mesh (Topology) -- + # Use public API so pass user comm (self.comm) + tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder, + distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is, + distribution_name=distribution_name, permutation_name=permutation_name, + comm=self.comm) + self.viewer.pushFormat(format=format) + # tmesh.topology_dm has already been redistributed. + sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB + plex.labelsLoad(self.viewer, sfXCtemp) + self.viewer.popFormat() + # These labels are distribution dependent. + # We should be able to save/load labels selectively. + plex.removeLabel("pyop2_core") + plex.removeLabel("pyop2_owned") + plex.removeLabel("pyop2_ghost") return tmesh @PETSc.Log.EventDecorator("LoadFunctionSpace")