From 3bbbe79bd2d896c577f223c2d2105300ccf1591c Mon Sep 17 00:00:00 2001 From: mloubout Date: Mon, 18 Nov 2024 10:12:02 -0500 Subject: [PATCH] api: fix equation pickling --- devito/types/equation.py | 4 ++-- tests/test_pickle.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/devito/types/equation.py b/devito/types/equation.py index 662cdd0d34..ad3e62f0e2 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -2,13 +2,13 @@ import sympy from devito.deprecations import deprecations -from devito.tools import as_tuple, frozendict +from devito.tools import as_tuple, frozendict, Pickable from devito.types.lazy import Evaluable __all__ = ['Eq', 'Inc', 'ReduceMax', 'ReduceMin'] -class Eq(sympy.Eq, Evaluable): +class Eq(sympy.Eq, Evaluable, Pickable): """ An equal relation between two objects, the left-hand side and the diff --git a/tests/test_pickle.py b/tests/test_pickle.py index c2a676fde6..ef47e917fb 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -557,6 +557,24 @@ def test_derivative(self, pickle, transpose, side, deriv_order, assert new_dfdx.method == dfdx.method assert new_dfdx.weights == dfdx.weights + def test_equation(self, pickle): + grid = Grid(shape=(3,)) + x = grid.dimensions[0] + f = Function(name='f', grid=grid) + + # Some implicit dim + xs = ConditionalDimension(name='xs', parent=x, factor=4) + + eq = Eq(f, f+1, implicit_dims=xs) + + pkl_eq = pickle0.dumps(eq) + new_eq = pickle0.loads(pkl_eq) + + assert new_eq.lhs.name == f.name + assert str(new_eq.rhs) == 'f(x) + 1' + assert new_eq.implicit_dims[0].name == 'xs' + assert new_eq.implicit_dims[0].factor.data == 4 + class TestAdvanced: