Skip to content

Commit

Permalink
Merge pull request #4 from connorjward/local-indices
Browse files Browse the repository at this point in the history
Add local indexing and enumerate
  • Loading branch information
connorjward authored Aug 17, 2023
2 parents b625ee3 + 67dd221 commit e1a8b9c
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 68 deletions.
34 changes: 32 additions & 2 deletions pyop3/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,9 @@ def __init__(
def __getitem__(self, indices):
return as_axis_tree(self)[indices]

def __call__(self, *args):
return as_axis_tree(self)(*args)

def __str__(self) -> str:
return f"{self.__class__.__name__}([{', '.join(str(cpt) for cpt in self.components)}], label={self.label})"

Expand All @@ -578,6 +581,30 @@ def count(self):
def index(self):
return as_axis_tree(self).index()

def enumerate(self):
return as_axis_tree(self).enumerate()


# this is supposed to be used in place of an array to represent the offset
# of an axis at a given index
class CalledAxisTree:
def __init__(self, axes, indices):
self.axes = axes
self.indices = indices

# FIXME
@property
def name(self):
return "my_called_axis"

@property
def dtype(self):
return IntType

@functools.cached_property
def datamap(self):
return self.axes.datamap | self.indices.datamap


class AxisTree(LabelledTree):
def __init__(
Expand All @@ -602,11 +629,14 @@ def __getitem__(self, indices):

return IndexedAxisTree(self, indices)

def __call__(self, *args):
return CalledAxisTree(self, *args)

def index(self):
# cyclic import
from pyop3.index import LoopIndex
from pyop3.index import GlobalLoopIndex

return LoopIndex(self)
return GlobalLoopIndex(self)

def enumerate(self):
# cyclic import
Expand Down
85 changes: 62 additions & 23 deletions pyop3/codegen/loopexpr2loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@
from pyrsistent import pmap

from pyop3 import tlang, utils
from pyop3.axis import AffineLayout, Axis, AxisComponent, AxisTree, TabulatedLayout
from pyop3.axis import (
AffineLayout,
Axis,
AxisComponent,
AxisTree,
CalledAxisTree,
TabulatedLayout,
)
from pyop3.distarray import IndexedMultiArray, MultiArray
from pyop3.dtypes import IntType
from pyop3.index import (
AffineMapComponent,
AffineSliceComponent,
CalledMap,
GlobalLoopIndex,
Index,
Indexed,
IndexedAxisTree,
Expand Down Expand Up @@ -493,6 +501,7 @@ def _finalize_parse_loop_rec(
loop.index: (
new_path,
array_expr_per_leaf[current_axis.id, axcpt.label],
new_jnames, # "local index"
)
},
codegen_ctx,
Expand Down Expand Up @@ -542,7 +551,9 @@ def _(call: FunctionCall, loop_indices, ctx: LoopyCodegenContext) -> None:
temporaries.append((arg, indexed_temp, spec.access, shape))

# Register data
ctx.add_argument(arg.name, arg.dtype)
if not isinstance(arg, CalledAxisTree):
ctx.add_argument(arg.name, arg.dtype)

ctx.add_temporary(temporary.name, temporary.dtype, shape)

# subarrayref nonsense/magic
Expand Down Expand Up @@ -663,15 +674,18 @@ def build_assignment(

# unroll the index trees, this should be tidied up
array = assignment.array
itrees = []
while isinstance(array, Indexed):
itrees.insert(0, array.itree)
array = array.obj
if isinstance(array, CalledAxisTree):
itrees = [array.indices]
else:
itrees = []
while isinstance(array, Indexed):
itrees.insert(0, array.itree)
array = array.obj

for indices in itrees:
# get the right index tree given the loop context
loop_context = {}
for loop_index, (path, _) in loop_indices.items():
for loop_index, (path, _, _) in loop_indices.items():
loop_context[loop_index] = pmap(path)
loop_context = pmap(loop_context)
index_tree = indices[loop_context]
Expand Down Expand Up @@ -731,6 +745,10 @@ def _prepare_assignment_rec(
jnames: pmap,
ctx: LoopyCodegenContext,
) -> tuple[pmap, pmap]:
# catch empty axis trees
# if axes.is_empty:
# return pmap(), pmap({None: 0}), pmap({None: ()})

jnames_per_axcpt = {}
insns_per_leaf = {}
array_expr_per_leaf = {}
Expand All @@ -754,9 +772,15 @@ def _prepare_assignment_rec(
insns_per_leaf |= subinsns_per_leaf
array_expr_per_leaf |= subarray_expr_per_leaf
else:
insns, array_expr = _assignment_array_insn(
assignment, axes, new_path, new_jnames, ctx
)
if isinstance(assignment.array, CalledAxisTree):
# just the offset instructions here, no subscript
insns, array_expr = emit_assignment_insn(
axes, new_path, new_jnames, ctx
)
else:
insns, array_expr = _assignment_array_insn(
assignment, axes, new_path, new_jnames, ctx
)
insns_per_leaf[axis.id, axcpt.label] = insns
array_expr_per_leaf[axis.id, axcpt.label] = array_expr

Expand Down Expand Up @@ -899,8 +923,7 @@ def _parse_assignment_final(
insns_per_leaf,
ctx: LoopyCodegenContext,
):
# catch empty axes here
if not axes.root:
if not axes.root: # catch empty axes here
for insn in insns_per_leaf[None]:
ctx.add_assignment(*insn)
array_expr = array_expr_per_leaf[None]
Expand Down Expand Up @@ -962,11 +985,18 @@ def _parse_assignment_final_rec(
for insn in insns_per_leaf[axis.id, axcpt.label]:
ctx.add_assignment(*insn)
array_expr = array_expr_per_leaf[axis.id, axcpt.label]
temp_insns, temp_expr = _assignment_temp_insn(
assignment, new_path, new_jnames, ctx
)
for insn in temp_insns:
ctx.add_assignment(*insn)
if isinstance(assignment.array, CalledAxisTree):
temp_insns, temp_expr = _assignment_temp_insn(
assignment, pmap(), pmap(), ctx
)
for insn in temp_insns:
ctx.add_assignment(*insn)
else:
temp_insns, temp_expr = _assignment_temp_insn(
assignment, new_path, new_jnames, ctx
)
for insn in temp_insns:
ctx.add_assignment(*insn)
_shared_assignment_insn(assignment, array_expr, temp_expr, ctx)


Expand All @@ -983,7 +1013,13 @@ def _expand_index(index, *, loop_indices, codegen_ctx, **kwargs):

@_expand_index.register
def _(index: SplitLoopIndex, *, loop_indices, codegen_ctx, **kwargs):
path, jname_exprs = loop_indices[index.loop_index]
global_index = index.loop_index
if isinstance(global_index, GlobalLoopIndex):
path, jname_exprs, _ = loop_indices[global_index]
else:
assert isinstance(global_index, LocalLoopIndex)
global_index = global_index.global_index
path, _, jname_exprs = loop_indices[global_index]
insns = ()
# TODO namedtuple anyone?
return {None: (path, jname_exprs, insns)}, {"axes": AxisTree(), "jnames": pmap()}
Expand Down Expand Up @@ -1232,7 +1268,6 @@ def _assignment_array_insn(assignment, axes, path, jnames, ctx):
"""
offset_insns, array_offset = emit_assignment_insn(
assignment.array.name,
axes,
path,
jnames,
Expand All @@ -1252,7 +1287,6 @@ def _assignment_temp_insn(assignment, path, jnames, ctx):
"""
offset_insns, temp_offset = emit_assignment_insn(
assignment.temporary.name,
assignment.temporary.axes,
path,
jnames,
Expand Down Expand Up @@ -1290,7 +1324,6 @@ def _shared_assignment_insn(assignment, array_expr, temp_expr, ctx):


def emit_assignment_insn(
array_name,
axes,
path,
labels_to_jnames,
Expand Down Expand Up @@ -1469,7 +1502,10 @@ def collect_shape_index_callback(index, *args, **kwargs):

@collect_shape_index_callback.register
def _(loop_index: SplitLoopIndex, *, loop_indices, **kwargs):
path = loop_indices[loop_index.loop_index][0]
global_index = loop_index.loop_index
if isinstance(global_index, LocalLoopIndex):
global_index = global_index.global_index
path = loop_indices[global_index][0]
return {None: (path,)}, {"axes": AxisTree()}


Expand Down Expand Up @@ -1594,10 +1630,13 @@ def _indexed_axes(indexed, loop_indices):
axis tree. No instructions need be emitted.
"""
# offsets are always scalar
if isinstance(indexed, CalledAxisTree):
return AxisTree()

# get the right index tree given the loop context
loop_context = {}
for loop_index, (path, _) in loop_indices.items():
for loop_index, (path, _, _) in loop_indices.items():
loop_context[loop_index] = pmap(path)
loop_context = pmap(loop_context)

Expand Down
79 changes: 55 additions & 24 deletions pyop3/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,37 @@ def __init__(self, index_trees: pmap[pmap[LoopIndex, pmap[str, str]], IndexTree]
# this is terribly unclear
if not is_single_valued([set(key.keys()) for key in index_trees.keys()]):
raise ValueError("Loop contexts must contain the same loop indices")
self.index_trees = index_trees

new_index_trees = {}
for key, itree in index_trees.items():
new_key = {}
for loop_index, path in key.items():
if isinstance(loop_index, LocalLoopIndex):
loop_index = loop_index.global_index
new_key[loop_index] = path
new_index_trees[pmap(new_key)] = itree
self.index_trees = pmap(new_index_trees)

def __getitem__(self, loop_context):
key = {}
for loop_index, path in loop_context.items():
if isinstance(loop_index, LocalLoopIndex):
loop_index = loop_index.global_index
if loop_index in self.loop_indices:
key |= {loop_index: path}
key = pmap(key)
return self.index_trees[key]

@functools.cached_property
def loop_indices(self) -> frozenset[LoopIndex]:
# loop is used just for unpacking
for loop_context in self.index_trees.keys():
return frozenset(loop_context.keys())
indices = set()
for loop_index in loop_context.keys():
if isinstance(loop_index, LocalLoopIndex):
loop_index = loop_index.global_index
indices.add(loop_index)
return frozenset(indices)

@functools.cached_property
def datamap(self):
Expand Down Expand Up @@ -97,10 +114,11 @@ def __init__(self, obj, indices):
indices = as_split_index_tree(indices, axes=axes, loop_context=loop_ctx)
split_indices |= indices.index_trees
for loop_ctx_, itree in indices.index_trees.items():
# nasty hack because _indexed_axes currently expects a 2-tuple per loop index
# nasty hack because _indexed_axes currently expects a 3-tuple per loop index
assert set(loop_ctx.keys()) <= set(loop_ctx_.keys())
my_loop_context = {
idx: (path, "not used") for idx, path in loop_ctx_.items()
idx: (path, "not used", "not used")
for idx, path in loop_ctx_.items()
}
split_axes[loop_ctx_] = _indexed_axes((axes, indices), my_loop_context)

Expand Down Expand Up @@ -253,14 +271,37 @@ def __init__(self, map, from_index):
self.from_index = from_index


class LoopIndex:
class LoopIndex(abc.ABC):
pass


class GlobalLoopIndex(LoopIndex):
def __init__(self, iterset):
self.iterset = iterset

@property
def target_paths(self):
return self.iterset.target_paths

@property
def datamap(self):
return self.iterset.datamap


class LocalLoopIndex(LoopIndex):
"""Class representing a 'local' index."""

def __init__(self, loop_index: LoopIndex):
self.global_index = loop_index

@property
def target_paths(self):
return self.global_index.target_paths

@property
def datamap(self):
return self.global_index.datamap


class Index(LabelledNode):
@property
Expand Down Expand Up @@ -303,22 +344,6 @@ def target_paths(self) -> tuple[pmap]:
# )

@functools.cached_property
def datamap(self):
return self.loop_index.iterset.datamap


class LocalLoopIndex(Index):
"""Class representing a 'local' index."""

def __init__(self, loop_index: LoopIndex, **kwargs):
super().__init__(loop_index.component_labels, **kwargs)
self.loop_index = loop_index

@property
def target_paths(self):
return self.loop_index.target_paths

@property
def datamap(self):
return self.loop_index.datamap

Expand Down Expand Up @@ -403,8 +428,11 @@ def name(self):

class EnumeratedLoopIndex:
def __init__(self, iterset: AxisTree):
self.index = LoopIndex(iterset)
self.count = LocalLoopIndex(self.index)
global_index = GlobalLoopIndex(iterset)
local_index = LocalLoopIndex(global_index)

self.global_index = global_index
self.local_index = local_index


# it is probably a better pattern to give axis trees a "parent" option
Expand All @@ -423,7 +451,10 @@ def target_paths(self):
return tuple(paths.values())

def index(self):
return LoopIndex(self)
return GlobalLoopIndex(self)

def enumerate(self):
return EnumeratedLoopIndex(self)

@functools.cached_property
def datamap(self):
Expand Down
Loading

0 comments on commit e1a8b9c

Please sign in to comment.