Skip to content

Commit

Permalink
Codegen fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Oct 3, 2024
1 parent 26c301c commit 86cbfc4
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 40 deletions.
2 changes: 2 additions & 0 deletions pyop3/array/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ def maps(self):
dropped_ckeys = set()

# TODO: are dropped_rkeys and dropped_ckeys still needed?
# FIXME: this whole thing falls apart if we have multiple loop contexts
loop_index = just_one(self.block_raxes.outer_loops)

iterset = AxisTree(loop_index.iterset.node_map)

rmap_axes = iterset.add_subtree(self.block_raxes, *iterset.leaf)
Expand Down
76 changes: 50 additions & 26 deletions pyop3/ir/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ContextAwareLoop,
DummyKernelArgument,
Loop,
LoopList,
PetscMatAdd,
PetscMatInstruction,
PetscMatLoad,
Expand Down Expand Up @@ -307,6 +308,7 @@ def set_temporary_shapes(self, shapes):

class CodegenResult:
def __init__(self, expr, ir, arg_replace_map, *, compiler_parameters):
# NOTE: should this be iterable?
self.expr = as_tuple(expr)
self.ir = ir
self.arg_replace_map = arg_replace_map
Expand All @@ -315,7 +317,7 @@ def __init__(self, expr, ir, arg_replace_map, *, compiler_parameters):

@cached_property
def datamap(self):
return merge_dicts(e.datamap for e in self.expr)
return merge_dicts(e.preprocessed.datamap for e in self.expr)

def __call__(self, **kwargs):
data_args = []
Expand Down Expand Up @@ -399,35 +401,41 @@ def parse_compiler_parameters(compiler_parameters) -> CompilerParameters:
def compile(expr: Instruction, compiler_parameters=None):
compiler_parameters = parse_compiler_parameters(compiler_parameters)

# preprocess expr before lowering
from pyop3.transform import expand_implicit_pack_unpack, expand_loop_contexts

function_name = expr.name

cs_expr = expand_loop_contexts(expr)
if isinstance(expr, LoopList):
cs_expr = expr.loops
else:
assert isinstance(expr, Loop), "other types not handled yet"
cs_expr = (expr,)

ctx = LoopyCodegenContext()
for context, ex in cs_expr:
ex = expand_implicit_pack_unpack(ex)
# NOTE: so I think LoopCollection is a better abstraction here - don't want to be
# explicitly dealing with contexts at this point. Can always sniff them out again.
# for context, ex in cs_expr:
for ex in cs_expr:
# ex = expand_implicit_pack_unpack(ex)

# add external loop indices as kernel arguments
# FIXME: removed because cs_expr needs to sniff the context now
loop_indices = {}
for index, (path, _) in context.items():
if len(path) > 1:
raise NotImplementedError("needs to be sorted")

# dummy = HierarchicalArray(index.iterset, data=NullBuffer(IntType))
dummy = HierarchicalArray(Axis(1), dtype=IntType)
# this is dreadful, pass an integer array instead
ctx.add_argument(dummy)
myname = ctx.actual_to_kernel_rename_map[dummy.name]
replace_map = {
axis: pym.subscript(pym.var(myname), (i,))
for i, axis in enumerate(path.keys())
}
# FIXME currently assume that source and target exprs are the same, they are not!
loop_indices[index] = (replace_map, replace_map)

for e in as_tuple(ex):
# for index, (path, _) in context.items():
# if len(path) > 1:
# raise NotImplementedError("needs to be sorted")
#
# # dummy = HierarchicalArray(index.iterset, data=NullBuffer(IntType))
# dummy = HierarchicalArray(Axis(1), dtype=IntType)
# # this is dreadful, pass an integer array instead
# ctx.add_argument(dummy)
# myname = ctx.actual_to_kernel_rename_map[dummy.name]
# replace_map = {
# axis: pym.subscript(pym.var(myname), (i,))
# for i, axis in enumerate(path.keys())
# }
# # FIXME currently assume that source and target exprs are the same, they are not!
# loop_indices[index] = (replace_map, replace_map)

for e in as_tuple(ex): # TODO: get rid of this loop
# context manager?
ctx.set_temporary_shapes(_collect_temporary_shapes(e))
_compile(e, loop_indices, ctx)
Expand Down Expand Up @@ -497,6 +505,7 @@ def _collect_temporary_shapes(expr):
raise TypeError(f"No handler defined for {type(expr).__name__}")


# TODO: get rid of this type
@_collect_temporary_shapes.register
def _(expr: ContextAwareLoop):
shapes = {}
Expand All @@ -512,6 +521,20 @@ def _(expr: ContextAwareLoop):
return shapes


@_collect_temporary_shapes.register
def _(expr: Loop):
shapes = {}
for stmt in expr.statements:
for temp, shape in _collect_temporary_shapes(stmt).items():
if shape is None:
continue
if temp in shapes:
assert shapes[temp] == shape
else:
shapes[temp] = shape
return shapes


@_collect_temporary_shapes.register
def _(expr: Assignment):
return pmap()
Expand Down Expand Up @@ -539,9 +562,10 @@ def _compile(expr: Any, loop_indices, ctx: LoopyCodegenContext) -> None:
raise TypeError(f"No handler defined for {type(expr).__name__}")


