Skip to content

Commit

Permalink
Add converter for TunedThresholdClassifierCV (#1107)
Browse files Browse the repository at this point in the history
* Add converter for TunedThresholdClassifierCV

Signed-off-by: Xavier Dupre <[email protected]>

* upgrade version

Signed-off-by: Xavier Dupre <[email protected]>

* documentation

Signed-off-by: Xavier Dupre <[email protected]>

* update numpy version

Signed-off-by: Xavier Dupre <[email protected]>

* do not use numpy 2

Signed-off-by: Xavier Dupre <[email protected]>

* delay import

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jun 19, 2024
1 parent 4dad29e commit ebd20e9
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 12 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/linux-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,35 @@ jobs:
include:
- sklearn_version: '==1.5.0'
documentation: 0
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
python_version: '3.12'
- python_version: '3.12'
documentation: 0
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
sklearn_version: '==1.4.2'
- python_version: '3.11'
documentation: 1
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx<1.16.0'
onnxrt_version: 'onnxruntime==1.17.3'
sklearn_version: '==1.3.2'
- python_version: '3.10'
documentation: 0
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.14.1'
onnxrt_version: 'onnxruntime==1.16.0'
sklearn_version: '==1.2.2'
- python_version: '3.9'
documentation: 0
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.13.0'
onnxrt_version: 'onnxruntime==1.14.0'
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/windows-macos-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,37 @@ jobs:
include:
- sklearn_version: '==1.5.0'
python_version: '3.11'
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
- python_version: '3.11'
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx<1.16.0'
onnxrt_version: 'onnxruntime<1.18.0'
sklearn_version: '==1.3.2'
- python_version: '3.10'
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx<1.15'
onnxrt_version: 'onnxruntime<1.17.0'
sklearn_version: '==1.2.2'
- python_version: '3.9'
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx<1.14'
onnxrt_version: 'onnxruntime<1.16.0'
sklearn_version: '==1.2.2'
- sklearn_version: '==1.4.2'
python_version: '3.11'
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx>=1.16.0'
onnxrt_version: 'onnxruntime>=1.18.0'
- sklearn_version: '==1.1.3'
python_version: '3.11'
numpy_version: '>=1.21.1'
numpy_version: '>=1.21.1,<2.0'
scipy_version: '>=1.7.0'
onnx_version: 'onnx>=1.16.0'
onnxrt_version: 'onnxruntime>=1.18.0'
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,5 @@ It is licensed with `Apache License v2.0 <../LICENSE>`_.

**Older versions**

* `1.17.0 <versions/v1.17.0/>`_
* `1.16.0 <versions/v1.16.0/>`_
2 changes: 1 addition & 1 deletion skl2onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
Main entry point to the converter from the *scikit-learn* to *onnx*.
"""
__version__ = "1.17.0"
__version__ = "1.18.0"
__author__ = "Microsoft"
__producer__ = "skl2onnx"
__producer_version__ = __version__
Expand Down
6 changes: 6 additions & 0 deletions skl2onnx/_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@
# GridSearchCV
from sklearn.model_selection import GridSearchCV

try:
from sklearn.model_selection import TunedThresholdClassifierCV
except ImportError:
TunedThresholdClassifierCV = None

# MultiOutput
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor

Expand Down Expand Up @@ -462,6 +467,7 @@ def build_sklearn_operator_name_map():
TfidfVectorizer,
TfidfTransformer,
TruncatedSVD,
TunedThresholdClassifierCV,
TweedieRegressor,
VarianceThreshold,
VotingClassifier,
Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from . import text_vectoriser
from . import tfidf_transformer
from . import tfidf_vectoriser
from . import tuned_threshold_classifier
from . import voting_classifier
from . import voting_regressor
from . import zip_map
Expand Down Expand Up @@ -130,6 +131,7 @@
text_vectoriser,
tfidf_transformer,
tfidf_vectoriser,
tuned_threshold_classifier,
voting_classifier,
voting_regressor,
zip_map,
Expand Down
42 changes: 42 additions & 0 deletions skl2onnx/operator_converters/tuned_threshold_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0

from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common.data_types import Int64TensorType
from .._supported_operators import sklearn_operator_name_map


def convert_sklearn_tuned_threshold_classifier(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
estimator = operator.raw_operator.estimator_
op_type = sklearn_operator_name_map[type(estimator)]

this_operator = scope.declare_local_operator(op_type, estimator)
this_operator.inputs = operator.inputs

label_name = scope.declare_local_variable("label_tuned", Int64TensorType())
prob_name = scope.declare_local_variable(
"proba_tuned", operator.outputs[1].type.__class__()
)
this_operator.outputs.append(label_name)
this_operator.outputs.append(prob_name)

container.add_node(
"Identity", [label_name.onnx_name], [operator.outputs[0].full_name]
)
container.add_node(
"Identity", [prob_name.onnx_name], [operator.outputs[1].full_name]
)


register_converter(
"SklearnTunedThresholdClassifierCV",
convert_sklearn_tuned_threshold_classifier,
options={
"zipmap": [True, False, "columns"],
"output_class_labels": [False, True],
"nocl": [True, False],
},
)
2 changes: 2 additions & 0 deletions skl2onnx/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from . import svd
from . import support_vector_machines
from . import text_vectorizer
from . import tuned_threshold_classifier
from . import tfidf_transformer
from . import voting_classifier
from . import voting_regressor
Expand Down Expand Up @@ -100,6 +101,7 @@
support_vector_machines,
text_vectorizer,
tfidf_transformer,
tuned_threshold_classifier,
voting_classifier,
voting_regressor,
zip_map,
Expand Down
16 changes: 16 additions & 0 deletions skl2onnx/shape_calculators/tuned_threshold_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

from ..common._registration import register_shape_calculator
from ..common.utils import check_input_and_output_numbers
from ..common.shape_calculator import _infer_linear_classifier_output_types


def tuned_threshold_classifier_shape_calculator(operator):
check_input_and_output_numbers(operator, output_count_range=2)

_infer_linear_classifier_output_types(operator)


register_shape_calculator(
"SklearnTunedThresholdClassifierCV", tuned_threshold_classifier_shape_calculator
)
59 changes: 59 additions & 0 deletions tests/test_sklearn_tuned_threshold_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0

import unittest
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.utils._testing import ignore_warnings
from skl2onnx import to_onnx
from skl2onnx.common.data_types import FloatTensorType
from test_utils import dump_data_and_model, TARGET_OPSET


def has_tuned_theshold_classifier():
try:
from sklearn.model_selection import TunedThresholdClassifierCV # noqa: F401
except ImportError:
return False
return True


class TestSklearnTunedThresholdClassifierConverter(unittest.TestCase):
@unittest.skipIf(
not has_tuned_theshold_classifier(),
reason="TunedThresholdClassifierCV not available",
)
@ignore_warnings(category=FutureWarning)
def test_tuned_threshold_classifier(self):
from sklearn.model_selection import TunedThresholdClassifierCV

X, y = make_classification(
n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=42
)
classifier = RandomForestClassifier(random_state=0)

classifier_tuned = TunedThresholdClassifierCV(
classifier, scoring="balanced_accuracy"
).fit(X_train, y_train)

model_onnx = to_onnx(
classifier_tuned,
initial_types=[("X", FloatTensorType([None, X_train.shape[1]]))],
target_opset=TARGET_OPSET - 1,
options={"zipmap": False},
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test[:10].astype(np.float32),
classifier_tuned,
model_onnx,
basename="SklearnTunedThresholdClassifier",
)


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit ebd20e9

Please sign in to comment.