diff --git a/river/metrics/__init__.py b/river/metrics/__init__.py index 5a1e102756..fa62ee9a7e 100644 --- a/river/metrics/__init__.py +++ b/river/metrics/__init__.py @@ -60,6 +60,7 @@ from .recall import MacroRecall, MicroRecall, Recall, WeightedRecall from .report import ClassificationReport from .roc_auc import ROCAUC +from .rolling_pr_auc import RollingPRAUC from .rolling_roc_auc import RollingROCAUC from .silhouette import Silhouette from .smape import SMAPE @@ -108,6 +109,7 @@ "FowlkesMallows", "RMSLE", "ROCAUC", + "RollingPRAUC", "RollingROCAUC", "R2", "Precision", diff --git a/river/metrics/efficient_rollingprauc/__init__.py b/river/metrics/efficient_rollingprauc/__init__.py new file mode 100644 index 0000000000..cac3a79140 --- /dev/null +++ b/river/metrics/efficient_rollingprauc/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .efficient_rollingprauc import EfficientRollingPRAUC + +__all__ = ["EfficientRollingPRAUC"] diff --git a/river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.cpp b/river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.cpp new file mode 100644 index 0000000000..568d0b4bc1 --- /dev/null +++ b/river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.cpp @@ -0,0 +1,150 @@ +#include "RollingPRAUC.hpp" + +#include +#include + +namespace rollingprauc { + +RollingPRAUC::RollingPRAUC(): positiveLabel{1}, windowSize{1000}, positives{0} { +} + +RollingPRAUC::RollingPRAUC(int positiveLabel, long unsigned windowSize): + positiveLabel{positiveLabel}, windowSize{windowSize}, positives{0} { +} + +void RollingPRAUC::update(int label, double score) { + if (this->window.size() == this->windowSize) + this->removeLast(); + + this->insert(label, score); + + return; +} + +void RollingPRAUC::revert(int label, double score) { + int normalizedLabel = 0; + if (label == this->positiveLabel) + normalizedLabel = 1; + + std::deque>::const_iterator it{this->window.cbegin()}; + for (; it != this->window.cend(); ++it) + if (std::get<0>(*it) == score && std::get<1>(*it) == normalizedLabel) + break; + + if (it == this->window.cend()) + return; + + if (normalizedLabel) + this->positives--; + + this->window.erase(it); + + std::multiset>::const_iterator itr{ + this->orderedWindow.find(std::make_tuple(score, label)) + }; + this->orderedWindow.erase(itr); + + return; +} + +double RollingPRAUC::get() const { + unsigned long windowSize{this->window.size()}; + + // If there is only one class in the window, it will lead to a + // division by zero. So, zero is returned. + if (!this->positives || !(windowSize - this->positives)) + return 0; + + unsigned long fp{windowSize - this->positives}; + unsigned long tp{this->positives}, tpPrev{tp}; + + double auc{0}, scorePrev{std::numeric_limits::max()}; + + double prec{tp / (double) (tp + fp)}, precPrev{prec}; + + std::multiset>::const_iterator it{this->orderedWindow.begin()}; + double score; + int label; + + for (; it != this->orderedWindow.end(); ++it) { + score = std::get<0>(*it); + label = std::get<1>(*it); + + if (score != scorePrev) { + prec = tp / (double) (tp + fp); + + if (precPrev > prec) + prec = precPrev; // Monotonic. decreasing + + auc += this->trapzArea(tp, tpPrev, prec, precPrev); + + scorePrev = score; + tpPrev = tp; + precPrev = prec; + } + + if (label) tp--; + else fp--; + } + + auc += this->trapzArea(tp, tpPrev, 1.0, precPrev); + + return auc / this->positives; // Scale the x axis +} + +void RollingPRAUC::insert(int label, double score) { + // Normalize label to 0 (negative) or 1 (positive) + int l = 0; + if (label == this->positiveLabel) { + l = 1; + this->positives++; + } + + this->window.emplace_back(score, l); + this->orderedWindow.emplace(score, l); + + return; +} + +void RollingPRAUC::removeLast() { + std::tuple last{this->window.front()}; + + if (std::get<1>(last)) + this->positives--; + + this->window.pop_front(); + + // Erase using a iterator to avoid multiple erases with equivalent instances + std::multiset>::iterator it{ + this->orderedWindow.find(last) + }; + this->orderedWindow.erase(it); + + return; +} + +std::vector RollingPRAUC::getTrueLabels() const { + std::vector trueLabels; + + std::deque>::const_iterator it{this->window.begin()}; + for (; it != this->window.end(); ++it) + trueLabels.push_back(std::get<1>(*it)); + + return trueLabels; +} + +std::vector RollingPRAUC::getScores() const { + std::vector scores; + + std::deque>::const_iterator it{this->window.begin()}; + for (; it != this->window.end(); ++it) + scores.push_back(std::get<0>(*it)); + + return scores; +} + +double RollingPRAUC::trapzArea(double x1, double x2, double y1, double y2) const { + return abs(x1 - x2) * (y1 + y2) / 2; +} + +} // namespace rollingprauc diff --git a/river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.hpp b/river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.hpp new file mode 100644 index 0000000000..e47e423079 --- /dev/null +++ b/river/metrics/efficient_rollingprauc/cpp/RollingPRAUC.hpp @@ -0,0 +1,59 @@ +#ifndef ROLLINGPRAUC_HPP +#define ROLLINGPRAUC_HPP + +#include +#include +#include +#include + +namespace rollingprauc { + +class RollingPRAUC { + public: + RollingPRAUC(); + RollingPRAUC(const int positiveLabel, const long unsigned windowSize); + + virtual ~RollingPRAUC() = default; + + // Calls insert() and removeLast() if needed + virtual void update(const int label, const double score); + + // Erase the most recent instance with content equal to params + virtual void revert(const int label, const double score); + + // Calculates the PRAUC and returns it + virtual double get() const; + + // Returns y_true as a vector + virtual std::vector getTrueLabels() const; + + // Returns y_score as a vector + virtual std::vector getScores() const; + + private: + // Insert instance based on params + virtual void insert(const int label, const double score); + + // Remove oldest instance + virtual void removeLast(); + + // Calculates the trapezoid area + double trapzArea(double x1, double x2, double y1, double y2) const; + + int positiveLabel; + + std::size_t windowSize; + std::size_t positives; + + // window maintains a queue of the instances to store the temporal + // aspect of the stream. Using deque to allow revert() + std::deque> window; + + // orderedWindow maintains a multiset (implemented as a tree) + // to store the ordered instances + std::multiset> orderedWindow; +}; + +} // namespace rollingprauc + +#endif diff --git a/river/metrics/efficient_rollingprauc/efficient_rollingprauc.pxd b/river/metrics/efficient_rollingprauc/efficient_rollingprauc.pxd new file mode 100644 index 0000000000..c9a0d8f5c0 --- /dev/null +++ b/river/metrics/efficient_rollingprauc/efficient_rollingprauc.pxd @@ -0,0 +1,13 @@ +from libcpp.vector cimport vector + +cdef extern from "cpp/RollingPRAUC.cpp": + pass + +cdef extern from "cpp/RollingPRAUC.hpp" namespace "rollingprauc": + cdef cppclass RollingPRAUC: + RollingPRAUC(int positiveLabel, int windowSize) except + + void update(int label, double score) + void revert(int label, double score) + double get() + vector[int] getTrueLabels() + vector[double] getScores() diff --git a/river/metrics/efficient_rollingprauc/efficient_rollingprauc.pyx b/river/metrics/efficient_rollingprauc/efficient_rollingprauc.pyx new file mode 100644 index 0000000000..1c8cb4123e --- /dev/null +++ b/river/metrics/efficient_rollingprauc/efficient_rollingprauc.pyx @@ -0,0 +1,55 @@ +# distutils: language = c++ +# distutils: extra_compile_args = "-std=c++11" + +import cython + +from .efficient_rollingprauc cimport RollingPRAUC as CppRollingPRAUC + +cdef class EfficientRollingPRAUC: + cdef cython.int positiveLabel + cdef cython.ulong windowSize + cdef CppRollingPRAUC* rollingprauc + + def __cinit__(self, cython.int positiveLabel, cython.ulong windowSize): + self.positiveLabel = positiveLabel + self.windowSize = windowSize + self.rollingprauc = new CppRollingPRAUC(positiveLabel, windowSize) + + def __dealloc__(self): + if not self.rollingprauc == NULL: + del self.rollingprauc + + def update(self, label, score): + self.rollingprauc.update(label, score) + + def revert(self, label, score): + self.rollingprauc.revert(label, score) + + def get(self): + return self.rollingprauc.get() + + def __getnewargs_ex__(self): + # Pickle will use this function to pass the arguments to __new__ + return (self.positiveLabel, self.windowSize),{} + + def __getstate__(self): + """ + On pickling, the true labels and scores of the instances in the + window will be dumped + """ + return (self.rollingprauc.getTrueLabels(), self.rollingprauc.getScores()) + + def __setstate__(self, state): + """ + On unpickling, the state parameter will have the true labels + and scores, this function updates the rollingprauc with them + """ + + # Labels returned by __getstate__ are normalized (0 or 1) + labels, scores = state + + for label, score in zip(labels, scores): + # If label is 1, update with the positive label defined by the constructor + # Else, update with a negative label + l = self.positiveLabel if label else int(not self.positiveLabel) + self.update(l, score) diff --git a/river/metrics/rolling_pr_auc.py b/river/metrics/rolling_pr_auc.py new file mode 100644 index 0000000000..d893eb843d --- /dev/null +++ b/river/metrics/rolling_pr_auc.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from river import metrics, utils + +from .efficient_rollingprauc import EfficientRollingPRAUC + +__all__ = ["RollingPRAUC"] + + +class RollingPRAUC(metrics.base.BinaryMetric): + """Rolling version of the Area Under the Precision-Recall Area Under Curve + metric. + + The RollingPRAUC calculates the AUC-PR using the instances in its window + of size S. It keeps a queue of the instances, when an instance is added + and the queue length is equal to S, the last instance is removed. + + The AUC-PR is suitable for evaluating models under unbalanced environments. + For now, the implementation can deal only with binary scenarios. + + Internally, this class maintains a self-balancing binary search tree to + efficiently and precisely (i.e., the result is not an approximation) + compute the AUC-PR considering the current window. + + This implementation is based on the paper "Efficient Prequential AUC-PR + Computation" (Gomes, Grégio, Alves, and Almeida, 2023): + https://doi.org/10.1109/ICMLA58977.2023.00335. + + + Parameters + ---------- + window_size + The max length of the window. + pos_val + Value to treat as "positive". + + Examples + -------- + + >>> from river import metrics + + >>> y_true = [ 0, 1, 0, 1, 0, 1, 0, 0, 1, 1] + >>> y_pred = [.3, .5, .5, .7, .1, .3, .1, .4, .35, .8] + + >>> metric = metrics.RollingPRAUC(window_size=4) + + >>> for yt, yp in zip(y_true, y_pred): + ... metric.update(yt, yp) + + >>> metric + RollingPRAUC: 83.33% + + """ + + def __init__(self, window_size=1000, pos_val=True): + self.window_size = window_size + self.pos_val = pos_val + self._metric = EfficientRollingPRAUC(pos_val, window_size) + + def works_with(self, model) -> bool: + return ( + super().works_with(model) + or utils.inspect.isanomalydetector(model) + or utils.inspect.isanomalyfilter(model) + ) + + def update(self, y_true, y_pred): + p_true = y_pred.get(True, 0.0) if isinstance(y_pred, dict) else y_pred + self._metric.update(y_true, p_true) + + def revert(self, y_true, y_pred): + p_true = y_pred.get(True, 0.0) if isinstance(y_pred, dict) else y_pred + self._metric.revert(y_true, p_true) + + @property + def requires_labels(self) -> bool: + return False + + @property + def works_with_weights(self) -> bool: + return False + + def get(self): + return self._metric.get() diff --git a/river/metrics/test_metrics.py b/river/metrics/test_metrics.py index fa66a9fc6b..9f55d3a615 100644 --- a/river/metrics/test_metrics.py +++ b/river/metrics/test_metrics.py @@ -116,6 +116,24 @@ def roc_auc_score(y_true, y_score): return sk_metrics.roc_auc_score(y_true, scores) +def pr_auc_score(y_true, y_score): + """ + This function is a wrapper to the scikit-learn precision_recall_curve and + auc functions. Returns 0 if y_true has only one class. + """ + nonzero = np.count_nonzero(y_true) + if nonzero == 0 or nonzero == len(y_true): + return 0 + + scores = [s[True] for s in y_score] + precision, recall, _ = sk_metrics.precision_recall_curve(y_true, scores) + + # Monotonic. decreasing + precision = np.maximum.accumulate(precision) + + return sk_metrics.auc(recall, precision) + + TEST_CASES = [ (metrics.Accuracy(), sk_metrics.accuracy_score), (metrics.Precision(), partial(sk_metrics.precision_score, zero_division=0)), @@ -210,6 +228,7 @@ def roc_auc_score(y_true, y_score): (metrics.MicroJaccard(), partial(sk_metrics.jaccard_score, average="micro")), (metrics.WeightedJaccard(), partial(sk_metrics.jaccard_score, average="weighted")), (metrics.RollingROCAUC(), roc_auc_score), + (metrics.RollingPRAUC(), pr_auc_score), ] # HACK: not sure why this is needed, see this CI run https://github.com/online-ml/river/runs/7992357532?check_suite_focus=true