@_compile.register
@_compile.register(ContextAwareLoop) # remove
@_compile.register(Loop)
def _(
loop: ContextAwareLoop,
loop,
loop_indices,
codegen_context: LoopyCodegenContext,
) -> None:
Expand Down
65 changes: 55 additions & 10 deletions pyop3/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class Intent(enum.Enum):
NA = Intent.NA


class ExpressionState(enum.Enum):
"""Enum indicating the state of an expression (preprocessed or not)."""
INITIAL = "initial"
PREPROCESSED = "preprocessed"


# TODO: This exception is not actually ever raised. We should check the
# intents of the kernel arguments and complain if something illegal is
# happening.
Expand Down Expand Up @@ -101,7 +107,32 @@ def kernel_dtype(self):


class Instruction(UniqueRecord, abc.ABC):
pass
fields = UniqueRecord.fields | {"state"}

@cached_property
def preprocessed(self):
from pyop3.transform import expand_implicit_pack_unpack, expand_loop_contexts

if self.state == ExpressionState.PREPROCESSED:
return self
else:
insn = self
insn = expand_loop_contexts(insn)
insn = expand_implicit_pack_unpack(insn)
# TODO: should make marking things as preprocessed be an extra stage, currently do in expand_loop_contexts which is a bug
return insn

@property
def is_preprocessed(self):
return self.state == ExpressionState.PREPROCESSED

@cached_property
def loopy_code(self):
from pyop3.ir.lower import compile

return compile(self.preprocessed)




class ContextAwareInstruction(Instruction):
Expand Down Expand Up @@ -132,22 +163,25 @@ def __init__(
*,
name: str = _DEFAULT_LOOP_NAME,
compiler_parameters=None,
state: ExpressionState = ExpressionState.INITIAL,
**kwargs,
):
super().__init__(**kwargs)
self.index = index
self.statements = as_tuple(statements)
self.name = name
self.compiler_parameters = compiler_parameters
self.state = state

def __call__(self, **kwargs):
# TODO just parse into ContextAwareLoop and call that
from pyop3.ir.lower import compile
from pyop3.itree.tree import partition_iterset

code = compile(self, compiler_parameters=self.compiler_parameters)
code = compile(self.preprocessed, compiler_parameters=self.compiler_parameters)

if self.is_parallel:
if False:
# if self.is_parallel:
# FIXME: The partitioning code does not seem to always run properly
# so for now do all the transfers in advance.
# interleave computation and communication
Expand Down Expand Up @@ -205,12 +239,6 @@ def __call__(self, **kwargs):
with PETSc.Log.Event(f"compute_{self.name}_serial"):
code(**kwargs)

@cached_property
def loopy_code(self):
from pyop3.ir.lower import compile

return compile(self)

@cached_property
def is_parallel(self):
from pyop3.buffer import DistributedBuffer
Expand Down Expand Up @@ -355,7 +383,10 @@ def _init_nil():

@cached_property
def datamap(self):
return self.index.datamap | merge_dicts(stmt.datamap for stmt in self.statements)
if self.is_preprocessed:
return self.index.datamap | merge_dicts(stmt.datamap for stmt in self.statements)
else:
return self.preprocessed.datamap


class ContextAwareLoop(ContextAwareInstruction):
Expand All @@ -379,6 +410,20 @@ def loopy_code(self):
return compile(self)


class LoopList(Instruction):
fields = Instruction.fields | {"loops"}

def __init__(self, loops, *, name=_DEFAULT_LOOP_NAME, state=ExpressionState.INITIAL, **kwargs):
super().__init__(**kwargs)
self.loops = loops
self.name = name
self.state = ExpressionState(state)

@cached_property
def datamap(self):
return merge_dicts(l.datamap for l in self.loops)


# TODO singledispatch
# TODO perhaps this is simply "has non unit stride"?
def _has_nontrivial_stencil(array):
Expand Down
17 changes: 13 additions & 4 deletions pyop3/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DummyKernelArgument,
Instruction,
Loop,
LoopList,
Pack,
PetscMatAdd,
PetscMatInstruction,
Expand All @@ -33,7 +34,7 @@
ReplaceAssignment,
Terminal,
)
from pyop3.utils import UniqueNameGenerator, checked_zip, just_one
from pyop3.utils import UniqueNameGenerator, checked_zip, just_one, single_valued


# TODO Is this generic for other parsers/transformers? Esp. lower.py
Expand Down Expand Up @@ -127,12 +128,16 @@ def _(self, loop: Loop, *, context):
statements[source_path].append(mystmt)

# FIXME this does not propagate inner outer contexts
loop = ContextAwareLoop(
# NOTE: also I think this is redundant, just use a Loop!!!
csloop = ContextAwareLoop(
loop.index.copy(iterset=cf_iterset),
statements,
)
loops.append((octx, loop))
return tuple(loops)
# NOTE: outer context now needs sniffing out, makes the objects nicer
# loops.append((octx, loop))
loops.append(csloop)

return LoopList(loops, name=loop.name, state="preprocessed")

@_apply.register
def _(self, terminal: CalledFunction, *, context):
Expand Down Expand Up @@ -247,6 +252,10 @@ def _(self, loop: ContextAwareLoop):
),
)

@_apply.register
def _(self, loop_list: LoopList):
return loop_list.copy(loops=[loop_ for loop in loop_list.loops for loop_ in self._apply(loop)])

# TODO: Should be the same as Assignment
@_apply.register
def _(self, assignment: PetscMatInstruction):
Expand Down

0 comments on commit 86cbfc4

Please sign in to comment.