Skip to content

Commit

Permalink
Remove class-level caches from CheckpointFile
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Mar 9, 2023
1 parent 4779660 commit 7d912a2
Showing 1 changed file with 105 additions and 139 deletions.
244 changes: 105 additions & 139 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import pickle
import weakref
from petsc4py.PETSc import ViewerHDF5
import ufl
from pyop2 import op2
Expand Down Expand Up @@ -517,10 +516,6 @@ class CheckpointFile(object):
One can also use different number of processes for saving and for loading.
"""
# Cache for loaded meshes.
_mesh_cache = weakref.WeakValueDictionary()
_tmesh_cache = weakref.WeakValueDictionary()

def __init__(self, filename, mode, comm=COMM_WORLD):
self.viewer = ViewerHDF5()
self.filename = filename
Expand Down Expand Up @@ -869,88 +864,67 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
base_tmesh.init()
tmesh_key = self._generate_mesh_key_from_names(tmesh_name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
if tmesh_key in self._tmesh_cache:
tmesh = self._tmesh_cache[tmesh_key]
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
else:
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
else:
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
self._tmesh_cache[tmesh_key] = tmesh
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
# -- Load mesh --
mesh_key = self._generate_mesh_key_from_names(name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
if mesh_key in self._mesh_cache:
mesh = self._mesh_cache[mesh_key]
else:
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
if self.has_attr(path, PREFIX + "_radial_coordinates"):
radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element"))
radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates")
radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name)
tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element)
V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh)
radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function")
mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name)
# The followings are conceptually redundant, but needed.
path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED)
base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
mesh._base_mesh = self.load_mesh(base_mesh_name)
self._mesh_cache[mesh_key] = mesh
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
if self.has_attr(path, PREFIX + "_radial_coordinates"):
radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element"))
radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates")
radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name)
tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element)
V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh)
radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function")
mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name)
# The followings are conceptually redundant, but needed.
path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED)
base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
mesh._base_mesh = self.load_mesh(base_mesh_name)
else:
utils._init()
# -- Load mesh topology --
tmesh = self._load_mesh_topology(tmesh_name, reorder, distribution_parameters)
mesh_key = self._generate_mesh_key_from_names(name,
tmesh._distribution_name,
tmesh._permutation_name)
if mesh_key in self._mesh_cache:
mesh = self._mesh_cache[mesh_key]
else:
# -- Load coordinates --
# tmesh.topology_dm has already been redistributed.
path = self._path_to_mesh(tmesh_name, name)
# Load firedrake coordinates directly.
# When implementing checkpointing for MeshHierarchy in the future,
# we will need to postpone calling tmesh.init().
tmesh.init()
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
# Load plex coordinates for a complete representation of plex.
tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC)
self._mesh_cache[mesh_key] = mesh
# -- Load coordinates --
# tmesh.topology_dm has already been redistributed.
path = self._path_to_mesh(tmesh_name, name)
# Load firedrake coordinates directly.
# When implementing checkpointing for MeshHierarchy in the future,
# we will need to postpone calling tmesh.init().
tmesh.init()
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
# Load plex coordinates for a complete representation of plex.
tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC)
return mesh

@PETSc.Log.EventDecorator("LoadMeshTopology")
Expand Down Expand Up @@ -989,65 +963,57 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
distribution_name = None
permutation_name = None
perm_is = None
# This is only to return the same tmesh object if the same set of arguments are given.
# Multiple tmesh_key might end up having the same value, but it is hard to process
# all distribution and reorder options at this stage (many things happen in MeshTopology constructor).
tmesh_key = self._generate_mesh_key(tmesh_name, distribution_name, permutation_name, reorder, distribution_parameters)
if tmesh_key in self._tmesh_cache:
tmesh = self._tmesh_cache[tmesh_key]
plex = PETSc.DMPlex()
plex.create(comm=self._comm)
plex.setName(tmesh_name)
# Check format
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"),
os.path.join(path, "cones"),
os.path.join(path, "order"),
os.path.join(path, "orientation")]):
raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}")
format = ViewerHDF5.Format.HDF5_PETSC
self.viewer.pushFormat(format=format)
plex.distributionSetName(distribution_name)
sfXB = plex.topologyLoad(self.viewer)
plex.distributionSetName(None)
self.viewer.popFormat()
if load_distribution_permutation:
chart_size = np.empty(1, dtype=utils.IntType)
chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm)
chart_sizes_iset.setName("chart_sizes")
path = self._path_to_distribution(tmesh_name, distribution_name)
self.viewer.pushGroup(path)
chart_sizes_iset.load(self.viewer)
self.viewer.popGroup()
chart_size = chart_sizes_iset.getIndices().item()
perm = np.empty(chart_size, dtype=utils.IntType)
perm_is = PETSc.IS().createGeneral(perm, comm=self._comm)
path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name)
self.viewer.pushGroup(path)
perm_is.setName("permutation")
perm_is.load(self.viewer)
perm_is.setName(None)
self.viewer.popGroup()
else:
plex = PETSc.DMPlex()
plex.create(comm=self._comm)
plex.setName(tmesh_name)
# Check format
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"),
os.path.join(path, "cones"),
os.path.join(path, "order"),
os.path.join(path, "orientation")]):
raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}")
format = ViewerHDF5.Format.HDF5_PETSC
self.viewer.pushFormat(format=format)
plex.distributionSetName(distribution_name)
sfXB = plex.topologyLoad(self.viewer)
plex.distributionSetName(None)
self.viewer.popFormat()
if load_distribution_permutation:
chart_size = np.empty(1, dtype=utils.IntType)
chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm)
chart_sizes_iset.setName("chart_sizes")
path = self._path_to_distribution(tmesh_name, distribution_name)
self.viewer.pushGroup(path)
chart_sizes_iset.load(self.viewer)
self.viewer.popGroup()
chart_size = chart_sizes_iset.getIndices().item()
perm = np.empty(chart_size, dtype=utils.IntType)
perm_is = PETSc.IS().createGeneral(perm, comm=self._comm)
path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name)
self.viewer.pushGroup(path)
perm_is.setName("permutation")
perm_is.load(self.viewer)
perm_is.setName(None)
self.viewer.popGroup()
else:
perm_is = None
# -- Construct Mesh (Topology) --
# Use public API so pass user comm (self.comm)
tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder,
distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is,
distribution_name=distribution_name, permutation_name=permutation_name,
comm=self.comm)
self.viewer.pushFormat(format=format)
# tmesh.topology_dm has already been redistributed.
sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB
plex.labelsLoad(self.viewer, sfXCtemp)
self.viewer.popFormat()
# These labels are distribution dependent.
# We should be able to save/load labels selectively.
plex.removeLabel("pyop2_core")
plex.removeLabel("pyop2_owned")
plex.removeLabel("pyop2_ghost")
self._tmesh_cache[tmesh_key] = tmesh
perm_is = None
# -- Construct Mesh (Topology) --
# Use public API so pass user comm (self.comm)
tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder,
distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is,
distribution_name=distribution_name, permutation_name=permutation_name,
comm=self.comm)
self.viewer.pushFormat(format=format)
# tmesh.topology_dm has already been redistributed.
sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB
plex.labelsLoad(self.viewer, sfXCtemp)
self.viewer.popFormat()
# These labels are distribution dependent.
# We should be able to save/load labels selectively.
plex.removeLabel("pyop2_core")
plex.removeLabel("pyop2_owned")
plex.removeLabel("pyop2_ghost")
return tmesh

@PETSc.Log.EventDecorator("LoadFunctionSpace")
Expand Down

0 comments on commit 7d912a2

Please sign in to comment.