-
Notifications
You must be signed in to change notification settings - Fork 2
/
gaussproc.py
319 lines (262 loc) · 11.2 KB
/
gaussproc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import numpy as np
from functools import partial
from typing import Union, Dict, Callable, Optional, Tuple
import jax
import jaxopt
import jax.numpy as jnp
import jax.scipy as jsc
from jax import vmap, jit
jax.config.update("jax_enable_x64", True)
import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
MCMC,
NUTS,
init_to_feasible,
init_to_median,
init_to_sample,
init_to_uniform,
init_to_value,
)
from numpyro.handlers import seed, trace
numpyro.util.enable_x64()
if jax.__version__ < '0.2.26':
clear_cache = jax.interpreters.xla._xla_callable.cache_clear
else:
clear_cache = jax._src.dispatch._xla_callable.cache_clear
@jit
def _sqrt(x, eps=1e-12):
return jnp.sqrt(x + eps)
@jit
def square_scaled_distance(X, Z,lengthscale = 1.):
"""
Computes a square of scaled distance, :math:`\|\frac{X-Z}{l}\|^2`,
between X and Z are vectors with :math:`n x num_features` dimensions
"""
scaled_X = X / lengthscale
scaled_Z = Z / lengthscale
X2 = (scaled_X ** 2).sum(1, keepdims=True)
Z2 = (scaled_Z ** 2).sum(1, keepdims=True)
XZ = jnp.matmul(scaled_X, scaled_Z.T)
r2 = X2 - 2 * XZ + Z2.T
return r2.clip(0)
@jit
def kernel_RBF(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
r2 = square_scaled_distance(X, Z, params["k_length"])
k = params["k_scale"] * jnp.exp(-0.5 * r2)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
def kernel_Matern12(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
"""
Matern nu=1/2 kernel; exponentiel decay
"""
r2 = square_scaled_distance(X, Z, params["k_length"])
r = _sqrt(r2)
k = params["k_scale"] * jnp.exp(-r)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
########### Kernels #############
def kernel_Matern32(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
"""
Matern nu=3/2 kernel
"""
r2 = square_scaled_distance(X, Z, params["k_length"])
r = _sqrt(r2)
sqrt3_r = 3**0.5 * r
k = params["k_scale"] * (1.0 + sqrt3_r) * jnp.exp(-sqrt3_r)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
def kernel_Matern52(X: jnp.ndarray,
Z: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float =0.0, jitter: float=1.0e-6)-> jnp.ndarray:
"""
Matern nu=5/2 kernel
"""
r2 = square_scaled_distance(X, Z, params["k_length"])
r = _sqrt(r2)
sqrt5_r = 5**0.5 * r
k = params["k_scale"] * (1.0 + sqrt5_r + sqrt5_r**2 /3.0) * jnp.exp(-sqrt5_r)
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
########### Class #############
class GaussProc:
"""
Gaussian process class
Using C. E. Rasmussen & C. K. I. Williams,
Gaussian Processes for Machine Learning, the MIT Press, 2006
Alg. 2.1
Args:
kernel: GP kernel
mean_fn: optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic)
kernel_prior: optional custom priors over kernel hyperparameters
mean_fn_prior: optional priors over mean function parameters
noise_prior: optional custom prior for observation noise
"""
def __init__(self, kernel: Callable[[jnp.ndarray,
jnp.ndarray,
Dict[str, jnp.ndarray],
float, float],jnp.ndarray],
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None
) -> None:
clear_cache()
self.kernel = kernel
self.mean_fn = mean_fn
self.kernel_prior = kernel_prior
self.mean_fn_prior = mean_fn_prior
self.noise_prior = noise_prior
self.mcmc = None
def model(self, X:jnp.ndarray , y: jnp.ndarray):
"""GP probabilistic model"""
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])
# Sample kernel parameters
kernel_params = self.kernel_prior()
noise = self.noise_prior()
# Add mean function (if any)
if self.mean_fn is not None:
args = [X]
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# compute GP K(X,X)
K = self.kernel(X, X, kernel_params,noise)
# sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=K),
obs=y,
)
def fit(self, rng_key:jnp.array, X_train: jnp.ndarray, y_train: jnp.ndarray,
num_warmup: int = 1_000,
num_samples: int = 1_000,
num_chains: int = 1,
chain_method: str = 'vectorized',
dense_mass: bool = True,
progress_bar: bool = False,
print_summary: bool = True
) -> None:
"""
Fit GP Kernel parameters using MCMC NUTS
Args:
rng_key: random number generator key
X: 2D 'feature vector' with :math:`n x num_features`
y: 1D 'target vector' with :math:`(n,)` dimensions
num_warmup: number of MCMC warmup states
num_samples: number of MCMC samples
num_chains: number of MCMC chains
chain_method: 'sequential', 'parallel' or 'vectorized'
dense_mass: diagonal HMC mass matrix or full dense (optimized during warmup)
progress_bar: show progress bar
print_summary: print summary at the end of sampling
"""
init_strategy = init_to_median(num_samples=100)
kernel_nuts = NUTS(self.model, init_strategy=init_strategy, dense_mass=dense_mass)
self.mcmc = MCMC(
kernel_nuts,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
chain_method = chain_method,
progress_bar=progress_bar
)
self.mcmc.run(rng_key, X_train, y_train)
if print_summary:
self.mcmc.print_summary()
def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]:
"""Get posterior samples (after running the MCMC chains)"""
return self.mcmc.get_samples(group_by_chain=chain_dim)
def get_marginal_logprob(self, X_train: jnp.ndarray,
y_train: jnp.ndarray,
noise: float = None):
log2pi = jnp.log(2.*jnp.pi)
n_train = y_train.shape[0]
mlp = - n_train/2 * log2pi
samples = self.get_samples(chain_dim=False)
params = {name:np.mean(value) for name, value in samples.items()}
y_residual = y_train
if self.mean_fn is not None:
args = [X_train, params] if self.mean_fn_prior else [X_train]
y_residual -= self.mean_fn(*args).squeeze()
noise = noise if noise is not None else params["noise"]
k_XX = self.kernel(X_train, X_train, params, noise)
chol_XX = jsc.linalg.cholesky(k_XX, lower=True)
v = jsc.linalg.solve_triangular(chol_XX, y_residual, lower=True)
mlp -= 0.5 * (jnp.dot(v.T,v)+jnp.sum(jnp.log(jnp.diag(chol_XX))))
return mlp
@partial(jit, static_argnums=(0,))
def get_mvn_posterior_cholesky(self,
rng_key:jnp.array,
X_train: jnp.ndarray, y_train: jnp.ndarray,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise: float = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and mean+srd) of multivariate normal posterior
for a single sample of GP hyperparameters
Version with the use of Cholesky decompostion
"""
y_residual = y_train
if self.mean_fn is not None:
args = [X_train, params] if self.mean_fn_prior else [X_train]
y_residual -= self.mean_fn(*args).squeeze()
# compute kernel matrices for train and test data
k_pp = self.kernel(X_new, X_new, params, jitter=0.0)
k_pX = self.kernel(X_new, X_train, params, jitter=0.0)
k_XX = self.kernel(X_train, X_train, params, noise)
# compute the predictive covariance and mean
#### chol_XX = jsc.linalg.cholesky(k_XX, lower=True)
#kinv_XX_y = jsc.linalg.solve_triangular(
# chol_XX.T, jsc.linalg.solve_triangular(chol_XX, y_residual, lower=True))
chol_XX = jax.lax.linalg.cholesky(k_XX)
kinv_XX_y = jax.lax.linalg.triangular_solve(
chol_XX.T, jax.lax.linalg.triangular_solve(chol_XX, y_residual, lower=True, left_side=True),
left_side=True)
mean = jnp.matmul(k_pX, kinv_XX_y) # nb. K_pX = K(X_new, X_train) sometimes the transpose is used
if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()
#v = jsc.linalg.solve_triangular(chol_XX, k_pX.T, lower=True)
v = jax.lax.linalg.triangular_solve(chol_XX, k_pX.T, lower=True, left_side=True)
cov = k_pp - jnp.dot(v.T, v)
sigma = jnp.sqrt(jnp.clip(jnp.diag(cov), a_min=0.0))
#####*\jax.random.normal(rng_key, X_new.shape[:1])
return mean, sigma
@partial(jit, static_argnums=(0,))
def predict(self, rng_key: jnp.ndarray,
X_train: jnp.ndarray, y_train: jnp.ndarray,
X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
noise: float = None
) -> Tuple[jnp.ndarray, jnp.ndarray]:
X_new = X_new if X_new.ndim > 1 else X_new[:, None]
if samples is None:
samples = self.get_samples(chain_dim=False)
num_samples=samples[list(samples.keys())[0]].shape[0]
# do prediction
vmap_args = (
jax.random.split(rng_key, num_samples),
samples,
jnp.array([noise]*num_samples) if noise is not None else samples["noise"]
)
means, predictions = jit(vmap(
lambda rng_key, samples, noise: self.get_mvn_posterior_cholesky(
rng_key, X_train, y_train, X_new, samples, noise)
))(*vmap_args)
return means, predictions