Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the converters for scikit-learn==1.5.0 #1095

Merged
merged 44 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7ae5720
Extend CI to test with onnxruntime==1.18.0
xadupre May 20, 2024
25b2199
update doc
xadupre May 20, 2024
9e39c80
simplify pipelines
xadupre May 20, 2024
d28747a
rename master into main
xadupre May 20, 2024
e418fbc
action
xadupre May 21, 2024
802b5e2
ci
xadupre May 21, 2024
b3ac78e
ci
xadupre May 21, 2024
47b68e1
update CI
xadupre May 21, 2024
a747472
update CI
xadupre May 21, 2024
c7f5423
update CI
xadupre May 21, 2024
c86a45f
update CI
xadupre May 21, 2024
1f10b03
update CI
xadupre May 21, 2024
f4a5420
update CI
xadupre May 21, 2024
d995936
update CI
xadupre May 21, 2024
5b35462
update CI
xadupre May 21, 2024
b5fd4ab
update CI
xadupre May 21, 2024
e63d138
ci
xadupre May 21, 2024
664c228
fix ci
xadupre May 21, 2024
bbd5a9a
example
xadupre May 21, 2024
357a015
remove benchmark
xadupre May 21, 2024
1841af5
doc
xadupre May 21, 2024
d3237b7
ci
xadupre May 21, 2024
06165f1
ci
xadupre May 21, 2024
6b4c7e3
ci
xadupre May 21, 2024
1dfdd94
ci
xadupre May 21, 2024
3975ac4
fix ci
xadupre May 21, 2024
aaa2873
fix unittest
xadupre May 22, 2024
92f965c
fix ci
xadupre May 22, 2024
218a1e7
fix ci
xadupre May 22, 2024
be3e3da
fix title
xadupre May 22, 2024
293391d
Check scikit-learn==1.5.0
xadupre May 22, 2024
03db70c
fix disc
xadupre May 22, 2024
b396d8a
better ci
xadupre May 22, 2024
6e0c7a9
ci
xadupre May 22, 2024
4fbec73
linear
xadupre May 22, 2024
5c7b59a
ci
xadupre May 22, 2024
f454f12
fix two unit tests
xadupre May 22, 2024
905bb04
Merge branch 'main' of https://github.com/onnx/sklearn-onnx into skl150
xadupre May 22, 2024
6b355b4
fix PLSRegression
xadupre May 22, 2024
bc1f054
fix version
xadupre May 22, 2024
27f3885
ci
xadupre May 22, 2024
a441fff
ci
xadupre May 22, 2024
9538acf
precision
xadupre May 22, 2024
93387ad
fix unit test
xadupre May 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/linux-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python_version: ['3.12', '3.11', '3.10', '3.9']
sklearn_version: ['==1.5.0', '==1.4.2', '==1.3.2', '==1.2.2', '==1.1.3']
include:
- sklearn_version: '==1.5.0'
documentation: 0
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
python_version: '3.12'
- python_version: '3.12'
documentation: 0
numpy_version: '>=1.21.1'
Expand Down
24 changes: 21 additions & 3 deletions .github/workflows/windows-macos-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@ name: CI Win/Mac
on: [push, pull_request]
jobs:
run:
name: ${{ matrix.os }} py==${{ matrix.python_version }} - sklearn==${{ matrix.sklearn_version }} - ${{ matrix.onnxrt_version }}
name: ${{ matrix.os }} py==${{ matrix.python_version }} - sklearn${{ matrix.sklearn_version }} - ${{ matrix.onnxrt_version }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, macos-latest]
python_version: ['3.11', '3.10', '3.9']
sklearn_version: ['==1.5.0', '==1.4.2', '==1.3.2', '==1.2.2', '==1.1.3']
include:
- sklearn_version: '==1.5.0'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx==1.16.0'
onnxrt_version: 'onnxruntime==1.18.0'
- python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx<1.16.0'
onnxrt_version: 'onnxruntime==1.17.3'
onnxrt_version: 'onnxruntime<1.18.0'
sklearn_version: '==1.3.2'
- python_version: '3.10'
numpy_version: '>=1.21.1'
Expand All @@ -27,6 +33,18 @@ jobs:
onnx_version: 'onnx<1.14'
onnxrt_version: 'onnxruntime<1.16.0'
sklearn_version: '==1.2.2'
- sklearn_version: '==1.4.2'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx>=1.16.0'
onnxrt_version: 'onnxruntime>=1.18.0'
- sklearn_version: '==1.1.3'
python_version: '3.11'
numpy_version: '>=1.21.1'
scipy_version: '>=1.7.0'
onnx_version: 'onnx>=1.16.0'
onnxrt_version: 'onnxruntime>=1.18.0'

steps:
- name: Checkout repository
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.17.0 (development)

