Skip to content

Commit

Permalink
Lots of tests passing with better loop context logic
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Sep 27, 2023
1 parent 7eb9b8d commit 2e1a1ed
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 113 deletions.
32 changes: 14 additions & 18 deletions pyop3/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def layout_exprs_per_component(self):
pass


class ContextSensitive(pytools.ImmutableRecord, abc.ABC):
class ContextSensitive(abc.ABC):
# """Container of `IndexTree`s distinguished by outer loop information.
#
# This class is required because multi-component outer loops can lead to
Expand All @@ -119,39 +119,34 @@ class ContextSensitive(pytools.ImmutableRecord, abc.ABC):
# """
#
# def __init__(self, index_trees: pmap[pmap[LoopIndex, pmap[str, str]], IndexTree]):
fields = {"values"} # bad name

def __init__(self, values):
super().__init__()
# this is terribly unclear
if not is_single_valued([set(key.keys()) for key in values.keys()]):
raise ValueError("Loop contexts must contain the same loop indices")

assert all(isinstance(v, ContextFree) for v in values.values())

self.values = pmap(values)

@functools.cached_property
def keys(self):
# loop is used just for unpacking
for context in self.values.keys():
for context in self.context_map.keys():
indices = set()
for loop_index in context.keys():
indices.add(loop_index)
return frozenset(indices)

@property
@abc.abstractmethod
def context_map(self):
pass

def with_context(self, context):
key = {}
for loop_index, path in context.items():
if loop_index in self.keys:
key |= {loop_index: path}
key = pmap(key)
return self.values[key]
return self.context_map[key]


class ContextFree(ContextSensitive, abc.ABC):
def with_context(self, context):
return self
@property
def context_map(self):
return pmap({pmap(): self})


class ExpressionEvaluator(pym.mapper.evaluator.EvaluationMapper):
Expand Down Expand Up @@ -778,19 +773,20 @@ def __init__(

def __getitem__(self, indices):
if indices is Ellipsis:
raise NotImplementedError("TODO needs to return a full slice, not self")
return self
# FIXME
from pyop3.distarray.multiarray import IndexExpressionReplacer
from pyop3.index import (
IndexedAxisTree,
as_index_forest,
collect_loop_context,
collect_loop_contexts,
index_axes,
)

# FIXME I have a weird double loop here over loop contexts
axis_trees = {}
loop_contexts = collect_loop_context(indices)
loop_contexts = collect_loop_contexts(indices)
if not loop_contexts:
loop_contexts = [pmap()]
for loop_context in loop_contexts:
Expand Down
4 changes: 2 additions & 2 deletions pyop3/codegen/loopexpr2loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def _(
codegen_context: LoopyCodegenContext,
) -> None:
loop_context = {}
for loop_index, (path, _) in loop_indices.items():
loop_context[loop_index] = pmap(path)
for loop_index, (source_path, target_path, _, _) in loop_indices.items():
loop_context[loop_index] = source_path, target_path
loop_context = pmap(loop_context)

iterset = loop.index.iterset.with_context(loop_context)
Expand Down
Loading

0 comments on commit 2e1a1ed

Please sign in to comment.