Skip to content

Commit

Permalink
Adding method for confidence interval to bootstrap solver
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Dec 15, 2023
1 parent a77409c commit 5559e22
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,12 @@ def __init__(self, B=1000, strict=True):
self._strict = strict
self._B = B

# class members that are instantiated during model-fit
self._predicted_transitions = None

def fit_predict(self, X, Y, weights=None):
maes = []
predicted_transitions = []
self._predicted_transitions = []

# assuming pandas.DataFrame
if not isinstance(X, np.ndarray):
Expand All @@ -113,7 +116,7 @@ def fit_predict(self, X, Y, weights=None):
Y = Y.to_numpy()

tm = TransitionMatrixSolver(strict=self._strict, verbose=False)
predicted_transitions.append(tm.fit_predict(X, Y, weights=weights))
self._predicted_transitions.append(tm.fit_predict(X, Y, weights=weights))
maes.append(tm.MAE)

for b in tqdm(range(0, self._B - 1), desc="Bootstrapping"):
Expand All @@ -123,9 +126,22 @@ def fit_predict(self, X, Y, weights=None):
)
indices = [np.where((X == x).all(axis=1))[0][0] for x in X_resampled]
Y_resampled = Y[indices]
predicted_transitions.append(tm.fit_predict(X_resampled, Y_resampled, weights=None))
self._predicted_transitions.append(tm.fit_predict(X_resampled, Y_resampled, weights=None))
maes.append(tm.MAE)

self._mae = np.mean(maes)
LOG.info("MAE = %s", np.around(self._mae, 4))
return np.mean(predicted_transitions, axis=0)
return np.mean(self._predicted_transitions, axis=0)

def get_confidence_interval(self, alpha):
if alpha > 1:
alpha = alpha / 100
if alpha < 0 or alpha >= 1:
raise ValueError(f"Invalid confidence interval {alpha}.")

p_lower = ((1.0 - alpha) / 2.0) * 100
p_upper = ((1.0 + alpha) / 2.0) * 100
return (
np.percentile(self._predicted_transitions, p_lower, axis=0),
np.percentile(self._predicted_transitions, p_upper, axis=0),
)

0 comments on commit 5559e22

Please sign in to comment.