Skip to content

Commit

Permalink
Avoid concatenation if not needed (#1110)
Browse files Browse the repository at this point in the history
* Avoid concatenation if not needed

Signed-off-by: Xavier Dupre <[email protected]>

* change

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Jun 19, 2024
1 parent 90e3d86 commit 99939ef
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.18.0

* Converter for OneHotEncoder does not add a concat operator if not needed,
[#1110](https://github.com/onnx/sklearn-onnx/pull/1110)
* Function ``to_onnx`` now forces the main opset to be equal to the
value speficied by the user (parameter ``target_opset``),
[#1109](https://github.com/onnx/sklearn-onnx/pull/1109)
Expand Down
7 changes: 5 additions & 2 deletions skl2onnx/operator_converters/one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,11 @@ def convert_sklearn_one_hot_encoder(
result.append(ohe_output)
categories_len += len(categories)

concat_result_name = scope.get_unique_variable_name("concat_result")
apply_concat(scope, result, concat_result_name, container, axis=-1)
if len(result) == 1:
concat_result_name = result[0]
else:
concat_result_name = scope.get_unique_variable_name("concat_result")
apply_concat(scope, result, concat_result_name, container, axis=-1)

reshape_input = concat_result_name
if np.issubdtype(ohe_op.dtype, np.signedinteger):
Expand Down

0 comments on commit 99939ef

Please sign in to comment.