Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/onnx/sklearn-onnx into i108…
Browse files Browse the repository at this point in the history
…9cus
  • Loading branch information
xadupre committed May 23, 2024
2 parents 8d62195 + 9511028 commit f0d452b
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 46 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/linux-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python_version: ['3.12', '3.11', '3.10', '3.9']
sklearn_version: ['==1.5.0', '==1.4.2', '==1.3.2', '==1.2.2', '==1.1.3']
include:
- sklearn_version: '==1.5.0'
documentation: 0
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
python_version: '3.12'
- python_version: '3.12'
documentation: 0
numpy_version: '>=1.21.1'
Expand Down
24 changes: 21 additions & 3 deletions .github/workflows/windows-macos-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@ name: CI Win/Mac
on: [push, pull_request]
jobs:
run:
name: ${{ matrix.os }} py==${{ matrix.python_version }} - sklearn==${{ matrix.sklearn_version }} - ${{ matrix.onnxrt_version }}
name: ${{ matrix.os }} py==${{ matrix.python_version }} - sklearn${{ matrix.sklearn_version }} - ${{ matrix.onnxrt_version }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, macos-latest]
python_version: ['3.11', '3.10', '3.9']
sklearn_version: ['==1.5.0', '==1.4.2', '==1.3.2', '==1.2.2', '==1.1.3']
include:
- sklearn_version: '==1.5.0'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
- python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx<1.16.0'
onnxrt_version: 'onnxruntime==1.17.3'
onnxrt_version: 'onnxruntime<1.18.0'
sklearn_version: '==1.3.2'
- python_version: '3.10'
numpy_version: '>=1.21.1'
Expand All @@ -27,6 +33,18 @@ jobs:
onnx_version: 'onnx<1.14'
onnxrt_version: 'onnxruntime<1.16.0'
sklearn_version: '==1.2.2'
- sklearn_version: '==1.4.2'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx>=1.16.0'
onnxrt_version: 'onnxruntime>=1.18.0'
- sklearn_version: '==1.1.3'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx>=1.16.0'
onnxrt_version: 'onnxruntime>=1.18.0'

steps:
- name: Checkout repository
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.17.0 (development)

