Skip to content

Commit

Permalink
Use preliz (#53)
Browse files Browse the repository at this point in the history
* use preliz

* update req

* arviz req

* fix linter

* new version
  • Loading branch information
aloctavodia authored Mar 21, 2024
1 parent 7ad7c4a commit dae5202
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 215 deletions.
116 changes: 0 additions & 116 deletions kulprit/projection/likelihood.py

This file was deleted.

42 changes: 14 additions & 28 deletions kulprit/projection/solver.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""optimization module."""

# pylint: disable=protected-access
from typing import List, Optional
import warnings

import arviz as az
import bambi as bmb
import numpy as np
import xarray as xr
import preliz as pz

from scipy.optimize import minimize

from kulprit.data.submodel import SubModel
from kulprit.projection.likelihood import LIKELIHOODS


class Solver:
Expand All @@ -36,11 +37,8 @@ def __init__(
self.num_chain = self.ref_idata.posterior.dims["chain"]
self.num_samples = self.num_chain * 100

try:
# define the negative log likelihood function of the submodel
self.neg_log_likelihood = LIKELIHOODS[self.ref_family]
except KeyError:
raise NotImplementedError from None
if self.ref_family not in ["gaussian", "poisson", "bernoulli", "binomial"]:
raise NotImplementedError(f"Family {self.ref_family} not supported")

@property
def pps(self):
Expand Down Expand Up @@ -115,28 +113,29 @@ def objective(
# Gaussian observation likelihood
if self.ref_family == "gaussian":
linear_predictor = _linear_predict(beta_x=params[:-1], X=X)
neg_llk = self.neg_log_likelihood(points=obs, mean=linear_predictor, sigma=params[-1])
neg_llk = pz.Normal(mu=linear_predictor, sigma=params[-1])._neg_logpdf(obs)

# Binomial observation likelihood
elif self.ref_family == "binomial":
trials = self.ref_model.response.data[:, 1]
linear_predictor = _linear_predict(beta_x=params, X=X)
probs = self.ref_model.family.link["p"].linkinv(linear_predictor)
neg_llk = self.neg_log_likelihood(points=obs, probs=probs, trials=trials)
neg_llk = pz.Binomial(n=trials, p=probs)._neg_logpdf(obs)

# Bernoulli observation likelihood
elif self.ref_family == "bernoulli":
linear_predictor = _linear_predict(beta_x=params, X=X)
probs = self.ref_model.family.link["p"].linkinv(linear_predictor)
neg_llk = self.neg_log_likelihood(points=obs, probs=probs)
neg_llk = pz.Binomial(p=probs)._neg_logpdf(obs)

# Poisson observation likelihood
elif self.ref_family == "poisson":
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="overflow encountered in exp")
linear_predictor = _linear_predict(beta_x=params, X=X)
lam = self.ref_model.family.link["mu"].linkinv(np.float64(linear_predictor))
neg_llk = self.neg_log_likelihood(points=obs, lam=lam)
neg_llk = pz.Poisson(mu=lam)._neg_logpdf(obs)

return neg_llk

def solve(self, term_names: List[str], X: np.ndarray, slices: dict) -> SubModel:
Expand All @@ -160,34 +159,21 @@ def solve(self, term_names: List[str], X: np.ndarray, slices: dict) -> SubModel:
SubModel: The projected submodel object
"""
# build the optimization parameter bounds
init = np.hstack(
[self.ref_idata.posterior.mean(["chain", "draw"])[term].values for term in term_names]
)
bounds = self._build_bounds(init)
term_values = az.extract(self.ref_idata.posterior, num_samples=self.pps.shape[0])
init = np.stack([term_values[key].values.flatten() for key in term_names]).T
bounds = self._build_bounds(init[0])

# perform mean-field variational projection predictive inference
# perform projection predictive inference
res_posterior = []
objectives = []

chains = np.random.randint(0, self.ref_idata.posterior.dims["chain"], self.pps.shape[0])
draws = np.random.randint(0, self.ref_idata.posterior.dims["draw"], self.pps.shape[0])

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Values in x were outside bounds")
for idx, obs in enumerate(self.pps):
# use samples from reference model posterior as initial guess
init = np.hstack(
[
self.ref_idata.posterior.sel({"chain": chains[idx], "draw": draws[idx]})[
term
].values
for term in term_names
]
)
opt = minimize(
self.objective,
args=(obs, X),
x0=init,
x0=init[idx],
tol=0.001,
bounds=bounds,
method="SLSQP",
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ classifiers = [
dynamic = ["version"]
description = "Kullback-Leibler projections for Bayesian model selection."
dependencies = [
"arviz>=0.17.1",
"bambi>=0.12.0",
"scikit-learn>=1.0.2",
"numba>=0.56.0",
"preliz>=0.4.1"
]

[tool.flit.module]
Expand Down
71 changes: 0 additions & 71 deletions tests/test_likelihood.py

This file was deleted.

0 comments on commit dae5202

Please sign in to comment.