Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental AST rewriter and JIT decorator #326

Open
wants to merge 27 commits into
base: experimental/abc-mangling
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9e3724c
Added numba overloaded functions to layout
hugohadfield Jun 5, 2020
03826c5
Added a GA specific ast transformer
hugohadfield Jun 5, 2020
ef61257
Added a jit_func decorator to ast transform and numba jit
hugohadfield Jun 5, 2020
c13bc94
Corrected jit_func, added a test
hugohadfield Jun 5, 2020
51f8a42
remove duplication in ast_transformer
hugohadfield Jun 5, 2020
8022092
convert to abstract numeric types in the numba jit overload
hugohadfield Jun 5, 2020
f14521b
Improved handling globals, added a TODO
hugohadfield Jun 5, 2020
5fdbb86
Added ast_pretty warning if not installed
hugohadfield Jun 5, 2020
d6c6e06
removed unnescary print
hugohadfield Jun 5, 2020
8094a61
Added reversion to AST rewriter and JIT
hugohadfield Jun 5, 2020
1767342
Added grade selection via the call syntax
hugohadfield Jun 5, 2020
81601ce
Set up pytest benchmark
hugohadfield Jun 6, 2020
d905393
Make node visitation recursive for Call
hugohadfield Jun 6, 2020
750ec85
Add ImportError type for astpretty
hugohadfield Jun 6, 2020
e0263f8
Improve warning whitespace
hugohadfield Jun 6, 2020
e878dbe
Make the Call rewrite exception an AttributeError
hugohadfield Jun 6, 2020
482b091
Moved the decorator removal to the AST level
hugohadfield Jun 6, 2020
ff9648d
Add scalar and multivector constants to decorator arguments
hugohadfield Jun 7, 2020
5d27874
Fix nested function call transformer
hugohadfield Jun 7, 2020
6c2cea6
Improve speed of linear_operator_to_matrix
hugohadfield Jun 7, 2020
307874f
Add testing for new jit decorator features
hugohadfield Jun 7, 2020
c5be87a
Added a nested jitted function test
hugohadfield Jun 8, 2020
8f02960
Fixed flake8 complaints
hugohadfield Jun 8, 2020
8e96d81
Apply suggestions from Eric code review
hugohadfield Jun 9, 2020
2315f3f
Fix up review comments
hugohadfield Jun 9, 2020
87a41b9
Moved jit_impls into jit_func
hugohadfield Jun 9, 2020
ccf5551
Moved jit_func into an experimental directory
hugohadfield Jun 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def linear_operator_as_matrix(func, input_blades, output_blades):
ndimout = len(output_blades)
mat = np.zeros((ndimout, ndimin))
for i, b in enumerate(input_blades):
mat[:, i] = np.array([func(b)[j] for j in output_blades])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably makes sense to spin out a quick PR with just this fix against master.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I was thinking that

b_result = func(b)
mat[:, i] = np.array([b_result[j] for j in output_blades])
return mat


Expand Down
60 changes: 60 additions & 0 deletions clifford/_ast_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import ast


class DecoratorRemover(ast.NodeTransformer):
""" Strip decorators from top-level FunctionDefs"""
def visit_FunctionDef(self, node):
node.decorator_list = []
return node


class GATransformer(ast.NodeTransformer):
"""
This is an AST transformer that converts operations into
JITable counterparts that work on MultiVector value arrays.
We crawl the AST and convert BinOps and UnaryOps into numba
overloaded functions.
"""
def visit_BinOp(self, node):
ops = {
ast.Mult: 'ga_mul',
ast.BitXor: 'ga_xor',
ast.BitOr: 'ga_or',
ast.Add: 'ga_add',
ast.Sub: 'ga_sub',
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.left), self.visit(node.right)],
keywords=[]
)

def visit_UnaryOp(self, node):
ops = {
ast.Invert: 'ga_rev'
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.operand)],
keywords=[]
)

def visit_Call(self, node):
if len(node.args) == 1:
node = self.generic_visit(node)
node.args = [node.func] + node.args
node.func = ast.Name(id='ga_call', ctx=ast.Load())
return node
else:
return self.generic_visit(node)
1 change: 1 addition & 0 deletions clifford/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import sparse


# TODO: move some of these functions to this file if they're not useful anywhere
# else
import clifford as cf
Expand Down
Empty file.
Loading