Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fill_value for SimpleImputer with string data #1123

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions skl2onnx/operator_converters/imputer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ def convert_sklearn_imputer(
op_type = "Imputer"
attrs = {"name": scope.get_unique_operator_name(op_type)}
op = operator.raw_operator
if (
hasattr(op, "fill_value")
and isinstance(op.fill_value, str)
and op.fill_value.lower() != "nan"
):
raise RuntimeError(
"Imputer cannot fill missing values with a string '%s'." % op.fill_value
)
if not hasattr(op, "statistics_"):
raise RuntimeError("Member statistics_ is not present, was the model fitted?")

Expand Down Expand Up @@ -86,6 +78,14 @@ def convert_sklearn_imputer(

apply_concat(scope, names, operator.outputs[0].full_name, container, axis=1)
else:
if (
hasattr(op, "fill_value")
and isinstance(op.fill_value, str)
and op.fill_value.lower() != "nan"
):
raise RuntimeError(
"Imputer cannot fill missing values with a string '%s'." % op.fill_value
)
if isinstance(operator.inputs[0].type, Int64TensorType):
attrs["imputed_value_int64s"] = op.statistics_.astype(np.int64)
use_int = True
Expand Down
143 changes: 143 additions & 0 deletions tests/test_sklearn_imputer_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def _check_outputs_ints(self, model, model_onnx, data):
exp = model.transform(data)
assert_almost_equal(res, exp)

def _check_outputs_floats(self, model, model_onnx, data):
sess = InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
idata = {"input": np.array(data).astype(np.float32)}
res = sess.run(None, idata)[0]
exp = model.transform(data)
assert_almost_equal(res, exp)

def _check_outputs_strings(self, model, model_onnx, data, verbose=0):
idata = {"input": np.array(data).astype(np.str_)}
sess = InferenceSession(
Expand Down Expand Up @@ -206,6 +215,140 @@ def test_simple_imputer_string_inputs_int_mostf_default(self):
self.assertEqual(len(model_onnx.graph.output), 1)
self._check_outputs_strings(model, model_onnx, data)

@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
def test_simple_imputer_float_constant_default_fill_value(self):
model = SimpleImputer(strategy="constant")
data = [[1, 2], [np.nan, 3], [7, 6]]
model.fit(data)

model_onnx = convert_sklearn(
model,
"scikit-learn simple imputer",
[("input", FloatTensorType([None, 2]))],
target_opset=TARGET_OPSET,
)
self.assertIsNotNone(model_onnx.graph.node)

# should contain only node
self.assertEqual(len(model_onnx.graph.node), 1)

# last node should contain the Imputer
outputs = model_onnx.graph.output
self.assertEqual(len(outputs), 1)
self._check_outputs_floats(model, model_onnx, data)

@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
def test_simple_imputer_float_constant_provided_fill_value(self):
model = SimpleImputer(strategy="constant", fill_value=99.0)
data = [[1, 2], [np.nan, 3], [7, 6]]
model.fit(data)

model_onnx = convert_sklearn(
model,
"scikit-learn simple imputer",
[("input", FloatTensorType([None, 2]))],
target_opset=TARGET_OPSET,
)
self.assertIsNotNone(model_onnx.graph.node)

# should contain only node
self.assertEqual(len(model_onnx.graph.node), 1)

# last node should contain the Imputer
outputs = model_onnx.graph.output
self.assertEqual(len(outputs), 1)
self._check_outputs_floats(model, model_onnx, data)

@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
def test_simple_imputer_int_constant_default_fill_value(self):
model = SimpleImputer(strategy="constant")
data = [[1, 2], [np.nan, 3], [7, 6], [8, np.nan]]
model.fit(data)

model_onnx = convert_sklearn(
model,
"scikit-learn simple imputer",
[("input", Int64TensorType([None, 2]))],
target_opset=TARGET_OPSET,
)
self.assertIsNotNone(model_onnx.graph.node)

# should contain only node
self.assertEqual(len(model_onnx.graph.node), 1)

# last node should contain the Imputer
outputs = model_onnx.graph.output
self.assertEqual(len(outputs), 1)
self._check_outputs_ints(model, model_onnx, data)

@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
def test_simple_imputer_int_constant_provided_fill_value(self):
model = SimpleImputer(strategy="constant", fill_value=99)
data = [[1, 2], [np.nan, 3], [7, 6], [8, np.nan]]
model.fit(data)

model_onnx = convert_sklearn(
model,
"scikit-learn simple imputer",
[("input", Int64TensorType([None, 2]))],
target_opset=TARGET_OPSET,
)
self.assertIsNotNone(model_onnx.graph.node)

# should contain only node
self.assertEqual(len(model_onnx.graph.node), 1)

# last node should contain the Imputer
outputs = model_onnx.graph.output
self.assertEqual(len(outputs), 1)
self._check_outputs_ints(model, model_onnx, data)

@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
@unittest.skipIf(
pv.Version(skl_ver) < pv.Version("0.24"),
reason="SimpleImputer does not support strings",
)
def test_simple_imputer_string_inputs_constant_provided_fill_value(self):
model = SimpleImputer(
strategy="constant", missing_values="", fill_value="missing"
)
data = pd.DataFrame(
[["s1", "s2"], ["s1", "s2"], ["", "s3"], ["s7", "s6"], ["s8", ""]]
)
model.fit(data)
model_onnx = convert_sklearn(
model,
"scikit-learn simple imputer",
[("input", StringTensorType([None, 2]))],
target_opset=TARGET_OPSET,
)
self.assertIn("ai.onnx.ml", str(model_onnx))
self.assertIsNotNone(model_onnx.graph.node)
self.assertEqual(len(model_onnx.graph.output), 1)
self._check_outputs_strings(model, model_onnx, data)

@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
@unittest.skipIf(
pv.Version(skl_ver) < pv.Version("0.24"),
reason="SimpleImputer does not support strings",
)
def test_simple_imputer_string_inputs_constant_default_fill_value(self):
model = SimpleImputer(strategy="constant", missing_values="")
data = pd.DataFrame(
[["s1", "s2"], ["s1", "s2"], ["", "s3"], ["s7", "s6"], ["s8", ""]]
)
model.fit(data)
model_onnx = convert_sklearn(
model,
"scikit-learn simple imputer",
[("input", StringTensorType([None, 2]))],
target_opset=TARGET_OPSET,
)
self.assertIn("ai.onnx.ml", str(model_onnx))
self.assertIsNotNone(model_onnx.graph.node)
self.assertEqual(len(model_onnx.graph.output), 1)
self._check_outputs_strings(model, model_onnx, data)


if __name__ == "__main__":
unittest.main()
Loading