Skip to content

Commit

Permalink
Include center (#66)
Browse files Browse the repository at this point in the history
* new dawn proto

* centered reductions basically working

* adding tests, checks enforcing that include center is consistent

* make linter work

* remove pprinter

* Applied black to all files

* PR review

Co-authored-by: Ben Weber <[email protected]>
  • Loading branch information
mroethlin and BenWeber42 authored Jan 18, 2021
1 parent 1fac2ea commit d6bbfaf
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 41 deletions.
74 changes: 48 additions & 26 deletions dusk/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,14 @@ def temporary_field_declaration(self, name: str, field_type: expr):
def add_field_declaration(
self, name: str, field_type: expr, is_temporary: bool = False
):
field_type, hindex, vindex = self.field_type(field_type)

field_type, (include_center, hindex), vindex = self.field_type(field_type)
assert field_type in {"Field", "IndexField"}
DuskFieldType = DuskField if field_type == "Field" else DuskIndexField

if hindex is not None:
dimensions = make_field_dimensions_unstructured(hindex, vindex)
dimensions = make_field_dimensions_unstructured(
hindex, vindex, include_center
)
else:
dimensions = make_field_dimensions_vertical()

Expand Down Expand Up @@ -196,22 +197,39 @@ def add_field_declaration(
def field_type(self, field_type: str, hindex: expr = None, vindex: str = None):
return (
field_type,
self.location_chain(hindex) if hindex is not None else None,
self.location_chain(hindex) if hindex is not None else (False, None),
1 if vindex is not None else 0,
)

@transform(
OneOf(
name(Capture(str).append("locations")),
Compare(
left=name(Capture(str).append("locations")),
left=OneOf(
name(Capture(str).append("locations")),
BinOp(
# TODO: hardcoded strings
left=name(Capture("Origin").to("include_center")),
op=Add,
right=name(Capture(str).append("locations")),
),
),
ops=Repeat(Gt),
comparators=Repeat(name(Capture(str).append("locations"))),
),
)
)
def location_chain(self, locations: t.List):
return [self.location_type(location) for location in locations]
def location_chain(
self, locations: t.List, include_center: t.Optional[t.Literal["Origin"]] = None
):
does_include_center = include_center is not None
locations = [self.location_type(location) for location in locations]

if does_include_center and not self.ctx.location.is_ambiguous(locations):
raise DuskSyntaxError(
f"including the center is only allowed if start equals end location of the neighbor chain!"
)
return does_include_center, locations

@transform(Capture(str).to("name"))
def location_type(self, name: str):
Expand Down Expand Up @@ -406,12 +424,12 @@ def vertical_interval_bound(self, bound):
)
)
def loop_stmt(self, neighborhood, body: t.List):
neighborhood = self.location_chain(neighborhood)
include_center, neighborhood = self.location_chain(neighborhood)

with self.ctx.location.loop_stmt(neighborhood):
with self.ctx.location.loop_stmt(neighborhood, include_center):
body = self.statements(body)

return make_loop_stmt(body, neighborhood)
return make_loop_stmt(body, neighborhood, include_center)

@transform(Capture(expr).to("expr"))
def expression(self, expr: expr):
Expand Down Expand Up @@ -509,7 +527,9 @@ def field_index(self, field: DuskField, vindex=None, hindex=None):
voffset, vbase = (
self.relative_vertical_offset(vindex) if vindex is not None else (0, None)
)
hindex = self.location_chain(hindex) if hindex is not None else None
include_center, hindex = (
self.location_chain(hindex) if hindex is not None else (False, None)
)

if not self.ctx.location.in_neighbor_iteration:
if hindex is not None:
Expand All @@ -518,8 +538,7 @@ def field_index(self, field: DuskField, vindex=None, hindex=None):
"outside of neighbor iteration!"
)
return make_unstructured_offset(False), voffset, vbase

neighbor_iteration = self.ctx.location.current_neighbor_iteration
neighbor_chain = self.ctx.location.current_neighbor_iteration.chain
field_dimension = self.ctx.location.get_field_dimension(field.sir)

if hindex is not None and not self.ctx.location.is_dense(field_dimension):
Expand All @@ -535,16 +554,14 @@ def field_index(self, field: DuskField, vindex=None, hindex=None):

