Skip to content

Commit

Permalink
Trying out some error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
dmnapolitano committed Dec 12, 2023
1 parent 2f9c75b commit 833cb82
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions src/elexsolver/TransitionMatrixSolver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import warnings

import cvxpy as cp
import numpy as np
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

from elexsolver.logging import initialize_logging
from elexsolver.TransitionSolver import TransitionSolver, mean_absolute_error
Expand Down Expand Up @@ -33,7 +33,15 @@ def __solve(self, A, B, weights):
objective = cp.Minimize(loss_function)
constraint = TransitionMatrixSolver.__get_constraint(transition_matrix, self._strict)
problem = cp.Problem(objective, constraint)
problem.solve(solver=cp.CLARABEL)

with warnings.catch_warnings():
warnings.simplefilter("error")
try:
problem.solve(solver=cp.CLARABEL)
except (UserWarning, cp.error.SolverError) as e:
LOG.error(e)
return np.zeros((A.shape[1], B.shape[1]))

return transition_matrix.value

def fit_predict(self, X, Y, weights=None):
Expand Down Expand Up @@ -75,8 +83,13 @@ def fit_predict(self, X, Y, weights=None):

transition_matrix = self.__solve(X, Y, weights)
transitions = np.diag(X_expected_totals) @ transition_matrix
Y_pred_totals = np.sum(transitions, axis=0) / np.sum(transitions, axis=0).sum()
self._mae = mean_absolute_error(Y_expected_totals, Y_pred_totals)

if np.sum(transitions, axis=0).sum() != 0:
Y_pred_totals = np.sum(transitions, axis=0) / np.sum(transitions, axis=0).sum()
self._mae = mean_absolute_error(Y_expected_totals, Y_pred_totals)
else:
# would have logged an error above
self._mae = 1
if self._verbose:
LOG.info("MAE = %s", np.around(self._mae, 4))

Expand Down Expand Up @@ -108,24 +121,23 @@ def fit_predict(self, X, Y, weights=None):

from sklearn.utils import resample # to be replaced

with logging_redirect_tqdm(loggers=[LOG]):
for b in tqdm(range(0, self._B - 1), desc="Bootstrapping"):
X_resampled = []
Y_resampled = []
weights_resampled = []
for i in resample(range(0, len(X)), replace=True, random_state=b):
X_resampled.append(X[i])
Y_resampled.append(Y[i])
if weights is not None:
weights_resampled.append(weights[i])
if weights is None:
weights_resampled = None
else:
weights_resampled = np.array(weights_resampled)
predicted_transitions.append(
tm.fit_predict(np.array(X_resampled), np.array(Y_resampled), weights=weights_resampled)
)
maes.append(tm.MAE)
for b in tqdm(range(0, self._B - 1), desc="Bootstrapping"):
X_resampled = []
Y_resampled = []
weights_resampled = []
for i in resample(range(0, len(X)), replace=True, random_state=b):
X_resampled.append(X[i])
Y_resampled.append(Y[i])
if weights is not None:
weights_resampled.append(weights[i])
if weights is None:
weights_resampled = None
else:
weights_resampled = np.array(weights_resampled)
predicted_transitions.append(
tm.fit_predict(np.array(X_resampled), np.array(Y_resampled), weights=weights_resampled)
)
maes.append(tm.MAE)

self._mae = np.mean(maes)
LOG.info("MAE = %s", np.around(self._mae, 4))
Expand Down

0 comments on commit 833cb82

Please sign in to comment.