* Minor fixes to support scikit-learn==1.5.0
[#1095](https://github.com/onnx/sklearn-onnx/pull/1095)
* Fix the conversion of pipeline including pipelines,
issue [#1069](https://github.com/onnx/sklearn-onnx/pull/1069),
[#1072](https://github.com/onnx/sklearn-onnx/pull/1072)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ including models or transformers coming from external libraries.
## Documentation
Full documentation including tutorials is available at [https://onnx.ai/sklearn-onnx/](https://onnx.ai/sklearn-onnx/).
[Supported scikit-learn Models](https://onnx.ai/sklearn-onnx/supported.html)
Last supported opset is 15.
Last supported opset is 19.

You may also find answers in [existing issues](https://github.com/onnx/sklearn-onnx/issues?utf8=%E2%9C%93&q=is%3Aissue)
or submit a new one.
Expand Down
38 changes: 27 additions & 11 deletions skl2onnx/operator_converters/cross_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from ..algebra.onnx_ops import OnnxAdd, OnnxCast, OnnxDiv, OnnxMatMul, OnnxSub


def _skl150() -> bool:
import sklearn
import packaging.version as pv

return pv.Version(sklearn.__version__) >= pv.Version("1.5.0")


def convert_pls_regression(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
Expand All @@ -27,19 +34,28 @@ def convert_pls_regression(

coefs = op.x_mean_ if hasattr(op, "x_mean_") else op._x_mean
std = op.x_std_ if hasattr(op, "x_std_") else op._x_std
ym = op.y_mean_ if hasattr(op, "x_mean_") else op._y_mean

norm_x = OnnxDiv(
OnnxSub(X, coefs.astype(dtype), op_version=opv),
std.astype(dtype),
op_version=opv,
)
if hasattr(op, "set_predict_request"):
# new in 1.3
if hasattr(op, "intercept_") and _skl150():
# scikit-learn==1.5.0
# https://github.com/scikit-learn/scikit-learn/pull/28612
ym = op.intercept_
centered_x = OnnxSub(X, coefs.astype(dtype), op_version=opv)
coefs = op.coef_.T.astype(dtype)
dot = OnnxMatMul(centered_x, coefs, op_version=opv)
else:
coefs = op.coef_.astype(dtype)
dot = OnnxMatMul(norm_x, coefs, op_version=opv)
ym = op.y_mean_ if hasattr(op, "x_mean_") else op._y_mean

norm_x = OnnxDiv(
OnnxSub(X, coefs.astype(dtype), op_version=opv),
std.astype(dtype),
op_version=opv,
)
if hasattr(op, "set_predict_request"):
# new in 1.3
coefs = op.coef_.T.astype(dtype)
else:
coefs = op.coef_.astype(dtype)
dot = OnnxMatMul(norm_x, coefs, op_version=opv)

pred = OnnxAdd(dot, ym.astype(dtype), op_version=opv, output_names=operator.outputs)
pred.add_to(scope, container)

Expand Down
43 changes: 38 additions & 5 deletions skl2onnx/operator_converters/linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,26 @@ def convert_sklearn_linear_classifier(
intercepts = list(map(lambda x: -1 * x, intercepts)) + intercepts

multi_class = 0
use_ovr = False
if hasattr(op, "multi_class"):
if op.multi_class == "ovr":
multi_class = 1
else:
elif number_of_classes > 2:
# See https://scikit-learn.org/dev/modules/generated/
# sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression
# multi_class attribute is deprecated.
# OVR is not supported anymore.
multi_class = 2
use_ovr = op.multi_class in ["ovr", "warn"] or (
op.multi_class == "auto"
and (op.classes_.size <= 2 or op.solver == "liblinear")
)
else:
# See https://scikit-learn.org/dev/modules/generated/
# sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression
# multi_class attribute is deprecated.
# OVR is not supported anymore.
if number_of_classes > 2:
multi_class = 2

classifier_type = "LinearClassifier"
Expand All @@ -77,11 +93,28 @@ def convert_sklearn_linear_classifier(
):
classifier_attrs["post_transform"] = "NONE"
elif isinstance(op, LogisticRegression):
ovr = op.multi_class in ["ovr", "warn"] or (
op.multi_class == "auto"
and (op.classes_.size <= 2 or op.solver == "liblinear")
# See https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/linear_model/_logistic.py#L1423
classifier_attrs["post_transform"] = (
"LOGISTIC"
if (
use_ovr
or op.solver == "liblinear"
or (
hasattr(op, "multi_class")
and op.multi_class in ("auto", "deprecated")
and (
op.classes_.size <= 2
or op.solver in ("liblinear", "newton-cholesky")
)
)
or (
getattr(op, "multi_class", "auto") in ("auto", "deprecated")
and multi_class == 0
and (len(op.coef_.shape) <= 1 or min(op.coef_.shape) == 1)
)
)
else "SOFTMAX"
)
classifier_attrs["post_transform"] = "LOGISTIC" if ovr else "SOFTMAX"
else:
classifier_attrs["post_transform"] = (
"LOGISTIC" if multi_class > 2 else "SOFTMAX"
Expand Down
10 changes: 8 additions & 2 deletions tests/test_issues_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class TestInvestigate(unittest.TestCase):
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_issue_1053(self):
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
Expand Down Expand Up @@ -47,7 +48,7 @@ def test_issue_1053(self):
pv.Version(ort_version) < pv.Version("1.16.0"),
reason="opset 19 not implemented",
)
@ignore_warnings(category=(ConvergenceWarning,))
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_issue_1055(self):
import numpy as np
from numpy.testing import assert_almost_equal
Expand Down Expand Up @@ -118,6 +119,7 @@ def test_issue_1055(self):
pv.Version(ort_version) < pv.Version("1.17.3"),
reason="opset 19 not implemented",
)
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_issue_1069(self):
import math
from typing import Any
Expand Down Expand Up @@ -246,7 +248,11 @@ def Classifier(features: list[str]) -> base.BaseEstimator:
return classifier

model = Classifier(list(X_train.columns))
model.fit(X_train, y_train)
try:
model.fit(X_train, y_train)
except ValueError as e:
# If this fails, no need to go beyond.
raise unittest.SkipTest(str(e))

sample = X_train[:1].astype(numpy.float32)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_sklearn_feature_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_feature_union_transformer_weights_2(self):
X_test,
model,
model_onnx,
basename="SklearnFeatureUnionTransformerWeights2-Dec4",
basename="SklearnFeatureUnionTransformerWeights2-Dec3",
)


Expand Down
37 changes: 21 additions & 16 deletions tests/test_sklearn_pipeline_concat_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy
from numpy.testing import assert_almost_equal
from onnxruntime import InferenceSession
from onnxruntime.capi.onnxruntime_pybind11_state import Fail
import pandas
from sklearn import __version__ as sklearn_version

Expand Down Expand Up @@ -35,6 +34,12 @@ def skl12():
return pv.Version(vers) >= pv.Version("1.2")


def skl15():
# pv.Version does not work with development versions
vers = ".".join(sklearn_version.split(".")[:2])
return pv.Version(vers) >= pv.Version("1.5")


class TestSklearnPipelineConcatTfIdf(unittest.TestCase):
words = [
"ability",
Expand Down Expand Up @@ -319,7 +324,7 @@ def get_pipeline(N=10000):

@unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available")
@ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning))
@unittest.skipIf(not skl12(), reason="sparse_output")
@unittest.skipIf(not skl15(), reason="no working")
@unittest.skipIf(TARGET_OPSET < 18, reason="too long")
def test_issue_712_svc_binary(self):
pipe, dfx_test = TestSklearnPipelineConcatTfIdf.get_pipeline()
Expand All @@ -341,20 +346,20 @@ def test_issue_712_svc_binary(self):
got = sess.run(None, row_inputs)
assert_almost_equal(expected_dense[i], got[0])

with self.assertRaises(Fail):
# StringNormlizer removes empty strings after normalizer.
# This case happens when a string contains only stopwords.
# Then rows are missing and the output of the StringNormalizer
# and the OneHotEncoder output cannot be merged anymore with
# an error message like the following:
# onnxruntime.capi.onnxruntime_pybind11_state.Fail:
# [ONNXRuntimeError] : 1 : FAIL : Non-zero status code
# returned while running Concat node. Name:'Concat1'
# Status Message: concat.cc:159 onnxruntime::ConcatBase::
# PrepareForCompute Non concat axis dimensions must match:
# Axis 0 has mismatched dimensions of 2106 and 2500.
got = sess.run(None, inputs)
# assert_almost_equal(expected.todense(), got[0])
# It is fixed with scikit-learn==1.5.0.
# StringNormalizer removes empty strings after normalizer.
# This case happens when a string contains only stopwords.
# Then rows are missing and the output of the StringNormalizer
# and the OneHotEncoder output cannot be merged anymore with
# an error message like the following:
# onnxruntime.capi.onnxruntime_pybind11_state.Fail:
# [ONNXRuntimeError] : 1 : FAIL : Non-zero status code
# returned while running Concat node. Name:'Concat1'
# Status Message: concat.cc:159 onnxruntime::ConcatBase::
# PrepareForCompute Non concat axis dimensions must match:
# Axis 0 has mismatched dimensions of 2106 and 2500.
got = sess.run(None, inputs)
assert_almost_equal(expected.todense(), got[0])

@unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available")
@ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sklearn_pls_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_model_pls_regression(self):
model_onnx,
methods=["predict"],
basename="SklearnPLSRegression",
verbose=10,
verbose=0,
)

def test_model_pls_regression64(self):
Expand Down Expand Up @@ -91,4 +91,4 @@ def test_model_pls_regressionInt64(self):


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)
15 changes: 11 additions & 4 deletions tests/test_sklearn_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,12 @@ def test_count_vectorizer_english2(self):
raise AssertionError(
f"{mod1.token_pattern!r} != {mod2.token_pattern!r}"
)
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(f"{mod1.stop_words_} != {mod2.stop_words_}")

if hasattr(mod1, "stop_words_"):
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(
f"{mod1.stop_words_} != {mod2.stop_words_}"
)
if len(mod1.vocabulary_) != len(mod2.vocabulary_):
raise AssertionError(
f"skl_version={skl_version!r}, "
Expand Down Expand Up @@ -228,8 +232,11 @@ def test_tfidf_vectorizer_english2(self):
raise AssertionError(
f"{mod1.token_pattern!r} != {mod2.token_pattern!r}"
)
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(f"{mod1.stop_words_} != {mod2.stop_words_}")
if hasattr(mod1, "stop_words_"):
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(
f"{mod1.stop_words_} != {mod2.stop_words_}"
)
if len(mod1.vocabulary_) != len(mod2.vocabulary_):
raise AssertionError(
f"skl_version={skl_version!r}, "
Expand Down

0 comments on commit f0d452b

Please sign in to comment.