diff --git a/brian2/core/functions.py b/brian2/core/functions.py index 30c897a34..ec906f019 100644 --- a/brian2/core/functions.py +++ b/brian2/core/functions.py @@ -707,11 +707,102 @@ def __init__(self, name, sympy_obj, value): ################################################################################ +def _log1p(x): + if x.is_zero: + return S.Zero + else: + return None + + +class log1p(sympy_Function): + """ + Represents ``log(1 + x)``. + + The benefit of using ``log1p(x)`` over ``log(x + 1)`` + is that the latter is imprecise when x is close to zero. + """ + + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return 1 / (1 + self.args[0]) + else: + raise sympy.ArgumentIndexError(self, argindex) + + def _eval_expand_func(self, **hints): + return _log1p(*self.args) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + return _log1p(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_exp + + @classmethod + def eval(cls, arg): + return _log1p(arg) + + def _eval_is_real(self): + return self.args[0].is_real + + def _eval_is_finite(self): + return self.args[0].is_finite + + +def _expm1(x): + if x.is_zero: + return S.Zero + else: + return None + + +class expm1(sympy_Function): + """ + Represents ``exp(x) - 1``. + + The benefit of using ``expm1(x)`` over ``exp(x) - 1`` + is that the latter is prone to cancellation under finite precision + arithmetic when x is close to zero. + """ + + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return sympy.exp(*self.args) + else: + raise sympy.ArgumentIndexError(self, argindex) + + def _eval_expand_func(self, **hints): + return _expm1(*self.args) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + return _expm1(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_exp + + @classmethod + def eval(cls, arg): + return _expm1(arg) + + def _eval_is_real(self): + return self.args[0].is_real + + def _eval_is_finite(self): + return self.args[0].is_finite + + def _exprel(x): if x.is_zero: return S.One else: - return (sympy.exp(x) - S.One) / x + return None class exprel(sympy_Function): @@ -741,23 +832,13 @@ def _eval_expand_func(self, **hints): return _exprel(*self.args) def _eval_rewrite_as_exp(self, arg, **kwargs): - if arg.is_zero: - return S.One - else: - return (sympy.exp(arg) - S.One) / arg + return _exprel(arg) _eval_rewrite_as_tractable = _eval_rewrite_as_exp @classmethod def eval(cls, arg): - if arg is None: - return None - if arg.is_zero: - return S.One - - exp_arg = sympy.exp.eval(arg) - if exp_arg is not None: - return (exp_arg - S.One) / arg + return _exprel(arg) def _eval_is_real(self): return self.args[0].is_real @@ -829,9 +910,9 @@ def timestep(t, dt): unitsafe.log, sympy_func=sympy.functions.elementary.exponential.log ), "log10": Function(unitsafe.log10, sympy_func=sympy_cfunctions.log10), - "expm1": Function(unitsafe.expm1, sympy_func=sympy_cfunctions.expm1), + "expm1": Function(unitsafe.expm1, sympy_func=expm1), "exprel": Function(unitsafe.exprel, sympy_func=exprel), - "log1p": Function(unitsafe.log1p, sympy_func=sympy_cfunctions.log1p), + "log1p": Function(unitsafe.log1p, sympy_func=log1p), "sqrt": Function( np.sqrt, sympy_func=sympy.functions.elementary.miscellaneous.sqrt,