From 7705a4bf62d2affa7d7d4de03f4ac9b33d98dbfe Mon Sep 17 00:00:00 2001 From: Liberty Askew Date: Tue, 27 Feb 2024 10:34:22 +0000 Subject: [PATCH 1/2] Fix for gaussian process --- .../operator_converters/gaussian_process.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/skl2onnx/operator_converters/gaussian_process.py b/skl2onnx/operator_converters/gaussian_process.py index 886eac706..ba3279c3f 100644 --- a/skl2onnx/operator_converters/gaussian_process.py +++ b/skl2onnx/operator_converters/gaussian_process.py @@ -126,12 +126,12 @@ def convert_gaussian_process_regressor( if len(mean_y.shape) == 1: mean_y = mean_y.reshape(mean_y.shape + (1,)) - if not hasattr(op, "_y_train_std") or op._y_train_std == 1: + if not hasattr(op, "_y_train_std") or np.all(op._y_train_std == 1): if isinstance(y_mean_b, (np.float32, np.float64)): y_mean_b = np.array([y_mean_b]) if isinstance(mean_y, (np.float32, np.float64)): mean_y = np.array([mean_y]) - y_mean = OnnxAdd(y_mean_b, mean_y, op_version=opv) + y_mean = OnnxAdd(y_mean_b, mean_y.T, op_version=opv) else: # A bug was fixed in 0.23 and it changed # the predictions when return_std is True. @@ -145,13 +145,13 @@ def convert_gaussian_process_regressor( if isinstance(mean_y, (np.float32, np.float64)): mean_y = np.array([mean_y]) y_mean = OnnxAdd( - OnnxMul(y_mean_b, var_y, op_version=opv), mean_y, op_version=opv + OnnxMul(y_mean_b, var_y.T, op_version=opv), mean_y.T, op_version=opv ) y_mean.set_onnx_name_prefix("gpr") y_mean_reshaped = OnnxReshapeApi13( y_mean, - np.array([-1, 1], dtype=np.int64), + np.array([-1, mean_y.shape[0]], dtype=np.int64), op_version=opv, output_names=out[:1], ) @@ -192,12 +192,13 @@ def convert_gaussian_process_regressor( # y_var[y_var_negative] = 0.0 ys0_var = OnnxMax(ys_var, np.array([0], dtype=dtype), op_version=opv) - if hasattr(op, "_y_train_std") and op._y_train_std != 1: + if hasattr(op, "_y_train_std"): # y_var = y_var * self._y_train_std**2 - ys0_var = OnnxMul(ys0_var, var_y**2, op_version=opv) + ys0_var = OnnxMul(var_y**2, ys0_var, op_version=opv) # var = np.sqrt(ys0_var) - var = OnnxSqrt(ys0_var, output_names=out[1:], op_version=opv) + var = OnnxSqrt(ys0_var, op_version=opv) + var = OnnxTranspose(var, output_names=out[1:], op_version=opv) var.set_onnx_name_prefix("gprv") outputs.append(var) @@ -413,4 +414,4 @@ def convert_gaussian_process_classifier( "output_class_labels": [False, True], "zipmap": [False, True], }, - ) + ) \ No newline at end of file From 2e82da5673a1890f485d5ff54cf5a878ebc6cd46 Mon Sep 17 00:00:00 2001 From: dreivmeister Date: Sat, 18 May 2024 13:17:59 +0200 Subject: [PATCH 2/2] add test for issue 1073 --- .../test_sklearn_gaussian_process_regressor.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_sklearn_gaussian_process_regressor.py b/tests/test_sklearn_gaussian_process_regressor.py index 1c1165d32..bc715ef55 100644 --- a/tests/test_sklearn_gaussian_process_regressor.py +++ b/tests/test_sklearn_gaussian_process_regressor.py @@ -1495,6 +1495,23 @@ def test_kernel_white_kernel(self): m1 = res m2 = ker(x, x) assert_almost_equal(m2, m1, decimal=5) + + def test_issue_1073(self): + # multioutput gpr + n_samples, n_features, n_targets = 1000, 8, 3 + X, y = make_regression(n_samples, n_features, n_targets=n_targets) + tx1, vx1, ty1, vy1 = train_test_split(X, y) + model = GaussianProcessRegressor() + model.fit(tx1, ty1) + initial_type = [("data_in", DoubleTensorType([None, X.shape[1]]))] + onx = to_onnx(model, initial_types=initial_type, target_opset=_TARGET_OPSET_) + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + pred = sess.run(None, {"data_in": vx1.astype(np.float64)}) + assert_almost_equal( + model.predict(vx1.astype(np.float64)).ravel(), pred[0].ravel() + ) if __name__ == "__main__": @@ -1503,4 +1520,5 @@ def test_kernel_white_kernel(self): # log.setLevel(logging.DEBUG) # logging.basicConfig(level=logging.DEBUG) # TestSklearnGaussianProcessRegressor().test_kernel_white_kernel() + #TestSklearnGaussianProcessRegressor().test_issue_1073() unittest.main(verbosity=2)