Skip to content

Commit

Permalink
Do not eval exprel, expm1, log1p
Browse files Browse the repository at this point in the history
Fixes #1350
  • Loading branch information
mstimberg committed Jun 6, 2023
1 parent 20c8ff4 commit c670be6
Showing 1 changed file with 96 additions and 15 deletions.
111 changes: 96 additions & 15 deletions brian2/core/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,11 +707,102 @@ def __init__(self, name, sympy_obj, value):
################################################################################


def _log1p(x):
if x.is_zero:
return S.Zero
else:
return None


class log1p(sympy_Function):
"""
Represents ``log(1 + x)``.
The benefit of using ``log1p(x)`` over ``log(x + 1)``
is that the latter is imprecise when x is close to zero.
"""

nargs = 1

def fdiff(self, argindex=1):
"""
Returns the first derivative of this function.
"""
if argindex == 1:
return 1 / (1 + self.args[0])
else:
raise sympy.ArgumentIndexError(self, argindex)

def _eval_expand_func(self, **hints):
return _log1p(*self.args)

def _eval_rewrite_as_exp(self, arg, **kwargs):
return _log1p(arg)

_eval_rewrite_as_tractable = _eval_rewrite_as_exp

@classmethod
def eval(cls, arg):
return _log1p(arg)

def _eval_is_real(self):
return self.args[0].is_real

def _eval_is_finite(self):
return self.args[0].is_finite


def _expm1(x):
if x.is_zero:
return S.Zero
else:
return None


class expm1(sympy_Function):
"""
Represents ``exp(x) - 1``.
The benefit of using ``expm1(x)`` over ``exp(x) - 1``
is that the latter is prone to cancellation under finite precision
arithmetic when x is close to zero.
"""

nargs = 1

def fdiff(self, argindex=1):
"""
Returns the first derivative of this function.
"""
if argindex == 1:
return sympy.exp(*self.args)
else:
raise sympy.ArgumentIndexError(self, argindex)

def _eval_expand_func(self, **hints):
return _expm1(*self.args)

def _eval_rewrite_as_exp(self, arg, **kwargs):
return _expm1(arg)

_eval_rewrite_as_tractable = _eval_rewrite_as_exp

@classmethod
def eval(cls, arg):
return _expm1(arg)

def _eval_is_real(self):
return self.args[0].is_real

def _eval_is_finite(self):
return self.args[0].is_finite


def _exprel(x):
if x.is_zero:
return S.One
else:
return (sympy.exp(x) - S.One) / x
return None


class exprel(sympy_Function):
Expand Down Expand Up @@ -741,23 +832,13 @@ def _eval_expand_func(self, **hints):
return _exprel(*self.args)

def _eval_rewrite_as_exp(self, arg, **kwargs):
if arg.is_zero:
return S.One
else:
return (sympy.exp(arg) - S.One) / arg
return _exprel(arg)

_eval_rewrite_as_tractable = _eval_rewrite_as_exp

@classmethod
def eval(cls, arg):
if arg is None:
return None
if arg.is_zero:
return S.One

exp_arg = sympy.exp.eval(arg)
if exp_arg is not None:
return (exp_arg - S.One) / arg
return _exprel(arg)

def _eval_is_real(self):
return self.args[0].is_real
Expand Down Expand Up @@ -829,9 +910,9 @@ def timestep(t, dt):
unitsafe.log, sympy_func=sympy.functions.elementary.exponential.log
),
"log10": Function(unitsafe.log10, sympy_func=sympy_cfunctions.log10),
"expm1": Function(unitsafe.expm1, sympy_func=sympy_cfunctions.expm1),
"expm1": Function(unitsafe.expm1, sympy_func=expm1),
"exprel": Function(unitsafe.exprel, sympy_func=exprel),
"log1p": Function(unitsafe.log1p, sympy_func=sympy_cfunctions.log1p),
"log1p": Function(unitsafe.log1p, sympy_func=log1p),
"sqrt": Function(
np.sqrt,
sympy_func=sympy.functions.elementary.miscellaneous.sqrt,
Expand Down

0 comments on commit c670be6

Please sign in to comment.