Skip to content

Commit

Permalink
io: use repr instead of pickle for UFL elements
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Mar 13, 2023
1 parent 520d1f2 commit 2d5aba6
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,15 +590,15 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
if mesh.name not in self.require_group(path):
path = self._path_to_mesh(tmesh.name, mesh.name)
self.require_group(path)
self.set_attr(path, PREFIX + "_coordinate_element", self._pickle(mesh._coordinates.function_space().ufl_element()))
self._save_ufl_element(path, PREFIX + "_coordinate_element", mesh._coordinates.function_space().ufl_element())
self.set_attr(path, PREFIX + "_coordinates", mesh._coordinates.name())
self._save_function_topology(mesh._coordinates)
if hasattr(mesh, PREFIX + "_radial_coordinates"):
# Cannot do: self.save_function(mesh.radial_coordinates)
# This will cause infinite recursion.
self.set_attr(path, PREFIX + "_radial_coordinate_function", mesh.radial_coordinates.name())
radial_coordinates = mesh.radial_coordinates.topological
self.set_attr(path, PREFIX + "_radial_coordinate_element", self._pickle(radial_coordinates.function_space().ufl_element()))
self._save_ufl_element(path, PREFIX + "_radial_coordinate_element", radial_coordinates.function_space().ufl_element())
self.set_attr(path, PREFIX + "_radial_coordinates", radial_coordinates.name())
self._save_function_topology(radial_coordinates)
self._update_mesh_name_topology_name_map({mesh.name: tmesh.name})
Expand All @@ -616,7 +616,7 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
path = self._path_to_mesh(tmesh.name, mesh.name)
self.require_group(path)
# Save Firedrake coodinates.
self.set_attr(path, PREFIX + "_coordinate_element", self._pickle(mesh._coordinates.function_space().ufl_element()))
self._save_ufl_element(path, PREFIX + "_coordinate_element", mesh._coordinates.function_space().ufl_element())
self.set_attr(path, PREFIX + "_coordinates", mesh._coordinates.name())
self._save_function_topology(mesh._coordinates)
# Save DMPlex coordinates for a complete representation of the plex.
Expand Down Expand Up @@ -719,11 +719,11 @@ def _save_function_space(self, V):
# Save UFL element
path = self._path_to_function_space(tmesh.name, mesh.name, V_name)
self.require_group(path)
self.set_attr(path, PREFIX + "_ufl_element", self._pickle(element))
# Test if the pickled UFL element matches the original element
loaded_element = self._unpickle(self.get_attr(path, PREFIX + "_ufl_element"))
self._save_ufl_element(path, PREFIX + "_ufl_element", element)
# Test if the loaded UFL element matches the original element
loaded_element = self._load_ufl_element(path, PREFIX + "_ufl_element")
if loaded_element != element:
raise RuntimeError(f"pickled UFL element ({loaded_element}) does not match the original element ({element})")
raise RuntimeError(f"Loaded UFL element ({loaded_element}) does not match the original element ({element})")

@PETSc.Log.EventDecorator("SaveFunctionSpaceTopology")
def _save_function_space_topology(self, tV):
Expand Down Expand Up @@ -892,12 +892,12 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
# -- Load mesh --
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
if self.has_attr(path, PREFIX + "_radial_coordinates"):
radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element"))
radial_coord_element = self._load_ufl_element(path, PREFIX + "_radial_coordinate_element")
radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates")
radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name)
tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element)
Expand All @@ -919,7 +919,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
# When implementing checkpointing for MeshHierarchy in the future,
# we will need to postpone calling tmesh.init().
tmesh.init()
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def _load_function_space(self, mesh, name):
elif self._is_function_space(tmesh.name, mesh.name, name):
# Load function space data
path = self._path_to_function_space(tmesh.name, mesh.name, name)
element = self._unpickle(self.get_attr(path, PREFIX + "_ufl_element"))
element = self._load_ufl_element(path, PREFIX + "_ufl_element")
tV = self._load_function_space_topology(tmesh, element)
# Construct function space
V = impl.WithGeometry.create(tV, mesh)
Expand Down Expand Up @@ -1353,6 +1353,18 @@ def _update_pickled_dict(self, name, new_item, *args):
the_dict.update(new_item)
getattr(self, "_set_" + name)(*args, the_dict)

def _save_ufl_element(self, path, name, elem):
self.set_attr(path, name + "_repr", repr(elem))

def _load_ufl_element(self, path, name):
if self.has_attr(path, name + "_repr"):
globals = {}
locals = {}
exec("from ufl import *", globals, locals)
return eval(self.get_attr(path, name + "_repr"), globals, locals)
else:
return self._unpickle(self.get_attr(path, name)) # backward compat.

def _set_mesh_name_topology_name_map(self, new_item):
path = self._path_to_topologies()
self._write_pickled_dict(path, PREFIX + "_mesh_name_topology_name_map", new_item)
Expand Down

0 comments on commit 2d5aba6

Please sign in to comment.