Skip to content

Commit

Permalink
add randn with mean and variance (#270)
Browse files Browse the repository at this point in the history
* add randn with mean and variance

* use correct random function
  • Loading branch information
zeotuan authored Oct 12, 2024
1 parent 8955005 commit d55ac0f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
5 changes: 1 addition & 4 deletions quinn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@
week_end_date,
week_start_date,
)
from quinn.math import (
rand_laplace,
rand_range,
)
from quinn.math import rand_laplace, rand_range, randn
from quinn.schema_helpers import print_schema_as_code
from quinn.split_columns import split_col
from quinn.transformations import (
Expand Down
22 changes: 22 additions & 0 deletions quinn/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,28 @@ def rand_range(
return minimum + (maximum - minimum) * u


def randn(
mean: Union[float, Column],
variance: Union[float, Column],
seed: Optional[int] = None,
) -> Column:
"""Generate a column with independent and identically distributed (i.i.d.) samples from
the standard normal distribution with given `mean` and `variance`..
:param mean: Mean of the normal distribution of the random numbers
:param variance: variance of the normal distribution of the random numbers
:param seed: random seed value (optional, default None)
:returns: column with random numbers
"""
if not isinstance(mean, Column):
mean = F.lit(mean)

if not isinstance(variance, Column):
variance = F.lit(variance)

return F.randn(seed) * F.sqrt(variance) + mean


def div_or_else(
cola: Column,
colb: Column,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,23 @@ def test_rand_range():
uniform_max = stats["max"]

assert lower_bound <= uniform_min <= uniform_max <= upper_bound


def test_randn():
mean = 1.0
variance = 2.0
stats = (
spark.range(1000)
.select(quinn.randn(mean, variance).alias("rand_normal"))
.agg(
F.mean("rand_normal").alias("agg_mean"),
F.variance("rand_normal").alias("agg_variance"),
)
.first()
)

agg_mean = stats["agg_mean"]
agg_variance = stats["agg_variance"]

assert agg_mean - mean <= 0.1
assert agg_variance - variance <= 0.1

0 comments on commit d55ac0f

Please sign in to comment.