Skip to content

Commit

Permalink
Update TFP RTD Colabs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 505710526
  • Loading branch information
emilyfertig authored and copybara-github committed Jan 30, 2023
1 parent 2b857a4 commit 25e1c38
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/advanced_topics/tfp/bijectors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
" - ...\n",
"- `Invert` wraps any bijector instance and swaps its forward and inverse methods, e.g. `inv_sigmoid = tfb.Invert(tfb.Sigmoid())`.\n",
"- `Chain` composes a series of bijectors. The function $f(x) = 3 + 2x$ can be expressed as `tfb.Chain([tfb.Shift(3.), tfb.Scale(2.)])`. Note that the bijectors in the list are applied from right to left.\n",
"- `JointMap` applies a nested structure of bijectors to an identical nested structure of inputs. `build_constraining_bijector`, shown above, returns a `JointMap` which applies a nested structure of bijectors to an identical nested structure of inputs. Vizier GetConstraints function could be used to generate a `JointMap` based on the `Constraint`s of the `ModelParameter`s defined in the coroutine.\n",
"- `JointMap` applies a nested structure of bijectors to an identical nested structure of inputs. `build_constraining_bijector`, shown above, returns a `JointMap` which applies a nested structure of bijectors to an identical nested structure of inputs. Vizier `get_constraints` function could be used to generate a `JointMap` based on the `Constraint`s of the `ModelParameter`s defined in the coroutine.\n",
"- `Restructure` packs the elements of one nested structure (e.g. a list) into a different structure (e.g. a dict). `spm.build_restructure_bijector`, for example, is a `Chain` bijector that takes a vector of parameters, splits it into a list, and packs the elements of the list into a dictionary with the same structure as the Flax parameters dict."
]
},
Expand Down
203 changes: 181 additions & 22 deletions docs/advanced_topics/tfp/gp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,24 @@
},
"outputs": [],
"source": [
"import chex\n",
"import jax\n",
"from jax import numpy as jnp, random, tree_util\n",
"import numpy as np\n",
"import optax\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import matplotlib.pyplot as plt\n",
"from tensorflow_probability.substrates import jax as tfp\n",
"from typing import Any\n",
"\n",
"# Vizier models can freely access modules from vizier._src\n",
"from vizier._src.benchmarks.experimenters.synthetic import bbob\n",
"from vizier._src.jax.optimizers import optimizers\n",
"from vizier._src.jax import stochastic_process_model as spm\n",
"\n",
"tfd = tfp.distributions\n",
"tfb = tfp.bijectors\n",
"tfpk = tfp.math.psd_kernels\n",
"Array = Any"
"tfpk = tfp.math.psd_kernels"
]
},
{
Expand Down Expand Up @@ -89,7 +92,7 @@
},
"outputs": [],
"source": [
"def simple_gp_coroutine(inputs: Array=None):\n",
"def simple_gp_coroutine(inputs: chex.Array=None):\n",
" length_scale = yield spm.ModelParameter.from_prior(\n",
" tfd.Gamma(1., 1., name='length_scale'))\n",
" amplitude = 2. # Non-trainable parameters may be defined as constants.\n",
Expand Down Expand Up @@ -163,7 +166,7 @@
},
"outputs": [],
"source": [
"def vizier_gp_coroutine(inputs=None):\n",
"def vizier_gp_coroutine(inputs: chex.Array=None):\n",
" pass"
]
},
Expand All @@ -184,7 +187,9 @@
},
"outputs": [],
"source": [
"def vizier_gp_coroutine(inputs=None):\n",
"data_dimensionality = 2\n",
"\n",
"def vizier_gp_coroutine(inputs: chex.Array=None):\n",
" \"\"\"A coroutine that follows the `ModelCoroutine` protocol.\"\"\"\n",
" signal_variance = yield spm.ModelParameter(\n",
" init_fn=lambda x: tfb.Softplus()(random.normal(x)),\n",
Expand All @@ -193,11 +198,10 @@
" name='signal_variance')\n",
" length_scale = yield spm.ModelParameter.from_prior(\n",
" tfd.Sample(\n",
" tfd.LogNormal(loc=0, scale=1.),\n",
" sample_shape=[4],\n",
" tfd.LogNormal(loc=0., scale=1.),\n",
" sample_shape=[data_dimensionality],\n",
" name='length_scale'),\n",
" constraint=spm.Constraint(\n",
" bounds=(jnp.zeros([4]), 100.0 + jnp.zeros([4]))))\n",
" constraint=spm.Constraint(bounds=(0.0, None)))\n",
" kernel = tfpk.MaternFiveHalves(\n",
" amplitude=jnp.sqrt(signal_variance), validate_args=True)\n",
" kernel = tfpk.FeatureScaled(\n",
Expand Down Expand Up @@ -236,13 +240,12 @@
"model = spm.StochasticProcessModel(coroutine=vizier_gp_coroutine)\n",
"\n",
"# Sample some fake data.\n",
"# TODO: Use Branin or similar instead?\n",
"# Assume we have `num_points` observations, each with `dim` features.\n",
"num_points = 12\n",
"dim = 4\n",
"\n",
"# Sample a set of index points.\n",
"index_points = np.random.normal(size=[num_points, dim]).astype(np.float32)\n",
"index_points = np.random.normal(\n",
" size=[num_points, data_dimensionality]).astype(np.float32)\n",
"\n",
"# Sample function values observed at the index points\n",
"observations = np.random.normal(size=[num_points]).astype(np.float32)\n",
Expand Down Expand Up @@ -349,7 +352,7 @@
"def loss_fn(params):\n",
" gp, mutables = model.apply({'params': params},\n",
" index_points,\n",
" mutable=['losses', 'predictive'])\n",
" mutable=['losses'])\n",
" loss = (-gp.log_prob(observations) +\n",
" jax.tree_util.tree_reduce(jnp.add, mutables['losses'])) # add the regularization losses.\n",
" return loss, {}"
Expand Down Expand Up @@ -419,7 +422,7 @@
"\n",
"Below we use the Vizier `JaxoptLbfgsB` optimizer to run a constrained L-BFGS-B algorithm. Unconstrainted optimizers (e.g. Adam) use a bijector function to map between the unconstrained space where the search is performed, and the constrained space where the loss function is evaluated. On the contrary, constrained optimizers (e.g. L-BGFS-B) use the constraint bounds directly in the search process.\n",
"\n",
"To pass the constraints bounds to the `JaxoptLbfgsB` optimizer we use the `spm.GetConstraint` function that traverse the parameters defined in the module coroutine and extract their bounds."
"To pass the constraints bounds to the `JaxoptLbfgsB` optimizer we use the `spm.get_constraints` function that traverse the parameters defined in the module coroutine and extract their bounds."
]
},
{
Expand All @@ -432,11 +435,11 @@
"source": [
"setup = lambda rng: model.init(rng, index_points)['params']\n",
"model_optimizer = optimizers.JaxoptLbfgsB(\n",
" random_restarts=8, best_n=None\n",
" random_restarts=20, best_n=None\n",
" )\n",
"constraints = spm.GetConstraints(model.coroutine)()\n",
"optimal_params = model_optimizer(setup, loss_fn, random.PRNGKey(0),\n",
" constraints=constraints)"
"constraints = spm.get_constraints(model.coroutine)\n",
"optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),\n",
" constraints=constraints)"
]
},
{
Expand Down Expand Up @@ -467,18 +470,166 @@
" method=model.precompute_predictive)\n",
"\n",
"# Predict on new index points.\n",
"predictive_index_points = np.random.normal(size=[5, dim]).astype(np.float32)\n",
"pp_dist, _ = model.apply(\n",
"predictive_index_points = np.random.normal(\n",
" size=[5, data_dimensionality]).astype(np.float32)\n",
"pp_dist = model.apply(\n",
" {'params': optimal_params, **pp_state},\n",
" predictive_index_points,\n",
" mutable=['predictive'],\n",
" method=model.predict)\n",
"\n",
"# `predict` returns a TFP distribution, whose mean, variance, and samples we can\n",
"# use to compute an acquisition function.\n",
"assert pp_dist.mean().shape == (5,)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jBrxUYVjjDaL"
},
"source": [
"## Optimize a black-box function\n",
"\n",
"For an end-to-end example of Bayesian optimization, we'll use the GP surrogate model defined above along with an Upper Confidence Bound acquisition function to try to find the maximum of the Weierstrass function. First, visualize the function surface."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9_SF1D7Vu1sp"
},
"outputs": [],
"source": [
"# Use the Weierstrass function from Vizier's Black-Box Optimization Benchmarking\n",
"# (BBOB) library.\n",
"bb_fun = bbob.Weierstrass\n",
"\n",
"# Sample a set of index points in a 2D space.\n",
"num_points = 6\n",
"max_x = np.array(2.).astype(np.float32)\n",
"index_points = random.uniform(\n",
" random.PRNGKey(3), \n",
" shape=[num_points, data_dimensionality], dtype=jnp.float32) * max_x\n",
"\n",
"# Compute function values observed at the index points.\n",
"observations = np.apply_along_axis(\n",
" bb_fun, axis=1, arr=index_points).astype(np.float32)\n",
"\n",
"# Define a grid of points in the function domain for plotting.\n",
"n_grid = 100\n",
"x = y = np.linspace(0, max_x, n_grid, dtype=np.float32)\n",
"X, Y = np.meshgrid(x, y)\n",
"x_grid = np.vstack([X.ravel(), Y.ravel()]).T\n",
"y_grid = np.apply_along_axis(bb_fun, axis=1, arr=x_grid)\n",
"Z = y_grid.reshape(X.shape)\n",
"\n",
"# Plot the black-box function values.\n",
"fig = plt.figure(figsize=(8, 8))\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"ax.plot_surface(X, Y, Z, alpha=0.5)\n",
"ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r', \n",
" label='Initial observed data')\n",
"plt.title('Black-box (Weierstrass) function values and observed data')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VjZA9C41oCBm"
},
"source": [
"Next, run a few iterations of Bayesian optimization to maximize the black-box function given the observed data. A single iteration consists of the following steps:\n",
"1. Optimize the GP hyperparameters.\n",
"2. Find a suggestion that maximizes an Upper Confidence Bound acquisition function. In this example, we use grid search for the optimization.\n",
"3. Evaluate the black-box function on the suggestion and append it to the set of observed data.\n",
"\n",
"(Note that this simple Bayesopt algorithm is for educational purposes and that we'd expect Vizier's GP bandit algorithm to give better results.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1VqoyPzWOLN7"
},
"outputs": [],
"source": [
"num_bayesopt_iter = 5\n",
"\n",
"# At each iteration, redefine the loss function given the current observed data.\n",
"def build_loss_fn(index_points, observations):\n",
" def loss_fn(params):\n",
" gp, mutables = model.apply({'params': params},\n",
" index_points,\n",
" mutable=['losses'])\n",
" loss = (-gp.log_prob(observations) +\n",
" jax.tree_util.tree_reduce(jnp.add, mutables['losses'])) # add the regularization losses.\n",
" return loss, {}\n",
" return loss_fn\n",
"\n",
"for i in range(num_bayesopt_iter):\n",
" # Update the loss function to condition on all observed data.\n",
" loss_fn = build_loss_fn(index_points, observations)\n",
"\n",
" # Optimize the GP hyperparameters.\n",
" optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),\n",
" constraints=constraints)\n",
"\n",
" # Compute the posterior predictive distribution over a grid of points in the\n",
" # function domain (x_grid).\n",
" _, pp_state = model.apply(\n",
" {'params': optimal_params},\n",
" index_points,\n",
" observations,\n",
" mutable=['predictive'],\n",
" method=model.precompute_predictive)\n",
" pp_dist = model.apply(\n",
" {'params': optimal_params, **pp_state},\n",
" x_grid,\n",
" method=model.predict)\n",
"\n",
" # Compute the acquisition function value at each point in the grid.\n",
" pred_mean = pp_dist.mean()\n",
" ucb_vec = pred_mean + 2. * pp_dist.stddev()\n",
"\n",
" # Find the grid point with the highest acquisition function value.\n",
" ind = np.argmax(ucb_vec)\n",
"\n",
" # Evaluate the black box function at the selected point. \n",
" f_val = bb_fun(x_grid[ind])\n",
"\n",
" # Visualize the surrogate model mean and acquisition function surface at this\n",
" # iteration.\n",
" fig = plt.figure(figsize=(10, 10))\n",
" ax = fig.add_subplot(121, projection='3d')\n",
" W = pred_mean.reshape(X.shape)\n",
" ax.plot_surface(X, Y, W, alpha=0.5)\n",
" ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r',\n",
" label='Observed data')\n",
" ax.set_title('Observed data and posterior predictive GP mean')\n",
" ax.legend()\n",
" \n",
" ax = fig.add_subplot(122, projection='3d')\n",
" ucb = ucb_vec.reshape(X.shape)\n",
" ax.plot_surface(X, Y, ucb, alpha=0.5)\n",
" ax.scatter(*x_grid[ind], ucb_vec[ind], color='r', label='New suggestion')\n",
" ax.set_title('Acquisition function')\n",
" ax.legend()\n",
" plt.show()\n",
"\n",
" # Append the new suggestion and function value to the set of observations.\n",
" index_points = np.concatenate([index_points, x_grid[ind][np.newaxis]])\n",
" observations = np.concatenate(\n",
" [observations, np.array(f_val).astype(np.float32)[np.newaxis]])\n",
" \n",
" print(f'Iteration: {i}')\n",
" print(f'Acquisition function value at suggestion: {ucb_vec[ind]}')\n",
" print(f'Black-box function value at suggestion: {f_val}')"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -518,6 +669,14 @@
"outputs": [],
"source": [
"# Build a kernel function (see \"PSD kernels\" section below) and GP.\n",
"num_points = 6\n",
"index_points = random.uniform(\n",
" random.PRNGKey(3), \n",
" shape=[num_points, data_dimensionality], dtype=jnp.float32) \n",
"observations = random.uniform(\n",
" random.PRNGKey(4), \n",
" shape=[num_points], dtype=jnp.float32) \n",
"\n",
"kernel = tfpk.MaternFiveHalves(\n",
" amplitude=2.,\n",
" length_scale=0.3,\n",
Expand All @@ -534,7 +693,7 @@
"\n",
"# Take 4 samples from the GP at the index points.\n",
"s = gp.sample(4, seed=random.PRNGKey(0))\n",
"assert s.shape == (4, 12)\n",
"assert s.shape == (4, num_points)\n",
"\n",
"# Compute the log likelihood of the sampled values.\n",
"lp = gp.log_prob(s)\n",
Expand Down

0 comments on commit 25e1c38

Please sign in to comment.