Skip to content

Commit

Permalink
Add context sensitive dat and hashable trees
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed May 2, 2024
1 parent f8625bf commit 01cfcd0
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 30 deletions.
20 changes: 17 additions & 3 deletions pyop3/array/harray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyop3.array.base import Array
from pyop3.axtree import (
Axis,
ContextSensitive,
AxisTree,
ContextFree,
as_axis_tree,
Expand Down Expand Up @@ -138,7 +139,7 @@ def __init__(
# always deal with flattened data
if len(data.shape) > 1:
data = data.flatten()
if data.size != axes.alloc_size:
if data.size != axes.unindexed.global_size:
raise ValueError("Data shape does not match axes")

# IndexedAxisTrees do not currently have SFs, so create a dummy one here
Expand All @@ -147,10 +148,10 @@ def __init__(
else:
assert isinstance(axes, (ContextSensitiveAxisTree, IndexedAxisTree))
# not sure this is the right thing to do
sf = serial_forest(axes.alloc_size)
sf = serial_forest(axes.unindexed.global_size)

data = DistributedBuffer(
axes.alloc_size, # not a useful property anymore
axes.unindexed.global_size, # not a useful property anymore
sf,
dtype,
name=self.name,
Expand Down Expand Up @@ -528,3 +529,16 @@ class MultiArray(HierarchicalArray):
@deprecated("HierarchicalArray")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class ContextSensitiveDat(ContextSensitive):
"""Class for describing arrays that are different within different loop contexts.
This is useful for the case where one wants to pass a small array through as
part of a context-sensitive assignment.
"""

@property
def dtype(self):
return self._shared_attr("dtype")
23 changes: 17 additions & 6 deletions pyop3/array/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyrsistent import freeze, pmap

from pyop3.array.base import Array
from pyop3.array.harray import HierarchicalArray
from pyop3.array.harray import HierarchicalArray, ContextSensitiveDat
from pyop3.axtree.tree import (
AxisTree,
ContextSensitiveAxisTree,
Expand Down Expand Up @@ -261,11 +261,22 @@ def assign(self, other, *, eager=True):
# TODO: Check axes match between self and other
expr = PetscMatStore(self, other)
elif isinstance(other, numbers.Number):
static = HierarchicalArray(
self.axes,
data=np.full(self.axes.alloc_size, other, dtype=self.dtype),
constant=True,
)
if isinstance(self.axes, ContextSensitiveAxisTree):
cs_dats = {}
for context, axes in self.axes.context_map.items():
cs_dat = HierarchicalArray(
axes,
data=np.full(axes.size, other, dtype=self.dtype),
constant=True,
)
cs_dats[context] = cs_dat
static = ContextSensitiveDat(cs_dats)
else:
static = HierarchicalArray(
self.axes,
data=np.full(self.axes.alloc_size, other, dtype=self.dtype),
constant=True,
)
expr = PetscMatStore(self, static)
else:
raise NotImplementedError
Expand Down
16 changes: 8 additions & 8 deletions pyop3/axtree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def filter_context(self, context):
key.update({loop_index: freeze(path)})
return freeze(key)

def _shared_attr(self, attr: str):
return single_valued(getattr(a, attr) for a in self.context_map.values())

# this is basically just syntactic sugar, might not be needed
# avoids the need for
Expand Down Expand Up @@ -1164,9 +1166,9 @@ def __init__(
self._layout_exprs = pmap(layout_exprs)
self._outer_loops = tuple(outer_loops)

# @cached_property
# def _hash_key(self):
# return super()._hash_key + (self.unindexed, self.target_paths, self.index_exprs, self.layout_exprs, self.outer_loops)
@cached_property
def _hash_key(self):
return super()._hash_key + (self.unindexed, self.target_paths, self.index_exprs, self.layout_exprs, self.outer_loops)

@property
def unindexed(self):
Expand Down Expand Up @@ -1354,11 +1356,9 @@ def datamap(self):
def sf(self):
return single_valued([ax.sf for ax in self.context_map.values()])

# @cached_property
# def unindexed(self):
# this does not work because unindexed may have different IDs, so just return
# the first one.
# return single_valued([ax.unindexed for ax in self.context_map.values()])
@cached_property
def unindexed(self):
return single_valued([ax.unindexed for ax in self.context_map.values()])

@cached_property
def context_free(self):
Expand Down
11 changes: 10 additions & 1 deletion pyop3/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def _array_updates(self):
"""
from pyop3 import DistributedBuffer, HierarchicalArray, Mat
from pyop3.array.harray import ContextSensitiveDat
from pyop3.array.petsc import Sparsity

initializers = []
Expand All @@ -260,6 +261,9 @@ def _array_updates(self):
initializers.extend(inits)
reductions.extend(reds)
broadcasts.extend(bcasts)
elif isinstance(arg, ContextSensitiveDat):
# assumed to not be distributed
pass
else:
assert isinstance(arg, (Mat, Sparsity))
# just in case
Expand Down Expand Up @@ -589,7 +593,12 @@ def __init__(self, mat_arg, array_arg):

@property
def kernel_arguments(self):
return (self.mat_arg.mat, self.array_arg.buffer)
args = (self.mat_arg.mat,)
if isinstance(self.array_arg, ContextSensitive):
args += tuple(dat.buffer for dat in self.array_arg.context_map.values())
else:
args += (self.array_arg.buffer,)
return args

@property
def datamap(self):
Expand Down
48 changes: 36 additions & 12 deletions pyop3/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
import collections
import functools
import itertools
import operator
from collections import defaultdict
from collections.abc import Hashable, Sequence
Expand Down Expand Up @@ -251,18 +252,41 @@ def __init__(self, node_map=None):
# post-init checks
self._check_node_labels_unique_in_paths(self.node_map)

# This is arguably over-specific. Otherwise equivalent trees will currently
# fail this check if their IDs do not match. One way to resolve this would be
# to re-ID all of the nodes with a pre-order traversal. This is not a priority.
# def __eq__(self, other):
# return type(other) is type(self) and other._hash_key == self._hash_key
#
# def __hash__(self):
# return hash(self._hash_key)
#
# @cached_property
# def _hash_key(self):
# return (self.node_map,)
def __eq__(self, other):
return type(other) is type(self) and other._hash_key == self._hash_key

def __hash__(self):
return hash(self._hash_key)

@cached_property
def _hash_key(self):
return (self._hash_node_map,)

@cached_property
def _hash_node_map(self):
if self.is_empty:
return pmap()

counter = itertools.count()
return self._collect_hash_node_map(None, None, counter)

def _collect_hash_node_map(self, old_parent_id, new_parent_id, counter):
if old_parent_id not in self.node_map:
return pmap()

nodes = []
node_map = {}
for old_node in self.node_map[old_parent_id]:
if old_node is not None:
new_node = old_node.copy(id=f"id_{next(counter)}")
node_map.update(self._collect_hash_node_map(old_node.id, new_node.id, counter))
else:
new_node = None

nodes.append(new_node)

node_map[new_parent_id] = freeze(nodes)
return freeze(node_map)

@classmethod
def _check_node_labels_unique_in_paths(
Expand Down

0 comments on commit 01cfcd0

Please sign in to comment.