-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(algorithms): add CMA-ME, fix CMA-ES and CMA-MEGA (#86)
Add emitters introduced in CMA-ME, add tests, notebook, update the documentation, fix CMA-ES and fix CMA-MEGA
- Loading branch information
1 parent
d7c6dc7
commit a5c19a2
Showing
29 changed files
with
2,089 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Covariance Matrix Adaptation MAP Elites (CMAME) | ||
|
||
To create an instance of CMAME, one need to use an instance of [MAP-Elites](map_elites.md) with the desired CMA Emitter - optimizing, random direction, improvement - detailed below.To use the pool of emitter mechanism, use the CMAPoolEmitter. | ||
|
||
Three emitter types: | ||
|
||
::: qdax.core.emitters.cma_emitter.CMAEmitter | ||
::: qdax.core.emitters.cma_rnd_emitter.CMARndEmitter | ||
::: qdax.core.emitters.cma_opt_emitter.CMAOptimizingEmitter | ||
|
||
Pool of homogeneous emitters: | ||
|
||
::: qdax.core.emitters.cma_pool_emitter.CMAPoolEmitter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../examples/ |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "222bbe00", | ||
"metadata": {}, | ||
"source": [ | ||
"# Optimizing with CMA-ES in Jax\n", | ||
"\n", | ||
"This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", | ||
"\n", | ||
"- how to define the problem\n", | ||
"- how to create a CMA-ES optimizer\n", | ||
"- how to launch a certain number of optimizing steps\n", | ||
"- how to visualise the optimization process" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d731f067", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from matplotlib.patches import Ellipse\n", | ||
"\n", | ||
"from qdax.core.cmaes import CMAES" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7b6e910b", | ||
"metadata": {}, | ||
"source": [ | ||
"## Set the hyperparameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "404fb0dc", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Hyperparameters\n", | ||
"#@markdown ---\n", | ||
"num_iterations = 1000 #@param {type:\"integer\"}\n", | ||
"num_dimensions = 100 #@param {type:\"integer\"}\n", | ||
"batch_size = 36 #@param {type:\"integer\"}\n", | ||
"num_best = 18 #@param {type:\"integer\"}\n", | ||
"sigma_g = 0.5 # 0.5 #@param {type:\"number\"}\n", | ||
"minval = -5.12 #@param {type:\"number\"}\n", | ||
"optim_problem = \"sphere\" #@param[\"rastrigin\", \"sphere\"]\n", | ||
"#@markdown ---" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ccc7cbeb", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"## Define the fitness function - choose rastrigin or sphere" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "436dccbb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def rastrigin_scoring(x: jnp.ndarray):\n", | ||
" first_term = 10 * x.shape[-1]\n", | ||
" second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))\n", | ||
" return -(first_term + second_term)\n", | ||
"\n", | ||
"def sphere_scoring(x: jnp.ndarray):\n", | ||
" return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)\n", | ||
"\n", | ||
"if optim_problem == \"sphere\":\n", | ||
" fitness_fn = sphere_scoring\n", | ||
"elif optim_problem == \"rastrigin\":\n", | ||
" fitness_fn = jax.vmap(rastrigin_scoring)\n", | ||
"else:\n", | ||
" raise Exception(\"Invalid opt function name given\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "62bdd2a4", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"## Define a CMA-ES optimizer instance" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4cf03f55", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"cmaes = CMAES(\n", | ||
" population_size=batch_size,\n", | ||
" num_best=num_best,\n", | ||
" search_dim=num_dimensions,\n", | ||
" fitness_function=fitness_fn,\n", | ||
" mean_init=jnp.zeros((num_dimensions,)),\n", | ||
" init_sigma=sigma_g,\n", | ||
" delay_eigen_decomposition=True,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f1f69f50", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"## Init the CMA-ES optimizer state" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1a95b74d", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"state = cmaes.init()\n", | ||
"random_key = jax.random.PRNGKey(0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ac2d5c0d", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"## Run optimization iterations" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "363198ca", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%%time\n", | ||
"\n", | ||
"means = [state.mean]\n", | ||
"covs = [(state.sigma**2) * state.cov_matrix]\n", | ||
"\n", | ||
"iteration_count = 0\n", | ||
"for _ in range(num_iterations):\n", | ||
" iteration_count += 1\n", | ||
" \n", | ||
" # sample\n", | ||
" samples, random_key = cmaes.sample(state, random_key)\n", | ||
" \n", | ||
" # udpate\n", | ||
" state = cmaes.update(state, samples)\n", | ||
" \n", | ||
" # check stop condition\n", | ||
" stop_condition = cmaes.stop_condition(state)\n", | ||
"\n", | ||
" if stop_condition:\n", | ||
" break\n", | ||
" \n", | ||
" # store data for plotting\n", | ||
" means.append(state.mean)\n", | ||
" covs.append((state.sigma**2) * state.cov_matrix)\n", | ||
" \n", | ||
"print(\"Num iterations before stop condition: \", iteration_count)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0e5820b8", | ||
"metadata": {}, | ||
"source": [ | ||
"## Check final fitnesses and distribution mean" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1e4a2c7b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# checking final fitness values\n", | ||
"fitnesses = fitness_fn(samples)\n", | ||
"\n", | ||
"print(\"Min fitness in the final population: \", jnp.min(fitnesses))\n", | ||
"print(\"Mean fitness in the final population: \", jnp.mean(fitnesses))\n", | ||
"print(\"Max fitness in the final population: \", jnp.max(fitnesses))\n", | ||
"\n", | ||
"# checking mean of the final distribution\n", | ||
"print(\"Final mean of the distribution: \\n\", means[-1])\n", | ||
"# print(\"Final covariance matrix of the distribution: \", covs[-1])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f3bd2b0f", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%% md\n" | ||
} | ||
}, | ||
"source": [ | ||
"## Visualization of the optimization trajectory" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ad85551c", | ||
"metadata": { | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"fig, ax = plt.subplots(figsize=(12, 6))\n", | ||
"\n", | ||
"# sample points to show fitness landscape\n", | ||
"random_key, subkey = jax.random.split(random_key)\n", | ||
"x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n", | ||
"f_x = fitness_fn(x)\n", | ||
"\n", | ||
"# plot fitness landscape\n", | ||
"points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)\n", | ||
"fig.colorbar(points)\n", | ||
"\n", | ||
"# plot cma-es trajectory\n", | ||
"traj_min = 0\n", | ||
"traj_max = iteration_count\n", | ||
"for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):\n", | ||
" ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')\n", | ||
" ax.add_patch(ellipse)\n", | ||
" ax.plot(mean[0], mean[1], color='k', marker='x')\n", | ||
" \n", | ||
"ax.set_title(f\"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}\")\n", | ||
"plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.13" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.