Skip to content

Commit

Permalink
robustify optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Sep 27, 2023
1 parent 112933a commit bfb7c17
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions kulprit/projection/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,6 @@ def pps(self):
).values.T
return pps

def _init_optimization(self, term_names: List[str]) -> List[float]:
"""Initialise the optimization with the reference posterior means."""

return np.hstack(
[self.ref_idata.posterior.mean(["chain", "draw"])[term].values for term in term_names]
)

def _build_bounds(self, init: List[float]) -> list:
"""Build bounds for the parameters in the optimization.
Expand Down Expand Up @@ -139,9 +132,11 @@ def objective(

# Poisson observation likelihood
elif self.ref_family == "poisson":
linear_predictor = _linear_predict(beta_x=params, X=X)
lam = self.ref_model.family.link["mu"].linkinv(np.float64(linear_predictor))
neg_llk = self.neg_log_likelihood(points=obs, lam=lam)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="overflow encountered in exp")
linear_predictor = _linear_predict(beta_x=params, X=X)
lam = self.ref_model.family.link["mu"].linkinv(np.float64(linear_predictor))
neg_llk = self.neg_log_likelihood(points=obs, lam=lam)
return neg_llk

def solve(self, term_names: List[str], X: np.ndarray, slices: dict) -> SubModel:
Expand All @@ -164,26 +159,37 @@ def solve(self, term_names: List[str], X: np.ndarray, slices: dict) -> SubModel:
-------
SubModel: The projected submodel object
"""

# use reference model posterior as initial guess
init = self._init_optimization(term_names=term_names)

# build the optimization parameter bounds
init = np.hstack(
[self.ref_idata.posterior.mean(["chain", "draw"])[term].values for term in term_names]
)
bounds = self._build_bounds(init)

# perform mean-field variational projection predictive inference
res_posterior = []
objectives = []

chains = np.random.randint(0, self.ref_idata.posterior.dims["chain"], self.pps.shape[0])
draws = np.random.randint(0, self.ref_idata.posterior.dims["draw"], self.pps.shape[0])

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Values in x were outside bounds")
for obs in self.pps:
for idx, obs in enumerate(self.pps):
# use samples from reference model posterior as initial guess
init = np.hstack(
[
self.ref_idata.posterior.sel({"chain": chains[idx], "draw": draws[idx]})[
term
].values
for term in term_names
]
)
opt = minimize(
self.objective,
args=(obs, X),
x0=init,
tol=0.0001,
tol=0.001,
bounds=bounds,
# This is the fastest method and the projected posterior looks Gaussian-like
method="SLSQP",
)
res_posterior.append(opt.x)
Expand Down

0 comments on commit bfb7c17

Please sign in to comment.