From 01cfcd0060ea0d5fd61c7c63ba1d1d1be31bceca Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Thu, 2 May 2024 10:51:45 +0100 Subject: [PATCH] Add context sensitive dat and hashable trees --- pyop3/array/harray.py | 20 +++++++++++++++--- pyop3/array/petsc.py | 23 +++++++++++++++------ pyop3/axtree/tree.py | 16 +++++++-------- pyop3/lang.py | 11 +++++++++- pyop3/tree.py | 48 ++++++++++++++++++++++++++++++++----------- 5 files changed, 88 insertions(+), 30 deletions(-) diff --git a/pyop3/array/harray.py b/pyop3/array/harray.py index 0db1e64..4af43f4 100644 --- a/pyop3/array/harray.py +++ b/pyop3/array/harray.py @@ -19,6 +19,7 @@ from pyop3.array.base import Array from pyop3.axtree import ( Axis, + ContextSensitive, AxisTree, ContextFree, as_axis_tree, @@ -138,7 +139,7 @@ def __init__( # always deal with flattened data if len(data.shape) > 1: data = data.flatten() - if data.size != axes.alloc_size: + if data.size != axes.unindexed.global_size: raise ValueError("Data shape does not match axes") # IndexedAxisTrees do not currently have SFs, so create a dummy one here @@ -147,10 +148,10 @@ def __init__( else: assert isinstance(axes, (ContextSensitiveAxisTree, IndexedAxisTree)) # not sure this is the right thing to do - sf = serial_forest(axes.alloc_size) + sf = serial_forest(axes.unindexed.global_size) data = DistributedBuffer( - axes.alloc_size, # not a useful property anymore + axes.unindexed.global_size, # not a useful property anymore sf, dtype, name=self.name, @@ -528,3 +529,16 @@ class MultiArray(HierarchicalArray): @deprecated("HierarchicalArray") def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + +class ContextSensitiveDat(ContextSensitive): + """Class for describing arrays that are different within different loop contexts. + + This is useful for the case where one wants to pass a small array through as + part of a context-sensitive assignment. + + """ + + @property + def dtype(self): + return self._shared_attr("dtype") diff --git a/pyop3/array/petsc.py b/pyop3/array/petsc.py index 20cbc6c..f1e2b6b 100644 --- a/pyop3/array/petsc.py +++ b/pyop3/array/petsc.py @@ -12,7 +12,7 @@ from pyrsistent import freeze, pmap from pyop3.array.base import Array -from pyop3.array.harray import HierarchicalArray +from pyop3.array.harray import HierarchicalArray, ContextSensitiveDat from pyop3.axtree.tree import ( AxisTree, ContextSensitiveAxisTree, @@ -261,11 +261,22 @@ def assign(self, other, *, eager=True): # TODO: Check axes match between self and other expr = PetscMatStore(self, other) elif isinstance(other, numbers.Number): - static = HierarchicalArray( - self.axes, - data=np.full(self.axes.alloc_size, other, dtype=self.dtype), - constant=True, - ) + if isinstance(self.axes, ContextSensitiveAxisTree): + cs_dats = {} + for context, axes in self.axes.context_map.items(): + cs_dat = HierarchicalArray( + axes, + data=np.full(axes.size, other, dtype=self.dtype), + constant=True, + ) + cs_dats[context] = cs_dat + static = ContextSensitiveDat(cs_dats) + else: + static = HierarchicalArray( + self.axes, + data=np.full(self.axes.alloc_size, other, dtype=self.dtype), + constant=True, + ) expr = PetscMatStore(self, static) else: raise NotImplementedError diff --git a/pyop3/axtree/tree.py b/pyop3/axtree/tree.py index e51db40..3819fb9 100644 --- a/pyop3/axtree/tree.py +++ b/pyop3/axtree/tree.py @@ -114,6 +114,8 @@ def filter_context(self, context): key.update({loop_index: freeze(path)}) return freeze(key) + def _shared_attr(self, attr: str): + return single_valued(getattr(a, attr) for a in self.context_map.values()) # this is basically just syntactic sugar, might not be needed # avoids the need for @@ -1164,9 +1166,9 @@ def __init__( self._layout_exprs = pmap(layout_exprs) self._outer_loops = tuple(outer_loops) - # @cached_property - # def _hash_key(self): - # return super()._hash_key + (self.unindexed, self.target_paths, self.index_exprs, self.layout_exprs, self.outer_loops) + @cached_property + def _hash_key(self): + return super()._hash_key + (self.unindexed, self.target_paths, self.index_exprs, self.layout_exprs, self.outer_loops) @property def unindexed(self): @@ -1354,11 +1356,9 @@ def datamap(self): def sf(self): return single_valued([ax.sf for ax in self.context_map.values()]) - # @cached_property - # def unindexed(self): - # this does not work because unindexed may have different IDs, so just return - # the first one. - # return single_valued([ax.unindexed for ax in self.context_map.values()]) + @cached_property + def unindexed(self): + return single_valued([ax.unindexed for ax in self.context_map.values()]) @cached_property def context_free(self): diff --git a/pyop3/lang.py b/pyop3/lang.py index dde8aa4..1573d72 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -242,6 +242,7 @@ def _array_updates(self): """ from pyop3 import DistributedBuffer, HierarchicalArray, Mat + from pyop3.array.harray import ContextSensitiveDat from pyop3.array.petsc import Sparsity initializers = [] @@ -260,6 +261,9 @@ def _array_updates(self): initializers.extend(inits) reductions.extend(reds) broadcasts.extend(bcasts) + elif isinstance(arg, ContextSensitiveDat): + # assumed to not be distributed + pass else: assert isinstance(arg, (Mat, Sparsity)) # just in case @@ -589,7 +593,12 @@ def __init__(self, mat_arg, array_arg): @property def kernel_arguments(self): - return (self.mat_arg.mat, self.array_arg.buffer) + args = (self.mat_arg.mat,) + if isinstance(self.array_arg, ContextSensitive): + args += tuple(dat.buffer for dat in self.array_arg.context_map.values()) + else: + args += (self.array_arg.buffer,) + return args @property def datamap(self): diff --git a/pyop3/tree.py b/pyop3/tree.py index 814b35b..495cfca 100644 --- a/pyop3/tree.py +++ b/pyop3/tree.py @@ -3,6 +3,7 @@ import abc import collections import functools +import itertools import operator from collections import defaultdict from collections.abc import Hashable, Sequence @@ -251,18 +252,41 @@ def __init__(self, node_map=None): # post-init checks self._check_node_labels_unique_in_paths(self.node_map) - # This is arguably over-specific. Otherwise equivalent trees will currently - # fail this check if their IDs do not match. One way to resolve this would be - # to re-ID all of the nodes with a pre-order traversal. This is not a priority. - # def __eq__(self, other): - # return type(other) is type(self) and other._hash_key == self._hash_key - # - # def __hash__(self): - # return hash(self._hash_key) - # - # @cached_property - # def _hash_key(self): - # return (self.node_map,) + def __eq__(self, other): + return type(other) is type(self) and other._hash_key == self._hash_key + + def __hash__(self): + return hash(self._hash_key) + + @cached_property + def _hash_key(self): + return (self._hash_node_map,) + + @cached_property + def _hash_node_map(self): + if self.is_empty: + return pmap() + + counter = itertools.count() + return self._collect_hash_node_map(None, None, counter) + + def _collect_hash_node_map(self, old_parent_id, new_parent_id, counter): + if old_parent_id not in self.node_map: + return pmap() + + nodes = [] + node_map = {} + for old_node in self.node_map[old_parent_id]: + if old_node is not None: + new_node = old_node.copy(id=f"id_{next(counter)}") + node_map.update(self._collect_hash_node_map(old_node.id, new_node.id, counter)) + else: + new_node = None + + nodes.append(new_node) + + node_map[new_parent_id] = freeze(nodes) + return freeze(node_map) @classmethod def _check_node_labels_unique_in_paths(