-
Notifications
You must be signed in to change notification settings - Fork 1
/
linreg.tex
103 lines (87 loc) · 4.04 KB
/
linreg.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
\subsection{Linear regression}
\label{sec:linreg}
We focus on performing linear regression between two distinct data owners, one providing the explanatory variable $X$ and another providing some target (or response) variable $y$. For simplicity, we assume $X$ is an $n \times p$ matrix with $n > p$ and $y$ a column vector, indexed row-wise by $X_i$ and $y_i$, respectively.
Solving the classical linear regression problem equates to finding the minimizer of the mean squared error objective
\begin{align*}
\optCoeff = \min_{\coeff}\instanceAvg (X_i \cdot w - y_i)^2
\end{align*}
Using the method of maximum likelihood estimation~\cite{mml}, we can reformulate this optimization problem into the product
\begin{align*}
\optCoeff &= (X^TX)^{-1}X^Ty
\\
&= Z \cdot y
\end{align*}
Note that evaluating $Z = (X^TX)^{-1}X^T$ can be performed as local preprocessing on plaintext data by the provider of $X$ (since this term does not involve the target variable $y$). Thus, solving this regression problem securely reduces to a single secret-shared matrix-vector product $\coeff = Z \cdot y$.
Additionally, we may want to compute certain metrics to assess the regression
outcome. These metrics must also be evaluated on secret-shared data. Concretely,
we present cryptographic protocols to support mean squared error
($\mathsf{MSE}$), mean absolute percentage error
($\mathsf{MAPE}$), and R-squared ($R^2$) metrics. Denote the target predictions $\ypred = X \cdot
\coeff$ and the expected target value as $\bar{y} = \instanceAvg y_i$. An
argument similar to the above reduces the secure computation of $R^2$ to
securely computing the residual sum of squares ($\mathsf{RSS}$):
\begin{align*}
R^2 &= 1 - \frac{\sum(y_i - \ypred_i)^2}{\sum(y_i - \bar{y})^2}
\\
&= 1 - \frac{RSS(\ypred, y)}{SS(y)}
\end{align*}
since only the $RSS$ term depends on data from both providers (i.e. both $X$ and $y$).
In summary, we are ultimately concerned with providing protocols for securely evaluating the following values:
\begin{align*}
\coeff &= Z \cdot y
\\
\ypred &= X \cdot \coeff
\\
\mathit{MSE}(\ypred, y) &= \instanceAvg (\ypred - y_i)^2
\\
\mathit{RSS}(\ypred, y) &= \sum_n (\ypred - y_i)^2
\\
\mathsf{MAPE}(\ypred, y) &= \frac{1}{n} \sum_n \Big| \frac{\ypred - y_i}{y_i} \Big|
\end{align*}
%In order to compute the MAPE (mean absolute percentage error) metric which boils down to compute the absolute value of a secret $\share{|\vx|}$.
% \definecolor{codegreen}{rgb}{0,0.6,0}
% \definecolor{codegray}{rgb}{0.5,0.5,0.5}
% \definecolor{codepurple}{rgb}{0.58,0,0.82}
% \definecolor{backcolour}{rgb}{0.95,0.95,0.92}
% \lstdefinestyle{mystyle}{
% backgroundcolor=\color{backcolour},
% commentstyle=\color{codegreen},
% keywordstyle=\color{magenta},
% numberstyle=\tiny\color{codegray},
% stringstyle=\color{codepurple},
% basicstyle=\ttfamily\footnotesize,
% breakatwhitespace=false,
% breaklines=true,
% captionpos=b,
% keepspaces=true,
% numbers=left,
% numbersep=5pt,
% showspaces=false,
% tabsize=2
% }
% \lstset{style=mystyle}
% \begin{lstlisting}[language=Python]
% @edsl.computation
% def lin_reg():
% with x_owner:
% X = edsl.load("X")
% bias = edsl.ones(edsl.slice(edsl.shape(X), begin=0, end=1))
% reshaped_bias = edsl.expand_dims(bias, 1)
% X_b = edsl.concatenate([reshaped_bias, X], axis=1)
% A = edsl.inverse(edsl.dot(edsl.transpose(X_b), X_b))
% B = edsl.dot(A, edsl.transpose(X_b))
% with y_owner:
% y_true = edsl.load("y")
% totals_ss = ss_tot(y_true)
% with replicated:
% w = edsl.dot(B, y_true)
% y_pred = edsl.dot(X_b, w)
% mse_result = mse(y_pred, y_true)
% residuals_ss = ss_res(y_pred, y_true)
% with x_owner:
% rsquared_result = r_squared(residuals_ss, totals_ss)
% w = edsl.identity(w)
% mse = edsl.identity(mse)
% residuals_ss = edsl.identity(residuals_ss)
% \end{lstlisting}
For compatibility with our protocols, all operations will be implemented for fixed-point numbers as described in Section~\ref{sec:fixed-point}.