Skip to content

Commit

Permalink
First attempt at ensemble classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
sidchaini committed Oct 21, 2024
1 parent 60a6391 commit f4ba561
Showing 1 changed file with 58 additions and 2 deletions.
60 changes: 58 additions & 2 deletions distclassipy/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def find_best_metrics(
quantiles = pd.qcut(X_df[feature_name], q=n_quantiles)

X_train, X_test, y_train, y_test = train_test_split(
X_df, y_df, test_size=0.33, stratify=quantiles
X_df, y_df, test_size=0.25, stratify=quantiles
)

clf.fit(X_train, y_train.to_numpy().ravel())
Expand Down Expand Up @@ -507,5 +507,61 @@ def find_best_metrics(

# alt, but slower:
# loop through each quantile, and append pred
group_bins = []
for bins, group in grouped_test_data:
group_bins.append(bins)
return quantile_scores_df, best_metrics_per_quantile, group_bins

return quantile_scores_df, best_metrics_per_quantile

class EnsembleDistanceMetricClassifier(BaseEstimator, ClassifierMixin):
"""An ensemble classifier that uses different metrics for each quantile."""

def __init__(
self,
feat_idx: int,
scale: bool = True,
central_stat: str = "median",
dispersion_stat: str = "std",
) -> None:
"""Initialize the classifier with specified parameters."""
self.feat_idx = feat_idx
self.scale = scale
self.central_stat = central_stat
self.dispersion_stat = dispersion_stat

def fit(
self, X: np.ndarray, y: np.ndarray, n_quantiles: int = 4
) -> "EnsembleDistanceMetricClassifier":
"""Fit the ensemble classifier using the best metrics for each quantile."""
self.clf_ = DistanceMetricClassifier(
scale=self.scale,
central_stat=self.central_stat,
dispersion_stat=self.dispersion_stat,
)
(
self.quantile_scores_df_,
self.best_metrics_per_quantile_,
self.group_bins,
) = find_best_metrics(self.clf_, X, y, self.feat_idx, n_quantiles)
self.group_labels = [f"Quantile {i+1}" for i in range(n_quantiles)]
self.clf_.fit(X, y)
self.is_fitted_ = True
return self

def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict class labels using the best metric for each quantile."""
check_is_fitted(self, "is_fitted_")
X = check_array(X)

quantiles = pd.cut(
X[:, self.feat_idx], bins=self.group_bins, labels=self.group_labels
)
self.grouped_data = pd.DataFrame(X).groupby(quantiles, observed=False)
return 0
predictions = np.empty(X.shape[0], dtype=int)
for i, (lim, subdf) in enumerate(grouped_data):
best_metric = self.best_metrics_per_quantile_.loc[self.group_labels[i]]
preds = self.clf_.predict(subdf.to_numpy(), metric=best_metric)
predictions[subdf.index] = preds

return predictions

0 comments on commit f4ba561

Please sign in to comment.