Skip to content

Commit

Permalink
Merge pull request #1369 from spcl/sym-attr
Browse files Browse the repository at this point in the history
Support attributes in symbolic expressions
  • Loading branch information
alexnick83 authored Sep 28, 2023
2 parents a582261 + 306d7a9 commit 3e73304
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 62 deletions.
8 changes: 5 additions & 3 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ class ForScope(ControlFlow):
init_edges: List[InterstateEdge] #: All initialization edges

def as_cpp(self, codegen, symbols) -> str:

sdfg = self.guard.parent

# Initialize to either "int i = 0" or "i = 0" depending on whether
# the type has been defined
defined_vars = codegen.dispatcher.defined_vars
Expand All @@ -369,9 +372,8 @@ def as_cpp(self, codegen, symbols) -> str:
init = self.itervar
else:
init = f'{symbols[self.itervar]} {self.itervar}'
init += ' = ' + self.init

sdfg = self.guard.parent
init += ' = ' + unparse_interstate_edge(self.init_edges[0].data.assignments[self.itervar],
sdfg, codegen=codegen)

preinit = ''
if self.init_edges:
Expand Down
8 changes: 6 additions & 2 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def ptr(name: str, desc: data.Data, sdfg: SDFG = None, framecode=None) -> str:
from dace.codegen.targets.framecode import DaCeCodeGenerator # Avoid import loop
framecode: DaCeCodeGenerator = framecode

if '.' in name:
root = name.split('.')[0]
if root in sdfg.arrays and isinstance(sdfg.arrays[root], data.Structure):
name = name.replace('.', '->')

# Special case: If memory is persistent and defined in this SDFG, add state
# struct to name
if (desc.transient and desc.lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External)):
Expand Down Expand Up @@ -992,8 +997,7 @@ def _Name(self, t: ast.Name):
if t.id not in self.sdfg.arrays:
return super()._Name(t)

# Replace values with their code-generated names (for example,
# persistent arrays)
# Replace values with their code-generated names (for example, persistent arrays)
desc = self.sdfg.arrays[t.id]
self.write(ptr(t.id, desc, self.sdfg, self.codegen))

Expand Down
18 changes: 6 additions & 12 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def __init__(self, frame_codegen, sdfg):
def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''):
for k, v in struct.members.items():
if isinstance(v, data.Structure):
_visit_structure(v, args, f'{prefix}.{k}')
_visit_structure(v, args, f'{prefix}->{k}')
elif isinstance(v, data.StructArray):
_visit_structure(v.stype, args, f'{prefix}.{k}')
_visit_structure(v.stype, args, f'{prefix}->{k}')
elif isinstance(v, data.Data):
args[f'{prefix}.{k}'] = v
args[f'{prefix}->{k}'] = v

# Keeps track of generated connectors, so we know how to access them in nested scopes
arglist = dict(self._frame.arglist)
Expand Down Expand Up @@ -221,8 +221,8 @@ def allocate_view(self, sdfg: SDFG, dfg: SDFGState, state_id: int, node: nodes.A
if isinstance(v, data.Data):
ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype
defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer
self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef)
self._dispatcher.defined_vars.add(f"{name}.{k}", defined_type, ctypedef)
self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef)
self._dispatcher.defined_vars.add(f"{name}->{k}", defined_type, ctypedef)
# TODO: Find a better way to do this (the issue is with pointers of pointers)
if atype.endswith('*'):
atype = atype[:-1]
Expand Down Expand Up @@ -299,9 +299,6 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d
name = node.data
alloc_name = cpp.ptr(name, nodedesc, sdfg, self._frame)
name = alloc_name
# NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and
# NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows.
alloc_name = alloc_name.replace('.', '->')

