Skip to content

Commit

Permalink
write intro
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Dec 24, 2024
1 parent c6c33bd commit ef19b5b
Showing 1 changed file with 44 additions and 12 deletions.
56 changes: 44 additions & 12 deletions Python/hierarchical_hsgp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Notes on Hierarchical Hilbert Space Gaussian Processes"
"# Notes on Hierarchical Hilbert Space Gaussian Processes\n",
"\n",
"In this notebook, we want to explore some ideas on hierarchical Hilbert Space Gaussian Processes following the fantastic exposition of the PyMC example notebook: [\"Gaussian Processes: HSGP Advanced Usage\"](https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html) by [Bill Engels](https://github.com/bwengals), [Alexandre Andorra](https://github.com/AlexAndorra) and [Maxim Kochurov](https://github.com/ferrine). I can only recommend to read the notebook and the references therein!\n",
"For an introduction to Hilbert Space Gaussian Processes, please see my introductory blog post: [\"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods\"](https://juanitorduz.github.io/hsgp_intro/).\n",
"\n",
"**Motivation**: In many applications, one is interested in understanding the dynamic effect of one variable on another. For example, in marketing, one is interested in efficiency across many channels on sales (or conversions) as a function of time. For such a purpose, one typically uses a time-varying regression model with time-varying coefficients (for a simple example, see [\"Time-Varying Regression Coefficients via Hilbert Space Gaussian Process Approximation\"](https://juanitorduz.github.io/bikes_gp/)). The main challenge is that, if not done carefully, the model will easily overfit the data. One can use a hierarchical model with a global component and a group-specific component to overcome this. The global component will capture the overall trend, while the group-specific component will capture the idiosyncratic behavior of each group. As a side effect, we will gain some regularization effect that will help to avoid overfitting. For the marketing example above, [PyMC Labs](https://www.pymc-labs.com/blog-posts/) has successfully applied this approach in the context of media mix modeling: [\"Bayesian Media Mix Models: Modelling changes in marketing effectiveness over time\"](https://www.pymc-labs.com/blog-posts/modelling-changes-marketing-effectiveness-over-time/). This hierarchical approach was motivated by the work [\"Hierarchical Bayesian modeling of gene expression time series across irregularly sampled replicates and clusters\"](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-14-252).\n",
"\n",
"To my knowledge (please correct me if I am wrong), the PyMC example notebook [\"Gaussian Processes: HSGP Advanced Usage\"](https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html) is the first reference with complete code to fit a hierarchical Hilbert Space Gaussian Process (well, maybe the [initial gist](https://gist.github.com/bwengals/90e79d64fe9e2d496a048f0b2c6d33da)). Here, we use very similar techniques with NumPyro using the (relatively) new [Hilbert Space Gaussian Process module](https://num.pyro.ai/en/latest/contrib.html#hilbert-space-gaussian-processes-approximation). For both the PyMC and NumPyro implementations, one of the key ingredients is vectorizing the spectral density expressions (as we need to fit many groups). While in the PyMC notebook, the vectorization is done by hand, in the NumPyro implementation, the vectorization is done using the [vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) function. Nevertheless, both approaches are equivalent. Hence, we can focus on the concepts and worry less about the implementation details.\n",
"\n",
"With these new explicit code references, we expect to demystify the implementation of the hierarchical Hilbert Space Gaussian Process so that the community can benefit from this approach for real applications. That being said, as we will see below, this approach has its challenges. As in many applications with Gaussian processes, specifying sensitive priors and ensuring parameter identifiability is not a trivial task. Here is where the application should drive the model specification (where the thinking happens!).\n"
]
},
{
Expand Down Expand Up @@ -66,24 +75,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate Synthetic Data"
"## Generate Synthetic Data\n",
"\n",
"We start by generating some synthetic data. We do not do a full parameter recovery as in the PyMC notebook. Instead, we simple generate synthetic data, fit and asses the quality of the model. We will use a single group-specific latent function for each group. The mean function is a sum of two sine waves."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"n = 300\n",
"x = jnp.linspace(0.1, 1, n)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def generate_single_group_data(\n",
" rng_key: UInt32[Array, \"2\"], x: Float32[Array, \" n\"]\n",
Expand Down Expand Up @@ -114,24 +115,55 @@
"def dgg(\n",
" rng_key: UInt32[Array, \"2\"], x: Float32[Array, \" n\"]\n",
") -> tuple[Float32[Array, \" n\"], Float32[Array, \" n\"], Float32[Array, \" n\"]]:\n",
" \"\"\"Data generation function.\n",
"\n",
" We generate data for a single group by adding noise to the mean latent function.\n",
"\n",
" Parameters\n",
" ----------\n",
" rng_key : UInt32[Array, \"2\"]\n",
" JAX random key.\n",
" x : Float32[Array, \" n\"]\n",
" Input domain data.\n",
"\n",
" Returns\n",
" -------\n",
" tuple[Float32[Array, \" n\"], Float32[Array, \" n\"], Float32[Array, \" n\"]]\n",
" Group-specific latent function, mean function and observed data.\n",
" \"\"\"\n",
" f_g = generate_single_group_data(rng_key, x)\n",
" f = jnp.sin((4 * jnp.pi) * x) + jnp.sin((1 * jnp.pi) * x)\n",
" noise = random.normal(rng_key, shape=(n,)) * 0.5\n",
" y = f + f_g + noise\n",
" return f_g, f, y\n",
"\n",
"\n",
"# Number of observations and groups\n",
"n = 300\n",
"n_groups = 5\n",
"\n",
"# Input domain data\n",
"x = jnp.linspace(0.1, 1, n)\n",
"\n",
"# Generate data\n",
"rng_key, rng_subkey = random.split(rng_key)\n",
"\n",
"# Generate data for each group\n",
"f_g, f, y = vmap(dgg)(random.split(rng_subkey, n_groups), jnp.tile(x, (n_groups, 1)))\n",
"\n",
"# Check shapes\n",
"assert f_g.shape == (n_groups, n)\n",
"assert f.shape == (n_groups, n)\n",
"assert y.shape == (n_groups, n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's visualize the raw data."
]
},
{
"cell_type": "code",
"execution_count": 4,
Expand Down

0 comments on commit ef19b5b

Please sign in to comment.