From 25e1c382ac5340a8281f43f9ba63c952988ec420 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 30 Jan 2023 09:06:51 -0800 Subject: [PATCH] Update TFP RTD Colabs. PiperOrigin-RevId: 505710526 --- docs/advanced_topics/tfp/bijectors.ipynb | 2 +- docs/advanced_topics/tfp/gp.ipynb | 203 ++++++++++++++++++++--- 2 files changed, 182 insertions(+), 23 deletions(-) diff --git a/docs/advanced_topics/tfp/bijectors.ipynb b/docs/advanced_topics/tfp/bijectors.ipynb index 7436b74b7..e306b28d4 100644 --- a/docs/advanced_topics/tfp/bijectors.ipynb +++ b/docs/advanced_topics/tfp/bijectors.ipynb @@ -116,7 +116,7 @@ " - ...\n", "- `Invert` wraps any bijector instance and swaps its forward and inverse methods, e.g. `inv_sigmoid = tfb.Invert(tfb.Sigmoid())`.\n", "- `Chain` composes a series of bijectors. The function $f(x) = 3 + 2x$ can be expressed as `tfb.Chain([tfb.Shift(3.), tfb.Scale(2.)])`. Note that the bijectors in the list are applied from right to left.\n", - "- `JointMap` applies a nested structure of bijectors to an identical nested structure of inputs. `build_constraining_bijector`, shown above, returns a `JointMap` which applies a nested structure of bijectors to an identical nested structure of inputs. Vizier GetConstraints function could be used to generate a `JointMap` based on the `Constraint`s of the `ModelParameter`s defined in the coroutine.\n", + "- `JointMap` applies a nested structure of bijectors to an identical nested structure of inputs. `build_constraining_bijector`, shown above, returns a `JointMap` which applies a nested structure of bijectors to an identical nested structure of inputs. Vizier `get_constraints` function could be used to generate a `JointMap` based on the `Constraint`s of the `ModelParameter`s defined in the coroutine.\n", "- `Restructure` packs the elements of one nested structure (e.g. a list) into a different structure (e.g. a dict). `spm.build_restructure_bijector`, for example, is a `Chain` bijector that takes a vector of parameters, splits it into a list, and packs the elements of the list into a dictionary with the same structure as the Flax parameters dict." ] }, diff --git a/docs/advanced_topics/tfp/gp.ipynb b/docs/advanced_topics/tfp/gp.ipynb index a05a19675..0f157a276 100644 --- a/docs/advanced_topics/tfp/gp.ipynb +++ b/docs/advanced_topics/tfp/gp.ipynb @@ -47,21 +47,24 @@ }, "outputs": [], "source": [ + "import chex\n", "import jax\n", "from jax import numpy as jnp, random, tree_util\n", "import numpy as np\n", "import optax\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "import matplotlib.pyplot as plt\n", "from tensorflow_probability.substrates import jax as tfp\n", "from typing import Any\n", "\n", "# Vizier models can freely access modules from vizier._src\n", + "from vizier._src.benchmarks.experimenters.synthetic import bbob\n", "from vizier._src.jax.optimizers import optimizers\n", "from vizier._src.jax import stochastic_process_model as spm\n", "\n", "tfd = tfp.distributions\n", "tfb = tfp.bijectors\n", - "tfpk = tfp.math.psd_kernels\n", - "Array = Any" + "tfpk = tfp.math.psd_kernels" ] }, { @@ -89,7 +92,7 @@ }, "outputs": [], "source": [ - "def simple_gp_coroutine(inputs: Array=None):\n", + "def simple_gp_coroutine(inputs: chex.Array=None):\n", " length_scale = yield spm.ModelParameter.from_prior(\n", " tfd.Gamma(1., 1., name='length_scale'))\n", " amplitude = 2. # Non-trainable parameters may be defined as constants.\n", @@ -163,7 +166,7 @@ }, "outputs": [], "source": [ - "def vizier_gp_coroutine(inputs=None):\n", + "def vizier_gp_coroutine(inputs: chex.Array=None):\n", " pass" ] }, @@ -184,7 +187,9 @@ }, "outputs": [], "source": [ - "def vizier_gp_coroutine(inputs=None):\n", + "data_dimensionality = 2\n", + "\n", + "def vizier_gp_coroutine(inputs: chex.Array=None):\n", " \"\"\"A coroutine that follows the `ModelCoroutine` protocol.\"\"\"\n", " signal_variance = yield spm.ModelParameter(\n", " init_fn=lambda x: tfb.Softplus()(random.normal(x)),\n", @@ -193,11 +198,10 @@ " name='signal_variance')\n", " length_scale = yield spm.ModelParameter.from_prior(\n", " tfd.Sample(\n", - " tfd.LogNormal(loc=0, scale=1.),\n", - " sample_shape=[4],\n", + " tfd.LogNormal(loc=0., scale=1.),\n", + " sample_shape=[data_dimensionality],\n", " name='length_scale'),\n", - " constraint=spm.Constraint(\n", - " bounds=(jnp.zeros([4]), 100.0 + jnp.zeros([4]))))\n", + " constraint=spm.Constraint(bounds=(0.0, None)))\n", " kernel = tfpk.MaternFiveHalves(\n", " amplitude=jnp.sqrt(signal_variance), validate_args=True)\n", " kernel = tfpk.FeatureScaled(\n", @@ -236,13 +240,12 @@ "model = spm.StochasticProcessModel(coroutine=vizier_gp_coroutine)\n", "\n", "# Sample some fake data.\n", - "# TODO: Use Branin or similar instead?\n", "# Assume we have `num_points` observations, each with `dim` features.\n", "num_points = 12\n", - "dim = 4\n", "\n", "# Sample a set of index points.\n", - "index_points = np.random.normal(size=[num_points, dim]).astype(np.float32)\n", + "index_points = np.random.normal(\n", + " size=[num_points, data_dimensionality]).astype(np.float32)\n", "\n", "# Sample function values observed at the index points\n", "observations = np.random.normal(size=[num_points]).astype(np.float32)\n", @@ -349,7 +352,7 @@ "def loss_fn(params):\n", " gp, mutables = model.apply({'params': params},\n", " index_points,\n", - " mutable=['losses', 'predictive'])\n", + " mutable=['losses'])\n", " loss = (-gp.log_prob(observations) +\n", " jax.tree_util.tree_reduce(jnp.add, mutables['losses'])) # add the regularization losses.\n", " return loss, {}" @@ -419,7 +422,7 @@ "\n", "Below we use the Vizier `JaxoptLbfgsB` optimizer to run a constrained L-BFGS-B algorithm. Unconstrainted optimizers (e.g. Adam) use a bijector function to map between the unconstrained space where the search is performed, and the constrained space where the loss function is evaluated. On the contrary, constrained optimizers (e.g. L-BGFS-B) use the constraint bounds directly in the search process.\n", "\n", - "To pass the constraints bounds to the `JaxoptLbfgsB` optimizer we use the `spm.GetConstraint` function that traverse the parameters defined in the module coroutine and extract their bounds." + "To pass the constraints bounds to the `JaxoptLbfgsB` optimizer we use the `spm.get_constraints` function that traverse the parameters defined in the module coroutine and extract their bounds." ] }, { @@ -432,11 +435,11 @@ "source": [ "setup = lambda rng: model.init(rng, index_points)['params']\n", "model_optimizer = optimizers.JaxoptLbfgsB(\n", - " random_restarts=8, best_n=None\n", + " random_restarts=20, best_n=None\n", " )\n", - "constraints = spm.GetConstraints(model.coroutine)()\n", - "optimal_params = model_optimizer(setup, loss_fn, random.PRNGKey(0),\n", - " constraints=constraints)" + "constraints = spm.get_constraints(model.coroutine)\n", + "optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),\n", + " constraints=constraints)" ] }, { @@ -467,11 +470,11 @@ " method=model.precompute_predictive)\n", "\n", "# Predict on new index points.\n", - "predictive_index_points = np.random.normal(size=[5, dim]).astype(np.float32)\n", - "pp_dist, _ = model.apply(\n", + "predictive_index_points = np.random.normal(\n", + " size=[5, data_dimensionality]).astype(np.float32)\n", + "pp_dist = model.apply(\n", " {'params': optimal_params, **pp_state},\n", " predictive_index_points,\n", - " mutable=['predictive'],\n", " method=model.predict)\n", "\n", "# `predict` returns a TFP distribution, whose mean, variance, and samples we can\n", @@ -479,6 +482,154 @@ "assert pp_dist.mean().shape == (5,)\n" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "jBrxUYVjjDaL" + }, + "source": [ + "## Optimize a black-box function\n", + "\n", + "For an end-to-end example of Bayesian optimization, we'll use the GP surrogate model defined above along with an Upper Confidence Bound acquisition function to try to find the maximum of the Weierstrass function. First, visualize the function surface." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9_SF1D7Vu1sp" + }, + "outputs": [], + "source": [ + "# Use the Weierstrass function from Vizier's Black-Box Optimization Benchmarking\n", + "# (BBOB) library.\n", + "bb_fun = bbob.Weierstrass\n", + "\n", + "# Sample a set of index points in a 2D space.\n", + "num_points = 6\n", + "max_x = np.array(2.).astype(np.float32)\n", + "index_points = random.uniform(\n", + " random.PRNGKey(3), \n", + " shape=[num_points, data_dimensionality], dtype=jnp.float32) * max_x\n", + "\n", + "# Compute function values observed at the index points.\n", + "observations = np.apply_along_axis(\n", + " bb_fun, axis=1, arr=index_points).astype(np.float32)\n", + "\n", + "# Define a grid of points in the function domain for plotting.\n", + "n_grid = 100\n", + "x = y = np.linspace(0, max_x, n_grid, dtype=np.float32)\n", + "X, Y = np.meshgrid(x, y)\n", + "x_grid = np.vstack([X.ravel(), Y.ravel()]).T\n", + "y_grid = np.apply_along_axis(bb_fun, axis=1, arr=x_grid)\n", + "Z = y_grid.reshape(X.shape)\n", + "\n", + "# Plot the black-box function values.\n", + "fig = plt.figure(figsize=(8, 8))\n", + "ax = fig.add_subplot(111, projection='3d')\n", + "ax.plot_surface(X, Y, Z, alpha=0.5)\n", + "ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r', \n", + " label='Initial observed data')\n", + "plt.title('Black-box (Weierstrass) function values and observed data')\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VjZA9C41oCBm" + }, + "source": [ + "Next, run a few iterations of Bayesian optimization to maximize the black-box function given the observed data. A single iteration consists of the following steps:\n", + "1. Optimize the GP hyperparameters.\n", + "2. Find a suggestion that maximizes an Upper Confidence Bound acquisition function. In this example, we use grid search for the optimization.\n", + "3. Evaluate the black-box function on the suggestion and append it to the set of observed data.\n", + "\n", + "(Note that this simple Bayesopt algorithm is for educational purposes and that we'd expect Vizier's GP bandit algorithm to give better results.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1VqoyPzWOLN7" + }, + "outputs": [], + "source": [ + "num_bayesopt_iter = 5\n", + "\n", + "# At each iteration, redefine the loss function given the current observed data.\n", + "def build_loss_fn(index_points, observations):\n", + " def loss_fn(params):\n", + " gp, mutables = model.apply({'params': params},\n", + " index_points,\n", + " mutable=['losses'])\n", + " loss = (-gp.log_prob(observations) +\n", + " jax.tree_util.tree_reduce(jnp.add, mutables['losses'])) # add the regularization losses.\n", + " return loss, {}\n", + " return loss_fn\n", + "\n", + "for i in range(num_bayesopt_iter):\n", + " # Update the loss function to condition on all observed data.\n", + " loss_fn = build_loss_fn(index_points, observations)\n", + "\n", + " # Optimize the GP hyperparameters.\n", + " optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),\n", + " constraints=constraints)\n", + "\n", + " # Compute the posterior predictive distribution over a grid of points in the\n", + " # function domain (x_grid).\n", + " _, pp_state = model.apply(\n", + " {'params': optimal_params},\n", + " index_points,\n", + " observations,\n", + " mutable=['predictive'],\n", + " method=model.precompute_predictive)\n", + " pp_dist = model.apply(\n", + " {'params': optimal_params, **pp_state},\n", + " x_grid,\n", + " method=model.predict)\n", + "\n", + " # Compute the acquisition function value at each point in the grid.\n", + " pred_mean = pp_dist.mean()\n", + " ucb_vec = pred_mean + 2. * pp_dist.stddev()\n", + "\n", + " # Find the grid point with the highest acquisition function value.\n", + " ind = np.argmax(ucb_vec)\n", + "\n", + " # Evaluate the black box function at the selected point. \n", + " f_val = bb_fun(x_grid[ind])\n", + "\n", + " # Visualize the surrogate model mean and acquisition function surface at this\n", + " # iteration.\n", + " fig = plt.figure(figsize=(10, 10))\n", + " ax = fig.add_subplot(121, projection='3d')\n", + " W = pred_mean.reshape(X.shape)\n", + " ax.plot_surface(X, Y, W, alpha=0.5)\n", + " ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r',\n", + " label='Observed data')\n", + " ax.set_title('Observed data and posterior predictive GP mean')\n", + " ax.legend()\n", + " \n", + " ax = fig.add_subplot(122, projection='3d')\n", + " ucb = ucb_vec.reshape(X.shape)\n", + " ax.plot_surface(X, Y, ucb, alpha=0.5)\n", + " ax.scatter(*x_grid[ind], ucb_vec[ind], color='r', label='New suggestion')\n", + " ax.set_title('Acquisition function')\n", + " ax.legend()\n", + " plt.show()\n", + "\n", + " # Append the new suggestion and function value to the set of observations.\n", + " index_points = np.concatenate([index_points, x_grid[ind][np.newaxis]])\n", + " observations = np.concatenate(\n", + " [observations, np.array(f_val).astype(np.float32)[np.newaxis]])\n", + " \n", + " print(f'Iteration: {i}')\n", + " print(f'Acquisition function value at suggestion: {ucb_vec[ind]}')\n", + " print(f'Black-box function value at suggestion: {f_val}')" + ] + }, { "cell_type": "markdown", "metadata": { @@ -518,6 +669,14 @@ "outputs": [], "source": [ "# Build a kernel function (see \"PSD kernels\" section below) and GP.\n", + "num_points = 6\n", + "index_points = random.uniform(\n", + " random.PRNGKey(3), \n", + " shape=[num_points, data_dimensionality], dtype=jnp.float32) \n", + "observations = random.uniform(\n", + " random.PRNGKey(4), \n", + " shape=[num_points], dtype=jnp.float32) \n", + "\n", "kernel = tfpk.MaternFiveHalves(\n", " amplitude=2.,\n", " length_scale=0.3,\n", @@ -534,7 +693,7 @@ "\n", "# Take 4 samples from the GP at the index points.\n", "s = gp.sample(4, seed=random.PRNGKey(0))\n", - "assert s.shape == (4, 12)\n", + "assert s.shape == (4, num_points)\n", "\n", "# Compute the log likelihood of the sampled values.\n", "lp = gp.log_prob(s)\n",