* Minor fixes to support scikit-learn==1.5.0
[#1095](https://github.com/onnx/sklearn-onnx/pull/1095)
* Fix the conversion of pipeline including pipelines,
issue [#1069](https://github.com/onnx/sklearn-onnx/pull/1069),
[#1072](https://github.com/onnx/sklearn-onnx/pull/1072)
Expand Down
38 changes: 27 additions & 11 deletions skl2onnx/operator_converters/cross_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from ..algebra.onnx_ops import OnnxAdd, OnnxCast, OnnxDiv, OnnxMatMul, OnnxSub


def _skl150() -> bool:
import sklearn
import packaging.version as pv

return pv.Version(sklearn.__version__) >= pv.Version("1.5.0")


def convert_pls_regression(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
Expand All @@ -27,19 +34,28 @@ def convert_pls_regression(

coefs = op.x_mean_ if hasattr(op, "x_mean_") else op._x_mean
std = op.x_std_ if hasattr(op, "x_std_") else op._x_std
ym = op.y_mean_ if hasattr(op, "x_mean_") else op._y_mean

norm_x = OnnxDiv(
OnnxSub(X, coefs.astype(dtype), op_version=opv),
std.astype(dtype),
op_version=opv,
)
if hasattr(op, "set_predict_request"):
# new in 1.3
if hasattr(op, "intercept_") and _skl150():
# scikit-learn==1.5.0
# https://github.com/scikit-learn/scikit-learn/pull/28612
ym = op.intercept_
centered_x = OnnxSub(X, coefs.astype(dtype), op_version=opv)
coefs = op.coef_.T.astype(dtype)
dot = OnnxMatMul(centered_x, coefs, op_version=opv)
else:
coefs = op.coef_.astype(dtype)
dot = OnnxMatMul(norm_x, coefs, op_version=opv)
ym = op.y_mean_ if hasattr(op, "x_mean_") else op._y_mean

norm_x = OnnxDiv(
OnnxSub(X, coefs.astype(dtype), op_version=opv),
std.astype(dtype),
op_version=opv,
)
if hasattr(op, "set_predict_request"):
# new in 1.3
coefs = op.coef_.T.astype(dtype)
else:
coefs = op.coef_.astype(dtype)
dot = OnnxMatMul(norm_x, coefs, op_version=opv)

pred = OnnxAdd(dot, ym.astype(dtype), op_version=opv, output_names=operator.outputs)
pred.add_to(scope, container)

Expand Down
43 changes: 38 additions & 5 deletions skl2onnx/operator_converters/linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,26 @@ def convert_sklearn_linear_classifier(
intercepts = list(map(lambda x: -1 * x, intercepts)) + intercepts

multi_class = 0
use_ovr = False
if hasattr(op, "multi_class"):
if op.multi_class == "ovr":
multi_class = 1
else:
elif number_of_classes > 2:
# See https://scikit-learn.org/dev/modules/generated/
# sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression
# multi_class attribute is deprecated.
# OVR is not supported anymore.
multi_class = 2
use_ovr = op.multi_class in ["ovr", "warn"] or (
op.multi_class == "auto"
and (op.classes_.size <= 2 or op.solver == "liblinear")
)
else:
# See https://scikit-learn.org/dev/modules/generated/
# sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression
# multi_class attribute is deprecated.
# OVR is not supported anymore.
if number_of_classes > 2:
multi_class = 2

classifier_type = "LinearClassifier"
Expand All @@ -77,11 +93,28 @@ def convert_sklearn_linear_classifier(
):
classifier_attrs["post_transform"] = "NONE"
elif isinstance(op, LogisticRegression):
ovr = op.multi_class in ["ovr", "warn"] or (
op.multi_class == "auto"
and (op.classes_.size <= 2 or op.solver == "liblinear")
# See https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/linear_model/_logistic.py#L1423
classifier_attrs["post_transform"] = (
"LOGISTIC"
if (
use_ovr
or op.solver == "liblinear"
or (
hasattr(op, "multi_class")
and op.multi_class in ("auto", "deprecated")
and (
op.classes_.size <= 2
or op.solver in ("liblinear", "newton-cholesky")
)
)
or (
getattr(op, "multi_class", "auto") in ("auto", "deprecated")
and multi_class == 0
and (len(op.coef_.shape) <= 1 or min(op.coef_.shape) == 1)
)
)
else "SOFTMAX"
)
classifier_attrs["post_transform"] = "LOGISTIC" if ovr else "SOFTMAX"
else:
classifier_attrs["post_transform"] = (
"LOGISTIC" if multi_class > 2 else "SOFTMAX"
Expand Down
10 changes: 8 additions & 2 deletions tests/test_issues_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class TestInvestigate(unittest.TestCase):
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_issue_1053(self):
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
Expand Down Expand Up @@ -47,7 +48,7 @@ def test_issue_1053(self):
pv.Version(ort_version) < pv.Version("1.16.0"),
reason="opset 19 not implemented",
)
@ignore_warnings(category=(ConvergenceWarning,))
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_issue_1055(self):
import numpy as np
from numpy.testing import assert_almost_equal
Expand Down Expand Up @@ -118,6 +119,7 @@ def test_issue_1055(self):
pv.Version(ort_version) < pv.Version("1.17.3"),
reason="opset 19 not implemented",
)
@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_issue_1069(self):
import math
from typing import Any
Expand Down Expand Up @@ -246,7 +248,11 @@ def Classifier(features: list[str]) -> base.BaseEstimator:
return classifier

model = Classifier(list(X_train.columns))
model.fit(X_train, y_train)
try:
model.fit(X_train, y_train)
except ValueError as e:
# If this fails, no need to go beyond.
raise unittest.SkipTest(str(e))

sample = X_train[:1].astype(numpy.float32)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_sklearn_feature_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_feature_union_transformer_weights_2(self):
X_test,
model,
model_onnx,
basename="SklearnFeatureUnionTransformerWeights2-Dec4",
basename="SklearnFeatureUnionTransformerWeights2-Dec3",
)


Expand Down
37 changes: 21 additions & 16 deletions tests/test_sklearn_pipeline_concat_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy
from numpy.testing import assert_almost_equal
from onnxruntime import InferenceSession
from onnxruntime.capi.onnxruntime_pybind11_state import Fail
import pandas
from sklearn import __version__ as sklearn_version

Expand Down Expand Up @@ -35,6 +34,12 @@ def skl12():
return pv.Version(vers) >= pv.Version("1.2")


def skl15():
# pv.Version does not work with development versions
vers = ".".join(sklearn_version.split(".")[:2])
return pv.Version(vers) >= pv.Version("1.5")


class TestSklearnPipelineConcatTfIdf(unittest.TestCase):
words = [
"ability",
Expand Down Expand Up @@ -319,7 +324,7 @@ def get_pipeline(N=10000):

@unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available")
@ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning))
@unittest.skipIf(not skl12(), reason="sparse_output")
@unittest.skipIf(not skl15(), reason="no working")
@unittest.skipIf(TARGET_OPSET < 18, reason="too long")
def test_issue_712_svc_binary(self):
pipe, dfx_test = TestSklearnPipelineConcatTfIdf.get_pipeline()
Expand All @@ -341,20 +346,20 @@ def test_issue_712_svc_binary(self):
got = sess.run(None, row_inputs)
assert_almost_equal(expected_dense[i], got[0])

with self.assertRaises(Fail):
# StringNormlizer removes empty strings after normalizer.
# This case happens when a string contains only stopwords.
# Then rows are missing and the output of the StringNormalizer
# and the OneHotEncoder output cannot be merged anymore with
# an error message like the following:
# onnxruntime.capi.onnxruntime_pybind11_state.Fail:
# [ONNXRuntimeError] : 1 : FAIL : Non-zero status code
# returned while running Concat node. Name:'Concat1'
# Status Message: concat.cc:159 onnxruntime::ConcatBase::
# PrepareForCompute Non concat axis dimensions must match:
# Axis 0 has mismatched dimensions of 2106 and 2500.
got = sess.run(None, inputs)
# assert_almost_equal(expected.todense(), got[0])
# It is fixed with scikit-learn==1.5.0.
# StringNormalizer removes empty strings after normalizer.
# This case happens when a string contains only stopwords.
# Then rows are missing and the output of the StringNormalizer
# and the OneHotEncoder output cannot be merged anymore with
# an error message like the following:
# onnxruntime.capi.onnxruntime_pybind11_state.Fail:
# [ONNXRuntimeError] : 1 : FAIL : Non-zero status code
# returned while running Concat node. Name:'Concat1'
# Status Message: concat.cc:159 onnxruntime::ConcatBase::
# PrepareForCompute Non concat axis dimensions must match:
# Axis 0 has mismatched dimensions of 2106 and 2500.
got = sess.run(None, inputs)
assert_almost_equal(expected.todense(), got[0])

@unittest.skipIf(TARGET_OPSET < 11, reason="SequenceConstruct not available")
@ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sklearn_pls_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_model_pls_regression(self):
model_onnx,
methods=["predict"],
basename="SklearnPLSRegression",
verbose=10,
verbose=0,
)

def test_model_pls_regression64(self):
Expand Down Expand Up @@ -91,4 +91,4 @@ def test_model_pls_regressionInt64(self):


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)
15 changes: 11 additions & 4 deletions tests/test_sklearn_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,12 @@ def test_count_vectorizer_english2(self):
raise AssertionError(
f"{mod1.token_pattern!r} != {mod2.token_pattern!r}"
)
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(f"{mod1.stop_words_} != {mod2.stop_words_}")

if hasattr(mod1, "stop_words_"):
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(
f"{mod1.stop_words_} != {mod2.stop_words_}"
)
if len(mod1.vocabulary_) != len(mod2.vocabulary_):
raise AssertionError(
f"skl_version={skl_version!r}, "
Expand Down Expand Up @@ -228,8 +232,11 @@ def test_tfidf_vectorizer_english2(self):
raise AssertionError(
f"{mod1.token_pattern!r} != {mod2.token_pattern!r}"
)
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(f"{mod1.stop_words_} != {mod2.stop_words_}")
if hasattr(mod1, "stop_words_"):
if len(mod1.stop_words_) != len(mod2.stop_words_):
raise AssertionError(
f"{mod1.stop_words_} != {mod2.stop_words_}"
)
if len(mod1.vocabulary_) != len(mod2.vocabulary_):
raise AssertionError(
f"skl_version={skl_version!r}, "
Expand Down
Loading