diff --git a/skl2onnx/operator_converters/one_hot_encoder.py b/skl2onnx/operator_converters/one_hot_encoder.py index f12bbd814..95eb8b7b6 100644 --- a/skl2onnx/operator_converters/one_hot_encoder.py +++ b/skl2onnx/operator_converters/one_hot_encoder.py @@ -45,14 +45,28 @@ def convert_sklearn_one_hot_encoder( enum_cats = [] index_inputs = 0 + to_drop = ohe_op._drop_idx_after_grouping + if to_drop is not None: + # raise NotImplementedError( + # f"The converter is not implemented when " + # f"_drop_idx_after_grouping is not None: {to_drop}." + # ) + pass for index, cats in enumerate(ohe_op.categories_): + filtered_cats = ohe_op._compute_transformed_categories(index) + n_cats = None + if ohe_op._infrequent_enabled and "infrequent_sklearn" in filtered_cats: + n_cats = len(filtered_cats) - 1 + cats = np.hstack( + [filtered_cats[:-1], [c for c in cats if c not in filtered_cats]] + ) while sum(all_shapes[: index_inputs + 1]) <= index: index_inputs += 1 index_in_input = index - sum(all_shapes[:index_inputs]) inp = operator.inputs[index_inputs] - if not isinstance( + assert isinstance( inp.type, ( Int64TensorType, @@ -61,30 +75,50 @@ def convert_sklearn_one_hot_encoder( FloatTensorType, DoubleTensorType, ), - ): - raise NotImplementedError( - "{} input datatype not yet supported. " - "You may raise an issue at " - "https://github.com/onnx/sklearn-onnx/issues" - "".format(type(inp.type)) - ) + ), ( + f"{type(inp.type)} input datatype not yet supported. " + f"You may raise an issue at " + f"https://github.com/onnx/sklearn-onnx/issues" + ) if all_shapes[index_inputs] == 1: assert index_in_input == 0 afeat = False else: afeat = True - enum_cats.append((afeat, index_in_input, inp.full_name, cats, inp.type)) + enum_cats.append( + (afeat, index_in_input, inp.full_name, cats, inp.type, n_cats) + ) else: inp = operator.inputs[0] - enum_cats = [ - (True, i, inp.full_name, cats, inp.type) - for i, cats in enumerate(ohe_op.categories_) - ] + assert isinstance( + inp.type, + ( + Int64TensorType, + StringTensorType, + Int32TensorType, + FloatTensorType, + DoubleTensorType, + ), + ), ( + f"{type(inp.type)} input datatype not yet supported. " + f"You may raise an issue at " + f"https://github.com/onnx/sklearn-onnx/issues" + ) + + enum_cats = [] + for index, cats in enumerate(ohe_op.categories_): + filtered_cats = ohe_op._compute_transformed_categories(index) + if ohe_op._infrequent_enabled and "infrequent_sklearn" in filtered_cats: + raise NotImplementedError( + f"Infrequent categories are not implemented " + f"{filtered_cats} != {cats}." + ) + enum_cats.append((True, index, inp.full_name, cats, inp.type, None)) result, categories_len = [], 0 for index, enum_c in enumerate(enum_cats): - afeat, index_in, name, categories, inp_type = enum_c + afeat, index_in, name, categories, inp_type, n_cats = enum_c container.debug( "[conv.OneHotEncoder] cat %r/%r name=%r type=%r", index + 1, @@ -138,8 +172,8 @@ def convert_sklearn_one_hot_encoder( attrs["cats_int64s"] = categories.astype(np.int64) else: raise RuntimeError( - "Input type {} is not supported for OneHotEncoder. " - "Ideally, it should either be integer or strings.".format(inp_type) + f"Input type {inp_type} is not supported for OneHotEncoder. " + f"Ideally, it should either be integer or strings." ) ohe_output = scope.get_unique_variable_name(name + "out") @@ -155,11 +189,15 @@ def convert_sklearn_one_hot_encoder( container.add_node( "OneHotEncoder", name, ohe_output, op_domain="ai.onnx.ml", **attrs ) + if ( hasattr(ohe_op, "drop_idx_") and ohe_op.drop_idx_ is not None and ohe_op.drop_idx_[index] is not None ): + assert ( + n_cats is None + ), "drop_idx_ not implemented where there infrequent_categories" extracted_outputs_name = scope.get_unique_variable_name("extracted_outputs") indices_to_keep_name = scope.get_unique_variable_name("indices_to_keep") indices_to_keep = np.delete( @@ -179,6 +217,30 @@ def convert_sklearn_one_hot_encoder( name=scope.get_unique_operator_name("Gather"), ) ohe_output, categories = extracted_outputs_name, indices_to_keep + elif n_cats is not None: + split_name = scope.get_unique_variable_name("split_name") + container.add_initializer( + split_name, + onnx_proto.TensorProto.INT64, + [2], + [n_cats, len(categories) - n_cats], + ) + + spl1 = scope.get_unique_variable_name("split1") + spl2 = scope.get_unique_variable_name("split2") + + # let's sum every counts after n_cats + container.add_node("Split", [ohe_output, split_name], [spl1, spl2], axis=-1) + axis_name = scope.get_unique_variable_name("axis_name") + container.add_initializer( + axis_name, onnx_proto.TensorProto.INT64, [1], [-1] + ) + red_name = scope.get_unique_variable_name("red_name") + container.add_node("ReduceSum", [spl2, axis_name], red_name, keepdims=1) + conc_name = scope.get_unique_variable_name("conc_name") + container.add_node("Concat", [spl1, red_name], conc_name, axis=-1) + ohe_output = conc_name + categories = categories[: n_cats + 1] result.append(ohe_output) categories_len += len(categories) diff --git a/tests/test_sklearn_one_hot_encoder_converter.py b/tests/test_sklearn_one_hot_encoder_converter.py index 5073c4c1d..ec705fb9d 100644 --- a/tests/test_sklearn_one_hot_encoder_converter.py +++ b/tests/test_sklearn_one_hot_encoder_converter.py @@ -21,13 +21,14 @@ from sklearn.compose import ColumnTransformer from sklearn.impute import SimpleImputer from sklearn.linear_model import LinearRegression -from skl2onnx import convert_sklearn +from skl2onnx import convert_sklearn, to_onnx from skl2onnx.common.data_types import ( Int32TensorType, Int64TensorType, StringTensorType, FloatTensorType, ) +from skl2onnx.algebra.type_helper import guess_initial_types try: # scikit-learn >= 0.22 @@ -471,6 +472,71 @@ def test_shape_inference_onnx(self): def test_shape_inference_onnxruntime(self): self._shape_inference("onnxruntime") + def test_min_frequency(self): + data = pandas.DataFrame( + [ + dict(CAT1="aa", CAT2="ba", num1=0.5, num2=0.6, y=0), + dict(CAT1="ab", CAT2="bb", num1=0.4, num2=0.8, y=1), + dict(CAT1="ac", CAT2="bb", num1=0.4, num2=0.8, y=1), + dict(CAT1="ab", CAT2="bc", num1=0.5, num2=0.56, y=0), + dict(CAT1="ab", CAT2="bd", num1=0.55, num2=0.56, y=1), + dict(CAT1="ab", CAT2="bd", num1=0.35, num2=0.86, y=0), + dict(CAT1="ab", CAT2="bd", num1=0.5, num2=0.68, y=1), + ] + ) + cat_cols = ["CAT1", "CAT2"] + train_data = data.drop("y", axis=1) + for c in train_data.columns: + if c not in cat_cols: + train_data[c] = train_data[c].astype(numpy.float32) + + pipe = Pipeline( + [ + ( + "preprocess", + ColumnTransformer( + transformers=[ + ( + "cat", + Pipeline( + [ + ( + "onehot", + OneHotEncoder( + min_frequency=2, + sparse_output=False, + handle_unknown="ignore", + ), + ) + ] + ), + cat_cols, + ) + ], + remainder="passthrough", + ), + ), + ] + ) + pipe.fit(train_data, data["y"]) + + init = guess_initial_types(train_data, None) + self.assertEqual([i[0] for i in init], "CAT1 CAT2 num1 num2".split()) + for t in init: + self.assertEqual(t[1].shape, [None, 1]) + onx2 = to_onnx(pipe, initial_types=init) + with open("kkkk.onnx", "wb") as f: + f.write(onx2.SerializeToString()) + sess2 = InferenceSession( + onx2.SerializeToString(), providers=["CPUExecutionProvider"] + ) + + inputs = {c: train_data[c].values.reshape((-1, 1)) for c in train_data.columns} + got2 = sess2.run(None, inputs) + + expected = pipe.transform(train_data) + assert_almost_equal(expected, got2[0]) + @unittest.skipIf( not one_hot_encoder_supports_drop(), reason="OneHotEncoder does not support drop in scikit versions < 0.21", @@ -510,5 +576,6 @@ def test_one_hot_encoder_drop_if_binary(self): for name in ["skl2onnx"]: log = logging.getLogger(name) log.setLevel(logging.ERROR) - TestSklearnOneHotEncoderConverter().test_one_hot_encoder_drop_if_binary() + # TestSklearnOneHotEncoderConverter().test_min_frequency() + # TestSklearnOneHotEncoderConverter().test_one_hot_encoder_drop_if_binary() unittest.main(verbosity=2)