From ebd20e91fc400f0bf84fb6ba43557968316c3e3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 19 Jun 2024 10:44:44 +0200 Subject: [PATCH] Add converter for TunedThresholdClassifierCV (#1107) * Add converter for TunedThresholdClassifierCV Signed-off-by: Xavier Dupre * upgrade version Signed-off-by: Xavier Dupre * documentation Signed-off-by: Xavier Dupre * update numpy version Signed-off-by: Xavier Dupre * do not use numpy 2 Signed-off-by: Xavier Dupre * delay import Signed-off-by: Xavier Dupre --------- Signed-off-by: Xavier Dupre --- .github/workflows/linux-ci.yml | 10 ++-- .github/workflows/windows-macos-ci.yml | 12 ++-- docs/index.rst | 1 + skl2onnx/__init__.py | 2 +- skl2onnx/_supported_operators.py | 6 ++ skl2onnx/operator_converters/__init__.py | 2 + .../tuned_threshold_classifier.py | 42 +++++++++++++ skl2onnx/shape_calculators/__init__.py | 2 + .../tuned_threshold_classifier.py | 16 +++++ ...test_sklearn_tuned_threshold_classifier.py | 59 +++++++++++++++++++ 10 files changed, 140 insertions(+), 12 deletions(-) create mode 100644 skl2onnx/operator_converters/tuned_threshold_classifier.py create mode 100644 skl2onnx/shape_calculators/tuned_threshold_classifier.py create mode 100644 tests/test_sklearn_tuned_threshold_classifier.py diff --git a/.github/workflows/linux-ci.yml b/.github/workflows/linux-ci.yml index 10b06f10e..56fa27060 100644 --- a/.github/workflows/linux-ci.yml +++ b/.github/workflows/linux-ci.yml @@ -11,35 +11,35 @@ jobs: include: - sklearn_version: '==1.5.0' documentation: 0 - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx==1.16.0' onnxrt_version: 'onnxruntime==1.18.0' python_version: '3.12' - python_version: '3.12' documentation: 0 - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx==1.16.0' onnxrt_version: 'onnxruntime==1.18.0' sklearn_version: '==1.4.2' - python_version: '3.11' documentation: 1 - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx<1.16.0' onnxrt_version: 'onnxruntime==1.17.3' sklearn_version: '==1.3.2' - python_version: '3.10' documentation: 0 - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx==1.14.1' onnxrt_version: 'onnxruntime==1.16.0' sklearn_version: '==1.2.2' - python_version: '3.9' documentation: 0 - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx==1.13.0' onnxrt_version: 'onnxruntime==1.14.0' diff --git a/.github/workflows/windows-macos-ci.yml b/.github/workflows/windows-macos-ci.yml index 88c285e03..78690e62a 100644 --- a/.github/workflows/windows-macos-ci.yml +++ b/.github/workflows/windows-macos-ci.yml @@ -11,37 +11,37 @@ jobs: include: - sklearn_version: '==1.5.0' python_version: '3.11' - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx==1.16.0' onnxrt_version: 'onnxruntime==1.18.0' - python_version: '3.11' - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx<1.16.0' onnxrt_version: 'onnxruntime<1.18.0' sklearn_version: '==1.3.2' - python_version: '3.10' - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx<1.15' onnxrt_version: 'onnxruntime<1.17.0' sklearn_version: '==1.2.2' - python_version: '3.9' - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx<1.14' onnxrt_version: 'onnxruntime<1.16.0' sklearn_version: '==1.2.2' - sklearn_version: '==1.4.2' python_version: '3.11' - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx>=1.16.0' onnxrt_version: 'onnxruntime>=1.18.0' - sklearn_version: '==1.1.3' python_version: '3.11' - numpy_version: '>=1.21.1' + numpy_version: '>=1.21.1,<2.0' scipy_version: '>=1.7.0' onnx_version: 'onnx>=1.16.0' onnxrt_version: 'onnxruntime>=1.18.0' diff --git a/docs/index.rst b/docs/index.rst index 4fa04e447..226bd9bb0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -129,4 +129,5 @@ It is licensed with `Apache License v2.0 <../LICENSE>`_. **Older versions** +* `1.17.0 `_ * `1.16.0 `_ diff --git a/skl2onnx/__init__.py b/skl2onnx/__init__.py index 3db90b766..431380c70 100644 --- a/skl2onnx/__init__.py +++ b/skl2onnx/__init__.py @@ -3,7 +3,7 @@ """ Main entry point to the converter from the *scikit-learn* to *onnx*. """ -__version__ = "1.17.0" +__version__ = "1.18.0" __author__ = "Microsoft" __producer__ = "skl2onnx" __producer_version__ = __version__ diff --git a/skl2onnx/_supported_operators.py b/skl2onnx/_supported_operators.py index b4a5a7278..34f017cc3 100644 --- a/skl2onnx/_supported_operators.py +++ b/skl2onnx/_supported_operators.py @@ -131,6 +131,11 @@ # GridSearchCV from sklearn.model_selection import GridSearchCV +try: + from sklearn.model_selection import TunedThresholdClassifierCV +except ImportError: + TunedThresholdClassifierCV = None + # MultiOutput from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor @@ -462,6 +467,7 @@ def build_sklearn_operator_name_map(): TfidfVectorizer, TfidfTransformer, TruncatedSVD, + TunedThresholdClassifierCV, TweedieRegressor, VarianceThreshold, VotingClassifier, diff --git a/skl2onnx/operator_converters/__init__.py b/skl2onnx/operator_converters/__init__.py index bc3b04b89..0a9dadfbd 100644 --- a/skl2onnx/operator_converters/__init__.py +++ b/skl2onnx/operator_converters/__init__.py @@ -64,6 +64,7 @@ from . import text_vectoriser from . import tfidf_transformer from . import tfidf_vectoriser +from . import tuned_threshold_classifier from . import voting_classifier from . import voting_regressor from . import zip_map @@ -130,6 +131,7 @@ text_vectoriser, tfidf_transformer, tfidf_vectoriser, + tuned_threshold_classifier, voting_classifier, voting_regressor, zip_map, diff --git a/skl2onnx/operator_converters/tuned_threshold_classifier.py b/skl2onnx/operator_converters/tuned_threshold_classifier.py new file mode 100644 index 000000000..b813105f7 --- /dev/null +++ b/skl2onnx/operator_converters/tuned_threshold_classifier.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +from ..common._registration import register_converter +from ..common._topology import Scope, Operator +from ..common._container import ModelComponentContainer +from ..common.data_types import Int64TensorType +from .._supported_operators import sklearn_operator_name_map + + +def convert_sklearn_tuned_threshold_classifier( + scope: Scope, operator: Operator, container: ModelComponentContainer +): + estimator = operator.raw_operator.estimator_ + op_type = sklearn_operator_name_map[type(estimator)] + + this_operator = scope.declare_local_operator(op_type, estimator) + this_operator.inputs = operator.inputs + + label_name = scope.declare_local_variable("label_tuned", Int64TensorType()) + prob_name = scope.declare_local_variable( + "proba_tuned", operator.outputs[1].type.__class__() + ) + this_operator.outputs.append(label_name) + this_operator.outputs.append(prob_name) + + container.add_node( + "Identity", [label_name.onnx_name], [operator.outputs[0].full_name] + ) + container.add_node( + "Identity", [prob_name.onnx_name], [operator.outputs[1].full_name] + ) + + +register_converter( + "SklearnTunedThresholdClassifierCV", + convert_sklearn_tuned_threshold_classifier, + options={ + "zipmap": [True, False, "columns"], + "output_class_labels": [False, True], + "nocl": [True, False], + }, +) diff --git a/skl2onnx/shape_calculators/__init__.py b/skl2onnx/shape_calculators/__init__.py index 6a4fc36be..ab5556b1e 100644 --- a/skl2onnx/shape_calculators/__init__.py +++ b/skl2onnx/shape_calculators/__init__.py @@ -48,6 +48,7 @@ from . import svd from . import support_vector_machines from . import text_vectorizer +from . import tuned_threshold_classifier from . import tfidf_transformer from . import voting_classifier from . import voting_regressor @@ -100,6 +101,7 @@ support_vector_machines, text_vectorizer, tfidf_transformer, + tuned_threshold_classifier, voting_classifier, voting_regressor, zip_map, diff --git a/skl2onnx/shape_calculators/tuned_threshold_classifier.py b/skl2onnx/shape_calculators/tuned_threshold_classifier.py new file mode 100644 index 000000000..2d166391d --- /dev/null +++ b/skl2onnx/shape_calculators/tuned_threshold_classifier.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from ..common._registration import register_shape_calculator +from ..common.utils import check_input_and_output_numbers +from ..common.shape_calculator import _infer_linear_classifier_output_types + + +def tuned_threshold_classifier_shape_calculator(operator): + check_input_and_output_numbers(operator, output_count_range=2) + + _infer_linear_classifier_output_types(operator) + + +register_shape_calculator( + "SklearnTunedThresholdClassifierCV", tuned_threshold_classifier_shape_calculator +) diff --git a/tests/test_sklearn_tuned_threshold_classifier.py b/tests/test_sklearn_tuned_threshold_classifier.py new file mode 100644 index 000000000..9df3e09a0 --- /dev/null +++ b/tests/test_sklearn_tuned_threshold_classifier.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 + +import unittest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from sklearn.utils._testing import ignore_warnings +from skl2onnx import to_onnx +from skl2onnx.common.data_types import FloatTensorType +from test_utils import dump_data_and_model, TARGET_OPSET + + +def has_tuned_theshold_classifier(): + try: + from sklearn.model_selection import TunedThresholdClassifierCV # noqa: F401 + except ImportError: + return False + return True + + +class TestSklearnTunedThresholdClassifierConverter(unittest.TestCase): + @unittest.skipIf( + not has_tuned_theshold_classifier(), + reason="TunedThresholdClassifierCV not available", + ) + @ignore_warnings(category=FutureWarning) + def test_tuned_threshold_classifier(self): + from sklearn.model_selection import TunedThresholdClassifierCV + + X, y = make_classification( + n_samples=1_000, weights=[0.9, 0.1], class_sep=0.8, random_state=42 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=42 + ) + classifier = RandomForestClassifier(random_state=0) + + classifier_tuned = TunedThresholdClassifierCV( + classifier, scoring="balanced_accuracy" + ).fit(X_train, y_train) + + model_onnx = to_onnx( + classifier_tuned, + initial_types=[("X", FloatTensorType([None, X_train.shape[1]]))], + target_opset=TARGET_OPSET - 1, + options={"zipmap": False}, + ) + self.assertTrue(model_onnx is not None) + dump_data_and_model( + X_test[:10].astype(np.float32), + classifier_tuned, + model_onnx, + basename="SklearnTunedThresholdClassifier", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2)