if nodedesc.transient is False:
return
Expand Down Expand Up @@ -331,7 +328,7 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d
if isinstance(v, data.Data):
ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype
defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer
self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef)
self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef)
self.allocate_array(sdfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, function_stream,
declaration_stream, allocation_stream)
return
Expand Down Expand Up @@ -1184,9 +1181,6 @@ def memlet_definition(self,
if not types:
types = self._dispatcher.defined_vars.get(ptr, is_global=True)
var_type, ctypedef = types
# NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and
# NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows.
ptr = ptr.replace('.', '->')

if fpga.is_fpga_array(desc):
decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces")
Expand Down
3 changes: 3 additions & 0 deletions dace/codegen/tools/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ def _infer_dtype(t: Union[ast.Name, ast.Attribute]):

def _Attribute(t, symbols, inferred_symbols):
inferred_type = _dispatch(t.value, symbols, inferred_symbols)
if (isinstance(inferred_type, dtypes.pointer) and isinstance(inferred_type.base_type, dtypes.struct) and
t.attr in inferred_type.base_type.fields):
return inferred_type.base_type.fields[t.attr]
return inferred_type


Expand Down
112 changes: 67 additions & 45 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ def eval(cls, x, y):
def _eval_is_boolean(self):
return True


class IfExpr(sympy.Function):

@classmethod
Expand Down Expand Up @@ -724,6 +725,19 @@ class IsNot(sympy.Function):
pass


class Attr(sympy.Function):
"""
Represents a get-attribute call on a function, equivalent to ``a.b`` in Python.
"""

@property
def free_symbols(self):
return {sympy.Symbol(str(self))}

def __str__(self):
return f'{self.args[0]}.{self.args[1]}'


def sympy_intdiv_fix(expr):
""" Fix for SymPy printing out reciprocal values when they should be
integral in "ceiling/floor" sympy functions.
Expand Down Expand Up @@ -927,10 +941,9 @@ def _process_is(elem: Union[Is, IsNot]):
return expr


class SympyBooleanConverter(ast.NodeTransformer):
class PythonOpToSympyConverter(ast.NodeTransformer):
"""
Replaces boolean operations with the appropriate SymPy functions to avoid
non-symbolic evaluation.
Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation.
"""
_ast_to_sympy_comparators = {
ast.Eq: 'Eq',
Expand All @@ -946,12 +959,37 @@ class SympyBooleanConverter(ast.NodeTransformer):
ast.NotIn: 'NotIn',
}

_ast_to_sympy_functions = {
ast.BitAnd: 'BitwiseAnd',
ast.BitOr: 'BitwiseOr',
ast.BitXor: 'BitwiseXor',
ast.Invert: 'BitwiseNot',
ast.LShift: 'LeftShift',
ast.RShift: 'RightShift',
ast.FloorDiv: 'int_floor',
}

def visit_UnaryOp(self, node):
if isinstance(node.op, ast.Not):
func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return node
elif isinstance(node.op, ast.Invert):
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()),
node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BinOp(self, node):
if type(node.op) in self._ast_to_sympy_functions:
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()),
node)
new_node = ast.Call(func=func_node,
args=[self.visit(value) for value in (node.left, node.right)],
keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BoolOp(self, node):
func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node)
Expand All @@ -971,8 +1009,7 @@ def visit_Compare(self, node: ast.Compare):
raise NotImplementedError
op = node.ops[0]
arguments = [node.left, node.comparators[0]]
func_node = ast.copy_location(
ast.Name(id=SympyBooleanConverter._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[])
return ast.copy_location(new_node, node)

Expand All @@ -985,41 +1022,28 @@ def visit_NameConstant(self, node):
return self.visit_Constant(node)

def visit_IfExp(self, node):
new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load), args=[node.test, node.body, node.orelse], keywords=[])
new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load),
args=[self.visit(node.test),
self.visit(node.body),
self.visit(node.orelse)],
keywords=[])
return ast.copy_location(new_node, node)

class BitwiseOpConverter(ast.NodeTransformer):
"""
Replaces C/C++ bitwise operations with functions to avoid sympification to boolean operations.
"""
_ast_to_sympy_functions = {
ast.BitAnd: 'BitwiseAnd',
ast.BitOr: 'BitwiseOr',
ast.BitXor: 'BitwiseXor',
ast.Invert: 'BitwiseNot',
ast.LShift: 'LeftShift',
ast.RShift: 'RightShift',
ast.FloorDiv: 'int_floor',
}

def visit_UnaryOp(self, node):
if isinstance(node.op, ast.Invert):
func_node = ast.copy_location(
ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BinOp(self, node):
if type(node.op) in BitwiseOpConverter._ast_to_sympy_functions:
func_node = ast.copy_location(
ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node,
args=[self.visit(value) for value in (node.left, node.right)],

def visit_Subscript(self, node):
if isinstance(node.value, ast.Attribute):
attr = ast.Subscript(value=ast.Name(id=node.value.attr, ctx=ast.Load()), slice=node.slice, ctx=ast.Load())
new_node = ast.Call(func=ast.Name(id='Attr', ctx=ast.Load),
args=[self.visit(node.value.value), self.visit(attr)],
keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_Attribute(self, node):
new_node = ast.Call(func=ast.Name(id='Attr', ctx=ast.Load),
args=[self.visit(node.value), ast.Name(id=node.attr, ctx=ast.Load)],
keywords=[])
return ast.copy_location(new_node, node)


@lru_cache(maxsize=16384)
def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
Expand Down Expand Up @@ -1071,21 +1095,17 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
'int_ceil': int_ceil,
'IfExpr': IfExpr,
'Mod': sympy.Mod,
'Attr': Attr,
}
# _clash1 enables all one-letter variables like N as symbols
# _clash also allows pi, beta, zeta and other common greek letters
locals.update(_sympy_clash)

if isinstance(expr, str):
# Sympy processes "not/and/or" as direct evaluation. Replace with
# And/Or(x, y), Not(x)
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b', expr):
expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0]))

# NOTE: If the expression contains bitwise operations, replace them with user-functions.
# NOTE: Sympy does not support bitwise operations and converts them to boolean operations.
if re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]', expr):
expr = unparse(BitwiseOpConverter().visit(ast.parse(expr).body[0]))
# Sympy processes "not/and/or" as direct evaluation. Replace with And/Or(x, y), Not(x)
# Also replaces bitwise operations with user-functions since SymPy does not support bitwise operations.
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b|[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]|[\.]', expr):
expr = unparse(PythonOpToSympyConverter().visit(ast.parse(expr).body[0]))

# TODO: support SymExpr over-approximated expressions
try:
Expand Down Expand Up @@ -1126,6 +1146,8 @@ def _print_Function(self, expr):
return f'(({self._print(expr.args[0])}) and ({self._print(expr.args[1])}))'
if str(expr.func) == 'OR':
return f'(({self._print(expr.args[0])}) or ({self._print(expr.args[1])}))'
if str(expr.func) == 'Attr':
return f'{self._print(expr.args[0])}.{self._print(expr.args[1])}'
return super()._print_Function(expr)

def _print_Mod(self, expr):
Expand Down
47 changes: 47 additions & 0 deletions tests/sdfg/data/structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,52 @@ def test_direct_read_structure():
assert np.allclose(B, ref)


def test_direct_read_structure_loops():

M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
name='CSRMatrix')

sdfg = dace.SDFG('csr_to_dense_direct_loops')

sdfg.add_datadesc('A', csr_obj)
sdfg.add_array('B', [M, N], dace.float32)

state = sdfg.add_state()

indices = state.add_access('A.indices')
data = state.add_access('A.data')
B = state.add_access('B')

t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val')
state.add_edge(indices, None, t, 'j', dace.Memlet(data='A.indices', subset='idx'))
state.add_edge(data, None, t, '__val', dace.Memlet(data='A.data', subset='idx'))
state.add_edge(t, '__out', B, None, dace.Memlet(data='B', subset='0:M, 0:N', volume=1))

idx_before, idx_guard, idx_after = sdfg.add_loop(None, state, None, 'idx', 'A.indptr[i]', 'idx < A.indptr[i+1]', 'idx + 1')
i_before, i_guard, i_after = sdfg.add_loop(None, idx_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=idx_after)

func = sdfg.compile()

rng = np.random.default_rng(42)
A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng)
B = np.zeros((20, 20), dtype=np.float32)

inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0],
indices=A.indices.__array_interface__['data'][0],
data=A.data.__array_interface__['data'][0],
rows=A.shape[0],
cols=A.shape[1],
M=A.shape[0],
N=A.shape[1],
nnz=A.nnz)

func(A=inpA, B=B, M=20, N=20, nnz=A.nnz)
ref = A.toarray()

assert np.allclose(B, ref)


def test_direct_read_nested_structure():
M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz'))
csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]),
Expand Down Expand Up @@ -505,3 +551,4 @@ def test_direct_read_nested_structure():
test_write_nested_structure()
test_direct_read_structure()
test_direct_read_nested_structure()
test_direct_read_structure_loops()

0 comments on commit 3e73304

Please sign in to comment.