diff --git a/fixedpointmath/fixed_point_math.py b/fixedpointmath/fixed_point_math.py index 58e65ed..c52fca4 100644 --- a/fixedpointmath/fixed_point_math.py +++ b/fixedpointmath/fixed_point_math.py @@ -66,32 +66,40 @@ def isclose(a: NUMERIC, b: NUMERIC, abs_tol: NUMERIC = FixedPoint("0.0")) -> boo return abs(a - b) <= abs_tol -def maximum(x: NUMERIC, y: NUMERIC) -> NUMERIC: - """Compare the two inputs and return the greater value. +def maximum(*args: NUMERIC) -> NUMERIC: + """Compare the inputs and return the greatest value. If the first argument equals the second, return the first. """ - if isinstance(x, FixedPoint) and x.is_nan(): - return x - if isinstance(y, FixedPoint) and y.is_nan(): - return y - if x >= y: - return x - return y - - -def minimum(x: NUMERIC, y: NUMERIC) -> NUMERIC: - """Compare the two inputs and return the lesser value. + # use builtin for generic types + if isinstance(args[0], (float, int)): + return type(args[0])(max(*args)) + # else, we're FixedPoint + current_max = FixedPoint("-inf") + for arg in args: + if isinstance(arg, FixedPoint) and arg.is_nan(): # any nan means minimum is nan + return arg + if arg >= current_max: # pylint: disable=consider-using-max-builtin + current_max = arg + return type(args[0])(current_max) + + +def minimum(*args: NUMERIC) -> NUMERIC: + """Compare the inputs and return the lowest value. If the first argument equals the second, return the first. """ - if isinstance(x, FixedPoint) and x.is_nan(): - return x - if isinstance(y, FixedPoint) and y.is_nan(): - return y - if x <= y: - return x - return y + # use builtin for generic types + if isinstance(args[0], (int, float)): + return type(args[0])(min(*args)) + # else, we're FixedPoint + current_min = FixedPoint("inf") + for arg in args: + if isinstance(arg, FixedPoint) and arg.is_nan(): # any nan means minimum is nan + return arg + if arg <= current_min: # pylint: disable=consider-using-min-builtin + current_min = arg + return type(args[0])(current_min) def sqrt(x: NUMERIC) -> NUMERIC: diff --git a/tests/test_fp_math.py b/tests/test_fp_math.py index 9503e9f..58efdce 100644 --- a/tests/test_fp_math.py +++ b/tests/test_fp_math.py @@ -54,12 +54,15 @@ def test_minimum(self): assert minimum(0, 1) == 0 assert minimum(-1, 1) == -1 assert minimum(-1, -3) == -3 + assert minimum(-1, 0, -3) == -3 assert minimum(-1.0, -3.0) == -3.0 assert minimum(1.0, 3.0) == 1.0 - assert minimum(FixedPoint(1.0), FixedPoint(3.0)) == FixedPoint(1.0) + assert minimum(1.0, 3.0, 0.5) == 0.5 + assert minimum(FixedPoint("1.0"), FixedPoint("3.0")) == FixedPoint("1.0") assert minimum(FixedPoint("3.0"), FixedPoint(scaled_value=int(3e18 - 1e-17))) == FixedPoint( scaled_value=int(3e18 - 1e-17) ) + assert minimum(FixedPoint("1.0"), FixedPoint("-100.0"), FixedPoint("3.0")) == FixedPoint("-100.0") def test_minimum_nonfinite(self): """Test minimum method.""" @@ -74,9 +77,11 @@ def test_maximum(self): assert maximum(0, 1) == 1 assert maximum(-1, 1) == 1 assert maximum(-1, -3) == -1 - assert maximum(-1.0, -3.0) == -1.0 + assert maximum(-1, 0, -3) == 0 + assert maximum(-1.0, 0.0, -3.0) == 0.0 assert maximum(1.0, 3.0) == 3.0 - assert maximum(FixedPoint(1.0), FixedPoint(3.0)) == FixedPoint(3.0) + assert maximum(FixedPoint("1.0"), FixedPoint("3.0")) == FixedPoint("3.0") + assert maximum(FixedPoint("1.0"), FixedPoint("100.0"), FixedPoint("3.0")) == FixedPoint("100") assert maximum(FixedPoint("3.0"), FixedPoint(scaled_value=int(3e18 - 1e-17))) == FixedPoint(3.0) def test_maximum_nonfinite(self):