diff --git a/src/elexsolver/TransitionMatrixSolver.py b/src/elexsolver/TransitionMatrixSolver.py index 885c7e13..0ad9bb8b 100644 --- a/src/elexsolver/TransitionMatrixSolver.py +++ b/src/elexsolver/TransitionMatrixSolver.py @@ -102,9 +102,12 @@ def __init__(self, B=1000, strict=True): self._strict = strict self._B = B + # class members that are instantiated during model-fit + self._predicted_transitions = None + def fit_predict(self, X, Y, weights=None): maes = [] - predicted_transitions = [] + self._predicted_transitions = [] # assuming pandas.DataFrame if not isinstance(X, np.ndarray): @@ -113,7 +116,7 @@ def fit_predict(self, X, Y, weights=None): Y = Y.to_numpy() tm = TransitionMatrixSolver(strict=self._strict, verbose=False) - predicted_transitions.append(tm.fit_predict(X, Y, weights=weights)) + self._predicted_transitions.append(tm.fit_predict(X, Y, weights=weights)) maes.append(tm.MAE) for b in tqdm(range(0, self._B - 1), desc="Bootstrapping"): @@ -123,9 +126,22 @@ def fit_predict(self, X, Y, weights=None): ) indices = [np.where((X == x).all(axis=1))[0][0] for x in X_resampled] Y_resampled = Y[indices] - predicted_transitions.append(tm.fit_predict(X_resampled, Y_resampled, weights=None)) + self._predicted_transitions.append(tm.fit_predict(X_resampled, Y_resampled, weights=None)) maes.append(tm.MAE) self._mae = np.mean(maes) LOG.info("MAE = %s", np.around(self._mae, 4)) - return np.mean(predicted_transitions, axis=0) + return np.mean(self._predicted_transitions, axis=0) + + def get_confidence_interval(self, alpha): + if alpha > 1: + alpha = alpha / 100 + if alpha < 0 or alpha >= 1: + raise ValueError(f"Invalid confidence interval {alpha}.") + + p_lower = ((1.0 - alpha) / 2.0) * 100 + p_upper = ((1.0 + alpha) / 2.0) * 100 + return ( + np.percentile(self._predicted_transitions, p_lower, axis=0), + np.percentile(self._predicted_transitions, p_upper, axis=0), + )