Skip to content

Commit

Permalink
feat: Added working version of EnsembleDistanceMetricClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
sidchaini committed Oct 21, 2024
1 parent cb918db commit a71a615
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions distclassipy/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def __init__(
self.scale = scale
self.central_stat = central_stat
self.dispersion_stat = dispersion_stat
self.metrics_to_consider = metrics_to_consider or _ALL_METRICS
self.metrics_to_consider = metrics_to_consider

def fit(
self, X: np.ndarray, y: np.ndarray, n_quantiles: int = 4
Expand Down Expand Up @@ -604,12 +604,26 @@ def predict(self, X: np.ndarray) -> np.ndarray:
quantiles = pd.cut(
X[:, self.feat_idx], bins=self.group_bins, labels=self.group_labels
)
grouped_data = pd.DataFrame(X).groupby(quantiles, observed=False)
# grouped_data = pd.DataFrame(X).groupby(quantiles, observed=False)
quantile_indices = quantiles.codes # Get integer codes for quantiles
predictions = np.empty(X.shape[0], dtype=int)
for i, (lim, subdf) in enumerate(grouped_data):
best_metric = self.best_metrics_per_quantile_.loc[self.group_labels[i]]
preds = self.clf_.predict(subdf.to_numpy(), metric=best_metric)
predictions[subdf.index] = preds
# for i, (lim, subdf) in enumerate(grouped_data):
# best_metric = self.best_metrics_per_quantile_.loc[self.group_labels[i]]
# preds = self.clf_.predict(subdf.to_numpy(), metric=best_metric)
# predictions[subdf.index] = preds
# Precompute predictions for each quantile
quantile_predictions = {}
for i, label in enumerate(self.group_labels):
best_metric = self.best_metrics_per_quantile_.loc[label]
quantile_data = X[quantile_indices == i]
if quantile_data.size > 0:
quantile_predictions[i] = self.clf_.predict(
quantile_data, metric=best_metric
)

# Assign predictions to the corresponding indices
for i, preds in quantile_predictions.items():
predictions[quantile_indices == i] = preds

return predictions

Expand Down

0 comments on commit a71a615

Please sign in to comment.