Skip to content

Commit

Permalink
Use C-order arrays where possible in numba extension types (#364)
Browse files Browse the repository at this point in the history
Previously only A-order arrays were used.
F-order arrays are not relevant, because coefficient arrays are 1d, and therefore either contiguous and C-order, or not-contiguous and therefore not F-order.

Closes #363, allowing multivector constants to be used in cached functions; only C-order arrays are considered candidates for caching, so previously this did not work.
  • Loading branch information
hugohadfield authored Mar 4, 2021
1 parent 599abf8 commit 31651d0
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 16 deletions.
10 changes: 7 additions & 3 deletions clifford/numba/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
class LayoutType(types.Dummy):
def __init__(self, layout):
self.obj = layout
# cache of multivector types for this layout
self._cache = {}
# Caches of multivector types for this layout, in numba C and A order.
# Having two caches is faster than a cache keyed by a tuple of `(order, dt)`,
# and every millisecond counts in `MultiVector._numba_type_`.
self._c_cache = {}
self._a_cache = {}
layout_name = layout_short_name(layout)
if layout_name is not None:
name = "LayoutType({})".format(layout_name)
Expand All @@ -36,7 +39,8 @@ def __getstate__(self):
# the cache causes issues with numba's pickle modifications, as it
# contains a self-reference.
d = self.__dict__.copy()
d['_cache'] = {}
d['_c_cache'] = {}
d['_a_cache'] = {}
return d


Expand Down
32 changes: 19 additions & 13 deletions clifford/numba/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,14 @@
class MultiVectorType(types.Type):
def __init__(self, layout: LayoutType, dtype: types.DType):
self.layout_type = layout
self._scalar_type = dtype
self.value_type = dtype
super().__init__(name='MultiVector({!r}, {!r})'.format(
self.layout_type, self._scalar_type
self.layout_type, self.value_type
))

@property
def key(self):
return self.layout_type, self._scalar_type

@property
def value_type(self):
return self._scalar_type[:]
return self.layout_type, self.value_type


# The docs say we should use register a function to determine the numba type
Expand All @@ -53,19 +49,29 @@ def value_type(self):

@property
def _numba_type_(self):
# If the array is not 1D we can't do anything with it
if self.value.ndim != 1:
return None

layout_type = self.layout._numba_type_

cache = layout_type._cache
dt = self.value.dtype
if self.value.flags.c_contiguous:
relevant_cache = layout_type._c_cache
else:
relevant_cache = layout_type._a_cache

# now use the dtype to key that cache.
try:
return cache[dt]
return relevant_cache[dt]
except KeyError:
# Computing and hashing `dtype_type` is slow, so we do not use it as a
# Computing and hashing `value_type` is slow, so we do not use it as a
# hash key. The raw numpy dtype is much faster to use as a key.
dtype_type = _numpy_support.from_dtype(dt)
ret = cache[dt] = MultiVectorType(layout_type, dtype_type)
if self.value.flags.c_contiguous:
value_type = _numpy_support.from_dtype(dt)[::1]
else:
value_type = _numpy_support.from_dtype(dt)[:]
ret = relevant_cache[dt] = MultiVectorType(layout_type, value_type)
return ret

MultiVector._numba_type_ = _numba_type_
Expand All @@ -85,7 +91,7 @@ def __init__(self, dmm, fe_type):
def type_MultiVector(context):
def typer(layout, value):
if isinstance(layout, LayoutType) and isinstance(value, types.Array):
return MultiVectorType(layout, value.dtype)
return MultiVectorType(layout, value)
return typer


Expand Down
19 changes: 19 additions & 0 deletions clifford/test/test_function_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
from clifford._numba_utils import generated_jit
import pytest


@generated_jit(cache=True)
def foo(x):
from clifford.g3 import e3

def impl(x):
return (x * e3).value
return impl


# Make the test fail on a failed cache warning
@pytest.mark.filterwarnings("error")
def test_function_cache():
from clifford.g3 import e3
np.testing.assert_array_equal((1.0*e3).value, foo(1.0))
21 changes: 21 additions & 0 deletions clifford/test/test_numba_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,24 @@ def test_pickling():
e1._numba_type_ # int
(e1 * 1.0)._numba_type_ # float
assert pickle.loads(pickle.dumps(lt)) is lt


def test_A_order():
import numpy as np

@numba.njit
def mul_mv(mv):
return mv*e3

# A-order
mva = layout.MultiVector(np.ones(layout.gaDims))
mva.value = mva.value[::-1]
assert not mva.value.flags.c_contiguous
res_mva = mul_mv(mva)

# C-order
mvc = layout.MultiVector(np.ones(layout.gaDims))
assert mvc.value.flags.c_contiguous
res_mvc = mul_mv(mvc)

assert res_mva == res_mvc

0 comments on commit 31651d0

Please sign in to comment.