Skip to content

Commit

Permalink
Merge pull request #1527 from brian-team/numpy2_compat
Browse files Browse the repository at this point in the history
Numpy 2 compatibility
  • Loading branch information
mstimberg authored Apr 22, 2024
2 parents 7d31328 + 66e636b commit 70ec299
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 17 deletions.
1 change: 0 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ jobs:
- name: Build source tarball
run: |
python -m pip install --upgrade pip build
python -m pip install "cython>=0.29" oldest-supported-numpy "setuptools>=61" "setuptools_scm[toml]>=6.2"
python -m build --sdist --config-setting=--formats=gztar --config-setting=--with-cython --config-setting=--fail-on-error
if: ${{ matrix.arch == 'auto64' }}
- name: Set up QEMU
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ Try out Brian on the [mybinder](https://mybinder.org/) service:
## Dependencies
The following packages need to be installed to use Brian 2 (cf. [`pyproject.toml`](pyproject.toml)):

* Python >= 3.9
* NumPy >=1.21
* Python >= 3.10
* NumPy >=1.23
* SymPy >= 1.2
* Cython >= 0.29.21
* PyParsing
Expand Down
8 changes: 5 additions & 3 deletions brian2/codegen/runtime/numpy_rt/numpy_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ def __iter__(self):
return iter(self.indices)

# Allow conversion to a proper array with np.array(...)
def __array__(self, dtype=None):
def __array__(self, dtype=None, copy=None):
if copy is False:
raise ValueError("LazyArange does not support copy=False")
if self.indices is None:
return np.arange(self.start, self.stop)
return np.arange(self.start, self.stop, dtype=dtype)
else:
return self.indices + self.start
return (self.indices + self.start).astype(dtype)

# Allow basic arithmetics (used when shifting stuff for subgroups)
def __add__(self, other):
Expand Down
14 changes: 11 additions & 3 deletions brian2/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,11 +1337,13 @@ def set_with_index_array(self, item, value, check_units):
variable.get_value()[indices] = value

# Allow some basic calculations directly on the ArrayView object
def __array__(self, dtype=None):
def __array__(self, dtype=None, copy=None):
try:
# This will fail for subexpressions that refer to external
# parameters
self[:]
values = self[:]
# Never hand over copy = None
return np.array(values, dtype=dtype, copy=copy is not False, subok=True)
except ValueError:
var = self.variable.name
raise ValueError(
Expand All @@ -1350,7 +1352,6 @@ def __array__(self, dtype=None):
f"variables, use 'group.{var}[:]' instead of "
f"'group.{var}'"
)
return np.asanyarray(self[:], dtype=dtype)

def __array__ufunc__(self, ufunc, method, *inputs, **kwargs):
if method == "__call__":
Expand Down Expand Up @@ -1434,6 +1435,13 @@ def __iadd__(self, other):
self[:] = rhs
return self

# Support matrix multiplication with @
def __matmul__(self, other):
return self.get_item(slice(None), level=1) @ np.asanyarray(other)

def __rmatmul__(self, other):
return np.asanyarray(other) @ self.get_item(slice(None), level=1)

def __isub__(self, other):
if isinstance(other, str):
raise TypeError(
Expand Down
6 changes: 5 additions & 1 deletion brian2/parsing/rendering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import numbers

import numpy as np
import sympy

from brian2.core.functions import DEFAULT_CONSTANTS, DEFAULT_FUNCTIONS
Expand Down Expand Up @@ -87,6 +88,9 @@ def render_Name(self, node):
return node.id

def render_Constant(self, node):
if isinstance(node.value, np.number):
# repr prints the dtype in numpy 2.0
return repr(node.value.item())
return repr(node.value)

def render_Call(self, node):
Expand Down Expand Up @@ -344,7 +348,7 @@ def render_Constant(self, node):
elif node.value is False:
return "false"
else:
return repr(node.value)
return super().render_Constant(node)

def render_Name(self, node):
if node.id == "inf":
Expand Down
21 changes: 21 additions & 0 deletions brian2/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brian2 import _cache_dirs_and_extensions, clear_cache, prefs
from brian2.codegen.codeobject import CodeObject
from brian2.codegen.cpp_prefs import compiler_supports_c99, get_compiler_and_args
from brian2.codegen.generators.cython_generator import CythonNodeRenderer
from brian2.codegen.optimisation import optimise_statements
from brian2.codegen.runtime.cython_rt import CythonCodeObject
from brian2.codegen.statements import Statement
Expand All @@ -22,6 +23,7 @@
from brian2.core.functions import DEFAULT_CONSTANTS, DEFAULT_FUNCTIONS, Function
from brian2.core.variables import ArrayVariable, Constant, Subexpression, Variable
from brian2.devices.device import auto_target, device
from brian2.parsing.rendering import CPPNodeRenderer, NodeRenderer, NumpyNodeRenderer
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2.units import ms, second
from brian2.units.fundamentalunits import Unit
Expand Down Expand Up @@ -618,6 +620,25 @@ def test_msvc_flags():
assert len(previously_stored_flags[hostname])


@pytest.mark.codegen_independent
@pytest.mark.parametrize(
"renderer",
[
NodeRenderer(),
NumpyNodeRenderer(),
CythonNodeRenderer(),
CPPNodeRenderer(),
],
)
def test_number_rendering(renderer):
import ast

for number in [0.5, np.float32(0.5), np.float64(0.5)]:
# In numpy 2.0, repr(np.float64(0.5)) is 'np.float64(0.5)'
node = ast.Constant(value=number)
assert renderer.render_node(node) == "0.5"


if __name__ == "__main__":
test_auto_target()
test_analyse_identifiers()
Expand Down
4 changes: 2 additions & 2 deletions examples/advanced/stochastic_odes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def exact_solution(t, dt):
methods = ['milstein', 'heun']
dts = [1*ms, 0.5*ms, 0.2*ms, 0.1*ms, 0.05*ms, 0.025*ms, 0.01*ms, 0.005*ms]

rows = floor(sqrt(len(dts)))
cols = ceil(1.0 * len(dts) / rows)
rows = int(sqrt(len(dts)))
cols = int(ceil(1.0 * len(dts) / rows))
errors = dict([(method, zeros(len(dts))) for method in methods])
for dt_idx, dt in enumerate(dts):
print('dt: %s' % dt)
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ authors = [
{name = 'Dan Goodman'},
{name ='Romain Brette'}
]
requires-python = '>=3.9'
requires-python = '>=3.10'
dependencies = [
'numpy>=1.21',
'numpy>=1.23.5',
'cython>=0.29.21',
'sympy>=1.2',
'pyparsing',
Expand Down Expand Up @@ -60,10 +60,11 @@ fallback_version = 'unknown'
[build-system]
requires = [
"setuptools>=61",
"numpy>=1.10",
# By building against numpy 2.0, we make sure that the wheel is compatible with
# both numpy 2.0 and numpy>=1.23
"numpy>=2.0.0rc1",
"wheel",
"Cython",
"oldest-supported-numpy",
"setuptools_scm[toml]>=6.2"
]
build-backend = "setuptools.build_meta"
Expand Down
1 change: 0 additions & 1 deletion versions.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ This file contains a list of other files where versions of packages are specifie
In the future it would be advantageous to implement an automated way of keep versions synchronised across files e.g. https://github.com/pre-commit/pre-commit/issues/945#issuecomment-527603460 or preferably parsing `.pre-commit-config.yaml` and using it to `pip install` requirements (see discussion here: https://github.com/brian-team/brian2/pull/1449#issuecomment-1372476018). Until then, the files are listed below for manual checking and updating.

* [`README.md`](https://github.com/brian-team/brian2/blob/master/README.md)
* [`setup.py`](https://github.com/brian-team/brian2/blob/master/setup.py)
* [`rtd-requirements.txt`](https://github.com/brian-team/brian2/blob/master/rtd-requirements.txt)
* [`pyproject.toml`](https://github.com/brian-team/brian2/blob/master/pyproject.toml)
* [`.pre-commit-config.yaml`](https://github.com/brian-team/brian2/blob/master/.pre-commit-config.yaml)
Expand Down

0 comments on commit 70ec299

Please sign in to comment.