Skip to content

Commit

Permalink
Merge pull request #13 from s3alfisc/get_weights-port
Browse files Browse the repository at this point in the history
add get_weights port
  • Loading branch information
s3alfisc authored Oct 14, 2022
2 parents 3d91549 + dee43a1 commit 371a530
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/get_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
from itertools import product

class WildDrawFunctionException(Exception):
pass

def rademacher(n: int) -> np.ndarray:
rng = np.random.default_rng()
return rng.choice([-1,1],size=n, replace=True)

def mammen(n: int) -> np.ndarray:
rng = np.random.default_rng()
return rng.choice(
a= np.array([-1, 1]) * (np.sqrt(5) + np.array([-1, 1])) / 2, #TODO: #10 Should this divide the whole expression by 2 or just the second part
size=n,
replace=True,
p = (np.sqrt(5) + np.array([1, -1])) / (2 * np.sqrt(5))
)

def norm(n):
rng = np.random.default_rng()
return rng.normal(size=n)

def webb(n):
rng = np.random.default_rng()
return rng.choice(
a = np.concatenate([-np.sqrt(np.array([3,2,1]) / 2), np.sqrt(np.array([1,2,3]) / 2)]),
replace=True,
size=n
)

wild_draw_fun_dict = {
'rademacher' : rademacher,
'mammen' : mammen,
'norm' : norm,
'webb' : webb
}


def get_weights(t : str, full_enumeration: bool, N_G_bootcluster: int, boot_iter: int) -> np.ndarray:
"""draw bootstrap weights
Args:
t (str): the type of the weights distribution. Either 'rademacher', 'mammen', 'norm' or 'webb'
full_enumeration (bool): should deterministic full enumeration be employed
N_G_bootcluster (int): the number of bootstrap clusters
boot_iter (int): the number of bootstrap iterations
Returns:
np.ndarray: a matrix of dimension N_G_bootcluster x (boot_iter + 1)
"""

#TODO: we can use the `case` feature in python, but that's only available in 3.10+ will do a 3.7 version for now
# Will take out this and make separate functions for readability

wild_draw_fun = wild_draw_fun_dict.get(t)

if wild_draw_fun is None:
raise WildDrawFunctionException("Function type specified is not supported or there is a typo.")

# do full enumeration for rademacher weights if bootstrap iterations
# B exceed number of possible permutations else random sampling

# full_enumeration only for rademacher weights (set earlier)
if full_enumeration:
t = 0 # what is this needed for?
# with N_G_bootcluster draws, get all combinations of [-1,1] WITH
# replacement, in matrix form
v0 = np.transpose(np.array(list(product([-1,1], repeat=N_G_bootcluster))))
else:
# else: just draw with replacement - by chance, some permutations
# might occur more than once
v0 = wild_draw_fun(n = N_G_bootcluster * boot_iter)
v0.reshape(N_G_bootcluster, boot_iter) # weights matrix

v = np.insert(v0, 0, 1,axis = 1)

return v


0 comments on commit 371a530

Please sign in to comment.