Skip to content

Commit

Permalink
Use the weights in the bootstrap to draw a weighted sample
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Dec 12, 2023
1 parent 833cb82 commit 211e5b5
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,21 @@ def fit_predict(self, X, Y, weights=None):
X = X.to_numpy()
if not isinstance(Y, np.ndarray):
Y = Y.to_numpy()
# assuming pandas.Series
if weights is not None and not isinstance(weights, np.ndarray):
weights = weights.values

tm = TransitionMatrixSolver(strict=self._strict, verbose=False)
predicted_transitions.append(tm.fit_predict(X, Y, weights=weights))
maes.append(tm.MAE)

from sklearn.utils import resample # to be replaced

for b in tqdm(range(0, self._B - 1), desc="Bootstrapping"):
X_resampled = []
Y_resampled = []
weights_resampled = []
for i in resample(range(0, len(X)), replace=True, random_state=b):
X_resampled.append(X[i])
Y_resampled.append(Y[i])
if weights is not None:
weights_resampled.append(weights[i])
if weights is None:
weights_resampled = None
else:
weights_resampled = np.array(weights_resampled)
predicted_transitions.append(
tm.fit_predict(np.array(X_resampled), np.array(Y_resampled), weights=weights_resampled)
rng = np.random.default_rng(seed=b)
X_resampled = rng.choice(
X, len(X), replace=True, axis=0, p=(weights / weights.sum() if weights is not None else None)
)
Y_resampled = []
for x in X_resampled:
index = np.where((X == x).all(axis=1))[0][0]
Y_resampled.append(Y[index])
predicted_transitions.append(tm.fit_predict(X_resampled, np.array(Y_resampled), weights=None))
maes.append(tm.MAE)

self._mae = np.mean(maes)
Expand Down

0 comments on commit 211e5b5

Please sign in to comment.