if hindex is None:
if self.ctx.location.is_dense(field_dimension):
if self.ctx.location.is_ambiguous(neighbor_iteration):
if self.ctx.location.is_ambiguous(neighbor_chain):
raise DuskSyntaxError(
f"Field '{field.sir.name}' requires a horizontal index "
"inside of ambiguous neighbor iteration!"
)

return (
make_unstructured_offset(
field_dimension[0] == neighbor_iteration[-1]
),
make_unstructured_offset(field_dimension[0] == neighbor_chain[-1]),
voffset,
vbase,
)
Expand All @@ -553,14 +570,23 @@ def field_index(self, field: DuskField, vindex=None, hindex=None):

# TODO: check if `hindex` is valid for this field's location type

if (
self.ctx.location.current_neighbor_iteration.include_center
!= include_center
):
raise DuskSyntaxError(
f"Invalid horizontal offset for field '{field.sir.name}'! "
"inconsistent center inclusion"
)

if len(hindex) == 1:
if neighbor_iteration[0] != hindex[0]:
if neighbor_chain[0] != hindex[0]:
raise DuskSyntaxError(
f"Invalid horizontal offset for field '{field.sir.name}'!"
)
return make_unstructured_offset(False), voffset, vbase

if hindex != neighbor_iteration:
if hindex != neighbor_chain:
raise DuskSyntaxError(
f"Invalid horizontal offset for field '{field.sir.name}'!"
)
Expand Down Expand Up @@ -835,8 +861,8 @@ def reduction(
if 0 < len(wrong_kwargs):
raise DuskSyntaxError(f"Unsupported kwargs '{wrong_kwargs}' in reduction!")

neighborhood = self.location_chain(neighborhood)
with self.ctx.location.reduction(neighborhood):
include_center, neighborhood = self.location_chain(neighborhood)
with self.ctx.location.reduction(neighborhood, include_center):
expr = self.expression(expr)

op_map = {"sum": "+", "mul": "*", "min": "min", "max": "max"}
Expand Down Expand Up @@ -871,9 +897,5 @@ def reduction(
weights = [self.expression(weight) for weight in kwargs["weights"].elts]

return make_reduction_over_neighbor_expr(
op,
expr,
init,
neighborhood,
weights,
op, expr, init, neighborhood, weights, include_center
)
5 changes: 5 additions & 0 deletions dusk/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"Edge",
"Cell",
"Vertex",
"Origin",
"K",
"Field",
"IndexField",
Expand Down Expand Up @@ -39,6 +40,10 @@ class Vertex(metaclass=internal.LocationType):
pass


class Origin(metaclass=internal.LocationType):
pass


class K:
pass

Expand Down
5 changes: 4 additions & 1 deletion dusk/script/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ def __new__(cls, name, bases, dict):
return super().__new__(cls, name, bases, dict)

def __gt__(cls, other_cls):
pass
return cls

def __add__(cls, other_cls):
return cls


class Slicable:
Expand Down
19 changes: 10 additions & 9 deletions dusk/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from contextlib import contextmanager
from itertools import chain
from collections import namedtuple

from dawn4py.serialization import SIR as sir

Expand Down Expand Up @@ -93,15 +94,15 @@ def new_scope(self):


LocationTypeValue = NewType("LocationTypeValue", int)
LocationChain = List[LocationTypeValue]
IterationSpace = namedtuple("IterationSpace", "chain, include_center")


class LocationHelper:

in_vertical_region: bool
in_loop_stmt: bool
in_reduction: bool
neighbor_iterations: List[LocationChain]
neighbor_iterations: List[IterationSpace]

@staticmethod
def is_dense(location_chain: LocationChain) -> bool:
Expand All @@ -128,7 +129,7 @@ def __init__(self):
self.neighbor_iterations = []

@property
def current_neighbor_iteration(self) -> LocationChain:
def current_neighbor_iteration(self) -> IterationSpace:
assert self.in_neighbor_iteration
return self.neighbor_iterations[-1]

Expand All @@ -149,7 +150,7 @@ def vertical_region(self):
self.in_vertical_region = False

@contextmanager
def _neighbor_iteration(self, location_chain: LocationChain):
def _neighbor_iteration(self, location_chain: LocationChain, include_center: bool):

if not self.in_vertical_region:
raise DuskSyntaxError(
Expand All @@ -162,27 +163,27 @@ def _neighbor_iteration(self, location_chain: LocationChain):
"length longer than 1!"
)

self.neighbor_iterations.append(location_chain)
self.neighbor_iterations.append(IterationSpace(location_chain, include_center))
yield
self.neighbor_iterations.pop()

@contextmanager
def loop_stmt(self, location_chain: LocationChain):
def loop_stmt(self, location_chain: LocationChain, include_center: bool):

if self.in_loop_stmt:
raise DuskSyntaxError("Nested loop statements aren't allowed!")
if self.in_reduction:
raise DuskSyntaxError("Loop statements can't occur inside reductions!")

self.in_loop_stmt = True
with self._neighbor_iteration(location_chain):
with self._neighbor_iteration(location_chain, include_center):
yield
self.in_loop_stmt = False

@contextmanager
def reduction(self, location_chain: LocationChain):
def reduction(self, location_chain: LocationChain, include_center: bool):
self.in_reduction = True
with self._neighbor_iteration(location_chain):
with self._neighbor_iteration(location_chain, include_center):
yield
self.in_reduction = False

Expand Down
1 change: 1 addition & 0 deletions tests/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from laplacian_fvm import laplacian_fvm
from interpolation_sph import interpolation_sph


def test_examples():
validate(pyast_to_sir(callable_to_pyast(laplacian_fd)))
validate(pyast_to_sir(callable_to_pyast(laplacian_fvm)))
Expand Down
11 changes: 10 additions & 1 deletion tests/stencils/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
def test_reduce():
validate(pyast_to_sir(callable_to_pyast(various_reductions)))
validate(pyast_to_sir(callable_to_pyast(kw_args)))
validate(pyast_to_sir(callable_to_pyast(reductions_with_center)))


@stencil
Expand Down Expand Up @@ -79,4 +80,12 @@ def kw_args(
b = reduce_over(Edge > Cell, d * 3, mul, weights=[-1, 1], init=1.0)
a = sum_over(Edge > Cell, c * 3, init=10.0, weights=[-1, 1])
b = min_over(Edge > Cell, c * 3, weights=[-1, 1], init=-100)
a = max_over(Edge > Cell, d * 3, init=723, weights=[-1, 1])
a = max_over(Edge > Cell, d * 3, init=723, weights=[-1, 1])


@stencil
def reductions_with_center(
a: Field[Edge], b: Field[Origin + Edge > Cell > Edge], c: Field[Edge]
):
with levels_downward:
a = sum_over(Origin + Edge > Cell > Edge, b * c[Origin + Edge > Cell > Edge])
11 changes: 11 additions & 0 deletions tests/stencils/test_sparse_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test_sparse_fill():
validate(pyast_to_sir(callable_to_pyast(longer_fills)))
validate(pyast_to_sir(callable_to_pyast(fill_with_reduction)))
validate(pyast_to_sir(callable_to_pyast(ambiguous_fill)))
validate(pyast_to_sir(callable_to_pyast(fill_with_center)))


@stencil
Expand Down Expand Up @@ -98,3 +99,13 @@ def ambiguous_fill(
with levels_downward:
with sparse[Edge > Vertex > Edge]:
sparse2 = edge2[Edge] - 4 * edge1[Edge > Vertex > Edge]


@stencil
def fill_with_center(
sparse1: Field[Origin + Edge > Cell > Edge],
edge: Field[Edge],
):
with levels_downward:
with sparse[Origin + Edge > Cell > Edge]:
sparse1 = edge[Origin + Edge > Cell > Edge]
8 changes: 4 additions & 4 deletions tests/stencils/test_vertical_index_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def various_expression(

edge_3d_field2 = sum_over(
Cell > Vertex > Cell > Edge,
tan(
sparse_3d_field3[sparse_3d_index_field3 + 3]
)
tan(sparse_3d_field3[sparse_3d_index_field3 + 3])
/ edge_3d_field1[edge_3d_index_field + 1]
+ reduce_over(
Edge > Cell > Vertex > Cell,
Expand Down Expand Up @@ -237,7 +235,9 @@ def sparse_index_fields(
sparse_3d_field1 = edge_2d_field[sparse_3d_index_field1]

edge_2d_field = sum_over(
Edge > Vertex > Cell, sparse_3d_field2[sparse_3d_index_field2], init=-10,
Edge > Vertex > Cell,
sparse_3d_field2[sparse_3d_index_field2],
init=-10,
)

with sparse[Edge > Vertex > Cell]:
Expand Down

0 comments on commit d6bbfaf

Please sign in to comment.