diff --git a/quinn/__init__.py b/quinn/__init__.py index feca32c..2d442c7 100644 --- a/quinn/__init__.py +++ b/quinn/__init__.py @@ -52,6 +52,10 @@ week_end_date, week_start_date, ) +from quinn.math import ( + rand_laplace, + rand_range, +) from quinn.schema_helpers import print_schema_as_code from quinn.split_columns import split_col from quinn.transformations import ( diff --git a/quinn/math.py b/quinn/math.py index 3ba9e54..61f2b63 100644 --- a/quinn/math.py +++ b/quinn/math.py @@ -47,6 +47,29 @@ def rand_laplace( ) +def rand_range( + minimum: Union[int, Column], + maximum: Union[int, Column], + seed: Optional[int] = None, +) -> Column: + """Generate random numbers uniformly distributed in [`minimum`, `maximum`). + + :param minimum: minimum value of the random numbers + :param maximum: maximum value of the random numbers + :param seed: random seed value (optional, default None) + :returns: column with random numbers + """ + if not isinstance(minimum, Column): + minimum = F.lit(minimum) + + if not isinstance(maximum, Column): + maximum = F.lit(maximum) + + u = F.rand(seed) + + return minimum + (maximum - minimum) * u + + def div_or_else( cola: Column, colb: Column, diff --git a/tests/test_math.py b/tests/test_math.py new file mode 100644 index 0000000..87c814a --- /dev/null +++ b/tests/test_math.py @@ -0,0 +1,40 @@ +import pyspark.sql.functions as F + +import quinn +import math +from .spark import spark + + +def test_rand_laplace(): + stats = ( + spark.range(100000) + .select(quinn.rand_laplace(0.0, 1.0, 42)) + .agg( + F.mean("laplace_random").alias("mean"), + F.stddev("laplace_random").alias("std_dev"), + ) + .first() + ) + + laplace_mean = stats["mean"] + laplace_stddev = stats["std_dev"] + + # Laplace distribution with mean=0.0 and scale=1.0 has mean=0.0 and stddev=sqrt(2.0) + assert abs(laplace_mean) <= 0.1 + assert abs(laplace_stddev - math.sqrt(2.0)) < 0.5 + + +def test_rand_range(): + lower_bound = 5 + upper_bound = 10 + stats = ( + spark.range(1000) + .select(quinn.rand_range(lower_bound, upper_bound).alias("rand_uniform")) + .agg(F.min("rand_uniform").alias("min"), F.min("rand_uniform").alias("max")) + .first() + ) + + uniform_min = stats["min"] + uniform_max = stats["max"] + + assert lower_bound <= uniform_min <= uniform_max <= upper_bound