diff --git a/README.md b/README.md index 2477348d..0881da1d 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ QDax currently supports the following algorithms: | [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) | | [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) | | [MAP-Elites PBT (ME-PBT)](https://openreview.net/forum?id=CBfYffLqWqb) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_sac_pbt.ipynb) | +| [MAP-Elites Low-Spread (ME-LS)](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/me_ls.ipynb) | diff --git a/docs/api_documentation/core/mels.md b/docs/api_documentation/core/mels.md new file mode 100644 index 00000000..3aa212b5 --- /dev/null +++ b/docs/api_documentation/core/mels.md @@ -0,0 +1,7 @@ +# MAP-Elites Low-Spread (ME-LS) + +[ME-LS](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) is a variant of +MAP-Elites that thrives the search process towards solutions that are consistent +in the behavior space for uncertain domains. + +::: qdax.core.mels.MELS diff --git a/examples/mels.ipynb b/examples/mels.ipynb new file mode 100644 index 00000000..1fcd6c42 --- /dev/null +++ b/examples/mels.ipynb @@ -0,0 +1,559 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mels.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing Uncertain Domains with ME-LS in JAX\n", + "\n", + "This notebook shows how to discover controllers that achieve consistent performance in MDP domains using the [MAP-Elites Low-Spread](https://dl.acm.org/doi/abs/10.1145/3583131.3590433) algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create an emitter\n", + "- how to create an ME-LS instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualise the optimization process\n", + "- how to save/load a repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "import time\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.1.2 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.2.2\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax.core.mels import MELS\n", + "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", + "from qdax.core.containers.mels_repertoire import MELSRepertoire\n", + "from qdax import environments\n", + "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.emitters.standard_emitters import MixingEmitter\n", + "from qdax.utils.plotting import plot_map_elites_results\n", + "\n", + "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", + "\n", + "from jax.flatten_util import ravel_pytree\n", + "\n", + "from IPython.display import HTML\n", + "from brax.io import html\n", + "\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "#@markdown ---\n", + "batch_size = 100 #@param {type:\"number\"}\n", + "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "num_samples = 5 #@param {type:\"number\"}\n", + "episode_length = 100 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "seed = 42 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (64, 64) #@param {type:\"raw\"}\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "min_bd = 0. #@param {type:\"number\"}\n", + "max_bd = 1.0 #@param {type:\"number\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we consider the problem where each controller is evaluated `num_samples` times, each time in a different environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers. There are batch_size controllers, and each\n", + "# controller will be evaluated num_samples times.\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env\n", + "\n", + "Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data. This is identical to the function in the MAP-Elites tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define the function to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state,\n", + " policy_params,\n", + " random_key,\n", + "):\n", + " \"\"\"Play an environment step and return the updated state and the\n", + " transition.\"\"\"\n", + "\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + "\n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = QDTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " actions=actions,\n", + " truncations=next_state.info[\"truncation\"],\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. Note that while the MAP-Elites tutorial uses `scoring_function_brax_envs` as the basis for the scoring function, we use `reset_based_scoring_function_brax_envs`. The difference is that `reset_based_scoring_function_brax_envs` generates initial states randomly instead of taking in a fixed set of initial states. This is necessary since we are evaluating each controller across sampled initial states. If the initial states were kept the same for all evaluations, there would be no stochasticity in the behavior." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]\n", + "scoring_fn = functools.partial(\n", + " reset_based_scoring_function_brax_envs,\n", + " episode_length=episode_length,\n", + " play_reset_fn=env.reset,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "metrics_fn = functools.partial(\n", + " default_qd_metrics,\n", + " qd_offset=reward_offset * episode_length,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Define emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "mixing_emitter = MixingEmitter(\n", + " mutation_fn=None, \n", + " variation_fn=variation_fn, \n", + " variation_percentage=1.0, \n", + " batch_size=batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the ME-LS algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Instantiate ME-LS.\n", + "mels = MELS(\n", + " scoring_function=scoring_fn,\n", + " emitter=mixing_emitter,\n", + " metrics_function=metrics_fn,\n", + " num_samples=num_samples,\n", + ")\n", + "\n", + "# Compute the centroids\n", + "centroids, random_key = compute_cvt_centroids(\n", + " num_descriptors=env.behavior_descriptor_length,\n", + " num_init_cvt_samples=num_init_cvt_samples,\n", + " num_centroids=num_centroids,\n", + " minval=min_bd,\n", + " maxval=max_bd,\n", + " random_key=random_key,\n", + ")\n", + "\n", + "# Compute initial repertoire and emitter state\n", + "repertoire, emitter_state, random_key = mels.init(init_variables, centroids, random_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch ME-LS iterations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "log_period = 10\n", + "num_loops = int(num_iterations / log_period)\n", + "\n", + "csv_logger = CSVLogger(\n", + " \"mapelites-logs.csv\",\n", + " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + ")\n", + "all_metrics = {}\n", + "\n", + "# main loop\n", + "mels_scan_update = mels.scan_update\n", + "for i in range(num_loops):\n", + " start_time = time.time()\n", + " # main iterations\n", + " (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(\n", + " mels_scan_update,\n", + " (repertoire, emitter_state, random_key),\n", + " (),\n", + " length=log_period,\n", + " )\n", + " timelapse = time.time() - start_time\n", + "\n", + " # log metrics\n", + " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", + " for key, value in metrics.items():\n", + " # take last value\n", + " logged_metrics[key] = value[-1]\n", + "\n", + " # take all values\n", + " if key in all_metrics.keys():\n", + " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " else:\n", + " all_metrics[key] = value\n", + "\n", + " csv_logger.log(logged_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "# create the x-axis array\n", + "env_steps = jnp.arange(num_iterations) * episode_length * batch_size\n", + "\n", + "# create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to save/load a repertoire\n", + "\n", + "The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the final repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire_path = \"./last_repertoire/\"\n", + "os.makedirs(repertoire_path, exist_ok=True)\n", + "repertoire.save(path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build the reconstruction function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Init population of policies\n", + "random_key, subkey = jax.random.split(random_key)\n", + "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", + "fake_params = policy_network.init(subkey, fake_batch)\n", + "\n", + "_, reconstruction_fn = ravel_pytree(fake_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use the reconstruction function to load and re-create the repertoire" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "repertoire = MELSRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get the best individual of the repertoire\n", + "\n", + "Note that in ME-LS, the individual's cell is computed by finding its most frequent archive cell among its `num_samples` behavior descriptors. Thus, the descriptor associated with each individual in the archive is not its mean descriptor. Rather, we set the descriptor in the archive to be the centroid of the individual's most frequent archive cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "best_idx = jnp.argmax(repertoire.fitnesses)\n", + "best_fitness = jnp.max(repertoire.fitnesses)\n", + "best_bd = repertoire.descriptors[best_idx]\n", + "best_spread = repertoire.spreads[best_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " f\"Best fitness in the repertoire: {best_fitness:.2f}\\n\"\n", + " f\"Behavior descriptor of the best individual in the repertoire: {best_bd}\\n\"\n", + " f\"Spread of the best individual in the repertoire: {best_spread}\\n\"\n", + " f\"Index in the repertoire of this individual: {best_idx}\\n\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "my_params = jax.tree_util.tree_map(\n", + " lambda x: x[best_idx],\n", + " repertoire.genotypes\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Play some steps in the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jit_env_reset = jax.jit(env.reset)\n", + "jit_env_step = jax.jit(env.step)\n", + "jit_inference_fn = jax.jit(policy_network.apply)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rollout = []\n", + "rng = jax.random.PRNGKey(seed=1)\n", + "state = jit_env_reset(rng=rng)\n", + "while not state.done:\n", + " rollout.append(state)\n", + " action = jit_inference_fn(my_params, state.obs)\n", + " state = jit_env_step(state, action)\n", + "\n", + "print(f\"The trajectory of this individual contains {len(rollout)} transitions.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mkdocs.yml b/mkdocs.yml index 702a474a..2c0bbdb6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -140,6 +140,7 @@ nav: - MOME: api_documentation/core/mome.md - ME ES: api_documentation/core/mees.md - ME PBT: api_documentation/core/me_pbt.md + - ME LS: api_documentation/core/mels.md - Baseline algorithms: - SMERL: api_documentation/core/smerl.md - DIAYN: api_documentation/core/diayn.md diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py new file mode 100644 index 00000000..a2e99971 --- /dev/null +++ b/qdax/core/containers/mels_repertoire.py @@ -0,0 +1,311 @@ +"""This file contains the class to define the repertoire used to +store individuals in the Multi-Objective MAP-Elites algorithm as +well as several variants.""" + +from __future__ import annotations + +from typing import Callable, Optional + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from qdax.core.containers.mapelites_repertoire import ( + MapElitesRepertoire, + get_cells_indices, +) +from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, Spread + + +def _dispersion(descriptors: jnp.ndarray) -> jnp.ndarray: + """Computes dispersion of a batch of num_samples descriptors. + + Args: + descriptors: (num_samples, num_descriptors) array of descriptors. + Returns: + The float dispersion of the descriptors (this is represented as a scalar + jnp.ndarray). + """ + + # Pairwise distances between the descriptors. + dists = jnp.linalg.norm(descriptors[:, None] - descriptors, axis=2) + + # Compute dispersion -- this is the mean of the unique pairwise distances. + # + # Zero out the duplicate distances since the distance matrix is diagonal. + # Setting k=1 will also remove entries on the diagonal since they are zero. + dists = jnp.triu(dists, k=1) + + num_samples = len(descriptors) + n_pairwise = num_samples * (num_samples - 1) / 2.0 + + return jnp.sum(dists) / n_pairwise + + +def _mode(x: jnp.ndarray) -> jnp.ndarray: + """Computes mode (most common item) of an array. + + The return type is a scalar ndarray. + """ + unique_vals, counts = jnp.unique(x, return_counts=True, size=x.size) + return unique_vals[jnp.argmax(counts)] + + +class MELSRepertoire(MapElitesRepertoire): + """Class for the repertoire in MAP-Elites Low-Spread. + + This class inherits from MapElitesRepertoire. In addition to the stored data in + MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire + also maintains an array of spreads. We overload the save, load, add, and + init_default methods of MapElitesRepertoire. + + Refer to Mace 2023 for more info on MAP-Elites Low-Spread: + https://dl.acm.org/doi/abs/10.1145/3583131.3590433 + + Args: + genotypes: a PyTree containing all the genotypes in the repertoire ordered + by the centroids. Each leaf has a shape (num_centroids, num_features). The + PyTree can be a simple Jax array or a more complex nested structure such + as to represent parameters of neural network in Flax. + fitnesses: an array that contains the fitness of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + descriptors: an array that contains the descriptors of solutions in each cell + of the repertoire, ordered by centroids. The array shape + is (num_centroids, num_descriptors). + centroids: an array that contains the centroids of the tessellation. The array + shape is (num_centroids, num_descriptors). + spreads: an array that contains the spread of solutions in each cell of the + repertoire, ordered by centroids. The array shape is (num_centroids,). + """ + + spreads: Spread + + def save(self, path: str = "./") -> None: + """Saves the repertoire on disk in the form of .npy files. + + Flattens the genotypes to store it with .npy format. Supposes that + a user will have access to the reconstruction function when loading + the genotypes. + + Args: + path: Path where the data will be saved. Defaults to "./". + """ + + def flatten_genotype(genotype: Genotype) -> jnp.ndarray: + flatten_genotype, _ = ravel_pytree(genotype) + return flatten_genotype + + # flatten all the genotypes + flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes) + + # save data + jnp.save(path + "genotypes.npy", flat_genotypes) + jnp.save(path + "fitnesses.npy", self.fitnesses) + jnp.save(path + "descriptors.npy", self.descriptors) + jnp.save(path + "centroids.npy", self.centroids) + jnp.save(path + "spreads.npy", self.spreads) + + @classmethod + def load(cls, reconstruction_fn: Callable, path: str = "./") -> MELSRepertoire: + """Loads a MAP-Elites Low-Spread Repertoire. + + Args: + reconstruction_fn: Function to reconstruct a PyTree + from a flat array. + path: Path where the data is saved. Defaults to "./". + + Returns: + A MAP-Elites Low-Spread Repertoire. + """ + + flat_genotypes = jnp.load(path + "genotypes.npy") + genotypes = jax.vmap(reconstruction_fn)(flat_genotypes) + + fitnesses = jnp.load(path + "fitnesses.npy") + descriptors = jnp.load(path + "descriptors.npy") + centroids = jnp.load(path + "centroids.npy") + spreads = jnp.load(path + "spreads.npy") + + return cls( + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + spreads=spreads, + ) + + @jax.jit + def add( + self, + batch_of_genotypes: Genotype, + batch_of_descriptors: Descriptor, + batch_of_fitnesses: Fitness, + batch_of_extra_scores: Optional[ExtraScores] = None, + ) -> MELSRepertoire: + """ + Add a batch of elements to the repertoire. + + The key difference between this method and the default add() in + MapElitesRepertoire is that it expects each individual to be evaluated + `num_samples` times, resulting in `num_samples` fitnesses and + `num_samples` descriptors per individual. + + If multiple individuals may be added to a single cell, this method will + arbitrarily pick one -- the exact choice depends on the implementation of + jax.at[].set(), which can be non-deterministic: + https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html + We do not currently check if one of the multiple individuals dominates the + others (dominate means that the individual has both highest fitness and lowest + spread among the individuals for that cell). + + If `num_samples` is only 1, the spreads will default to 0. + + Args: + batch_of_genotypes: a batch of genotypes to be added to the repertoire. + Similarly to the self.genotypes argument, this is a PyTree in which + the leaves have a shape (batch_size, num_features) + batch_of_descriptors: an array that contains the descriptors of the + aforementioned genotypes over all evals. Its shape is + (batch_size, num_samples, num_descriptors). Note that we "aggregate" + descriptors by finding the most frequent cell of each individual. Thus, + the actual descriptors stored in the repertoire are just the coordinates + of the centroid of the most frequent cell. + batch_of_fitnesses: an array that contains the fitnesses of the + aforementioned genotypes over all evals. Its shape is (batch_size, + num_samples) + batch_of_extra_scores: unused tree that contains the extra_scores of + aforementioned genotypes. + + Returns: + The updated repertoire. + """ + batch_size, num_samples = batch_of_fitnesses.shape + + # Compute indices/cells of all descriptors. + batch_of_all_indices = get_cells_indices( + batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids + ).reshape((batch_size, num_samples)) + + # Compute most frequent cell of each solution. + batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None] + + # Compute dispersion / spread. The dispersion is set to zero if + # num_samples is 1. + batch_of_spreads = jax.lax.cond( + num_samples == 1, + lambda desc: jnp.zeros(batch_size), + lambda desc: jax.vmap(_dispersion)( + desc.reshape((batch_size, num_samples, -1)) + ), + batch_of_descriptors, + ) + batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1) + + # Compute canonical descriptors as the descriptor of the centroid of the most + # frequent cell. Note that this line redefines the earlier batch_of_descriptors. + batch_of_descriptors = jnp.take_along_axis( + self.centroids, batch_of_indices, axis=0 + ) + + # Compute canonical fitnesses as the average fitness. + # + # Shape: (batch_size, 1) + batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True) + + num_centroids = self.centroids.shape[0] + + # get current repertoire fitnesses and spreads + repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1) + current_fitnesses = jnp.take_along_axis( + repertoire_fitnesses, batch_of_indices, 0 + ) + + repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1) + current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0) + + # get addition condition + addition_condition_fitness = batch_of_fitnesses > current_fitnesses + addition_condition_spread = batch_of_spreads <= current_spreads + addition_condition = jnp.logical_and( + addition_condition_fitness, addition_condition_spread + ) + + # assign fake position when relevant : num_centroids is out of bound + batch_of_indices = jnp.where( + addition_condition, x=batch_of_indices, y=num_centroids + ) + + # create new repertoire + new_repertoire_genotypes = jax.tree_util.tree_map( + lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[ + batch_of_indices.squeeze(axis=-1) + ].set(new_genotypes), + self.genotypes, + batch_of_genotypes, + ) + + # compute new fitness and descriptors + new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_fitnesses.squeeze(axis=-1) + ) + new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_descriptors + ) + new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set( + batch_of_spreads.squeeze(axis=-1) + ) + + return MELSRepertoire( + genotypes=new_repertoire_genotypes, + fitnesses=new_fitnesses, + descriptors=new_descriptors, + centroids=self.centroids, + spreads=new_spreads, + ) + + @classmethod + def init_default( + cls, + genotype: Genotype, + centroids: Centroid, + ) -> MELSRepertoire: + """Initialize a MAP-Elites Low-Spread repertoire with an initial population of + genotypes. Requires the definition of centroids that can be computed with any + method such as CVT or Euclidean mapping. + + Note: this function has been kept outside of the object MELS, so + it can be called easily called from other modules. + + Args: + genotype: the typical genotype that will be stored. + centroids: the centroids of the repertoire. + + Returns: + A repertoire filled with default values. + """ + + # get number of centroids + num_centroids = centroids.shape[0] + + # default fitness is -inf + default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) + + # default genotypes is all 0 + default_genotypes = jax.tree_util.tree_map( + lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype), + genotype, + ) + + # default descriptor is all zeros + default_descriptors = jnp.zeros_like(centroids) + + # default spread is inf so that any spread will be less + default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf) + + return cls( + genotypes=default_genotypes, + fitnesses=default_fitnesses, + descriptors=default_descriptors, + centroids=centroids, + spreads=default_spreads, + ) diff --git a/qdax/core/mels.py b/qdax/core/mels.py new file mode 100644 index 00000000..6c06b785 --- /dev/null +++ b/qdax/core/mels.py @@ -0,0 +1,104 @@ +"""Core components of the MAP-Elites Low-Spread algorithm.""" +from __future__ import annotations + +from functools import partial +from typing import Callable, Optional, Tuple + +import jax + +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.map_elites import MAPElites +from qdax.types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Metrics, + RNGKey, +) +from qdax.utils.sampling import multi_sample_scoring_function + + +class MELS(MAPElites): + """Core elements of the MAP-Elites Low-Spread algorithm. + + Most methods in this class are inherited from MAPElites. + + The same scoring function can be passed into both MAPElites and this class. + We have overridden __init__ such that it takes in the scoring function and + wraps it such that every solution is evaluated `num_samples` times. + + We also overrode the init method to use the MELSRepertoire instead of + MapElitesRepertoire. + """ + + def __init__( + self, + scoring_function: Callable[ + [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey] + ], + emitter: Emitter, + metrics_function: Callable[[MELSRepertoire], Metrics], + num_samples: int, + ) -> None: + self._scoring_function = partial( + multi_sample_scoring_function, + scoring_fn=scoring_function, + num_samples=num_samples, + ) + self._emitter = emitter + self._metrics_function = metrics_function + self._num_samples = num_samples + + @partial(jax.jit, static_argnames=("self",)) + def init( + self, + init_genotypes: Genotype, + centroids: Centroid, + random_key: RNGKey, + ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: + """Initialize a MAP-Elites Low-Spread repertoire with an initial + population of genotypes. Requires the definition of centroids that can + be computed with any method such as CVT or Euclidean mapping. + + Args: + init_genotypes: initial genotypes, pytree in which leaves + have shape (batch_size, num_features) + centroids: tessellation centroids of shape (batch_size, num_descriptors) + random_key: a random key used for stochastic operations. + + Returns: + A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter + state, JAX random key). + """ + # score initial genotypes + fitnesses, descriptors, extra_scores, random_key = self._scoring_function( + init_genotypes, random_key + ) + + # init the repertoire + repertoire = MELSRepertoire.init( + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + centroids=centroids, + extra_scores=extra_scores, + ) + + # get initial state of the emitter + emitter_state, random_key = self._emitter.init( + init_genotypes=init_genotypes, random_key=random_key + ) + + # update emitter state + emitter_state = self._emitter.state_update( + emitter_state=emitter_state, + repertoire=repertoire, + genotypes=init_genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, + ) + return repertoire, emitter_state, random_key diff --git a/qdax/types.py b/qdax/types.py index 67fbb8a0..5000869b 100644 --- a/qdax/types.py +++ b/qdax/types.py @@ -26,6 +26,7 @@ Genotype: TypeAlias = ArrayTree Descriptor: TypeAlias = jnp.ndarray Centroid: TypeAlias = jnp.ndarray +Spread: TypeAlias = jnp.ndarray Gradient: TypeAlias = jnp.ndarray Skill: TypeAlias = jnp.ndarray diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index 88b6286e..a25e190f 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -8,7 +8,7 @@ from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey -@partial(jax.jit, static_argnames=("num_samples")) +@partial(jax.jit, static_argnames=("num_samples",)) def dummy_extra_scores_extractor( extra_scores: ExtraScores, num_samples: int, @@ -29,6 +29,60 @@ def dummy_extra_scores_extractor( return extra_scores +@partial( + jax.jit, + static_argnames=( + "scoring_fn", + "num_samples", + ), +) +def multi_sample_scoring_function( + policies_params: Genotype, + random_key: RNGKey, + scoring_fn: Callable[ + [Genotype, RNGKey], + Tuple[Fitness, Descriptor, ExtraScores, RNGKey], + ], + num_samples: int, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Wrap scoring_function to perform sampling. + + This function returns the fitnesses, descriptors, and extra_scores computed + over num_samples evaluations with the scoring_fn. + + Args: + policies_params: policies to evaluate + random_key: JAX random key + scoring_fn: scoring function used for evaluation + num_samples: number of samples to generate for each individual + + Returns: + (n, num_samples) array of fitnesses, + (n, num_samples, num_descriptors) array of descriptors, + dict with num_samples extra_scores per individual, + JAX random key + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=num_samples) + + # evaluate + sample_scoring_fn = jax.vmap( + scoring_fn, + # vectorizing over axis 0 vectorizes over the num_samples random keys + in_axes=(None, 0), + # indicates that the vectorized axis will become axis 1, i.e., the final + # output is shape (batch_size, num_samples, ...) + out_axes=1, + ) + all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( + policies_params, keys + ) + + return all_fitnesses, all_descriptors, all_extra_scores, random_key + + @partial( jax.jit, static_argnames=( @@ -49,14 +103,16 @@ def sampling( [ExtraScores, int], ExtraScores ] = dummy_extra_scores_extractor, ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """ - Wrap scoring_function to perform sampling. + """Wrap scoring_function to perform sampling. + + This function averages the fitnesses and descriptors for each individual + over `num_samples` evaluations. Args: policies_params: policies to evaluate - random_key + random_key: JAX random key scoring_fn: scoring function used for evaluation - num_samples + num_samples: number of samples to generate for each individual extra_scores_extractor: function to extract the extra_scores from multiple samples of the same policy. @@ -65,14 +121,13 @@ def sampling( The extra_score extract from samples with extra_scores_extractor A new random key """ - - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, num=num_samples) - - # evaluate - sample_scoring_fn = jax.vmap(scoring_fn, (None, 0), 1) - all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn( - policies_params, keys + ( + all_fitnesses, + all_descriptors, + all_extra_scores, + random_key, + ) = multi_sample_scoring_function( + policies_params, random_key, scoring_fn, num_samples ) # average results diff --git a/tests/core_test/containers_test/mels_repertoire_test.py b/tests/core_test/containers_test/mels_repertoire_test.py new file mode 100644 index 00000000..2fb1bd76 --- /dev/null +++ b/tests/core_test/containers_test/mels_repertoire_test.py @@ -0,0 +1,236 @@ +import jax.numpy as jnp +import pytest + +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.types import ExtraScores + + +def test_add_to_mels_repertoire() -> None: + """Test several additions to the MELSRepertoire, including adding a solution + and overwriting it by adding multiple solutions.""" + genotype_size = 12 + num_centroids = 4 + num_descriptors = 2 + + # create a repertoire instance + repertoire = MELSRepertoire( + genotypes=jnp.zeros(shape=(num_centroids, genotype_size)), + fitnesses=jnp.ones(shape=(num_centroids,)) * (-jnp.inf), + descriptors=jnp.zeros(shape=(num_centroids, num_descriptors)), + centroids=jnp.array( + [ + [1.0, 1.0], + [2.0, 1.0], + [2.0, 2.0], + [1.0, 2.0], + ] + ), + spreads=jnp.full(shape=(num_centroids,), fill_value=jnp.inf), + ) + + # + # Test 1: Insert a single solution. + # + + # create fake genotypes and scores to add + fake_genotypes = jnp.ones(shape=(1, genotype_size)) + # each solution gets two fitnesses and two descriptors + fake_fitnesses = jnp.array([[0.0, 0.0]]) + fake_descriptors = jnp.array([[[0.0, 1.0], [1.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # check that the repertoire looks as expected + expected_genotypes = jnp.array( + [ + [1.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses = jnp.array([0.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads = jnp.array([1.0, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume(jnp.allclose(repertoire.genotypes, expected_genotypes, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.fitnesses, expected_fitnesses, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.descriptors, expected_descriptors, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads, atol=1e-6)) + + # + # Test 2: Adding solutions into the same cell as above. + # + + # create fake genotypes and scores to add + fake_genotypes = jnp.concatenate( + ( + jnp.full(shape=(1, genotype_size), fill_value=2.0), + jnp.full(shape=(1, genotype_size), fill_value=3.0), + ), + axis=0, + ) + # Each solution gets two fitnesses and two descriptors (i.e. num_evals = 2). One + # solution has fitness 1.0 and spread 0.75, while the other has fitness 0.5 and + # spread 0.5. Thus, neither solution dominates the other (by having both higher + # fitness and lower spread). However, both solutions would be valid candidates for + # the archive due to dominating the current solution there. + fake_fitnesses = jnp.array([[1.0, 1.0], [0.5, 0.5]]) + fake_descriptors = jnp.array([[[1.0, 0.25], [1.0, 1.0]], [[1.0, 0.5], [1.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # Either solution may be added due to the behavior of jax.at[].set(): + # https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html + # Thus, we provide possible values for each scenario. + + # check that the repertoire looks like expected + expected_genotypes_1 = jnp.array( + [ + [2.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses_1 = jnp.array([1.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors_1 = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads_1 = jnp.array([0.75, jnp.inf, jnp.inf, jnp.inf]) + + expected_genotypes_2 = jnp.array( + [ + [3.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses_2 = jnp.array([0.5, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors_2 = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + expected_spreads_2 = jnp.array([0.5, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6) + or jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6) + ) + + if jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6): + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_1, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.fitnesses, expected_fitnesses_1, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.descriptors, expected_descriptors_1, atol=1e-6) + ) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads_1, atol=1e-6)) + elif jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6): + pytest.assume( + jnp.allclose(repertoire.genotypes, expected_genotypes_2, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.fitnesses, expected_fitnesses_2, atol=1e-6) + ) + pytest.assume( + jnp.allclose(repertoire.descriptors, expected_descriptors_2, atol=1e-6) + ) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads_2, atol=1e-6)) + + +def test_add_with_single_eval() -> None: + """Tries adding with a single evaluation. + + This is a special case because the spread defaults to 0. + """ + genotype_size = 12 + num_centroids = 4 + num_descriptors = 2 + + # create a repertoire instance + repertoire = MELSRepertoire( + genotypes=jnp.zeros(shape=(num_centroids, genotype_size)), + fitnesses=jnp.ones(shape=(num_centroids,)) * (-jnp.inf), + descriptors=jnp.zeros(shape=(num_centroids, num_descriptors)), + centroids=jnp.array( + [ + [1.0, 1.0], + [2.0, 1.0], + [2.0, 2.0], + [1.0, 2.0], + ] + ), + spreads=jnp.full(shape=(num_centroids,), fill_value=jnp.inf), + ) + + # Insert a single solution with only one eval. + + # create fake genotypes and scores to add + fake_genotypes = jnp.ones(shape=(1, genotype_size)) + # the solution gets one fitness and one descriptor. + fake_fitnesses = jnp.array([[0.0]]) + fake_descriptors = jnp.array([[[0.0, 1.0]]]) + fake_extra_scores: ExtraScores = {} + + # do an addition + repertoire = repertoire.add( + fake_genotypes, fake_descriptors, fake_fitnesses, fake_extra_scores + ) + + # check that the repertoire looks as expected + expected_genotypes = jnp.array( + [ + [1.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + [0.0 for _ in range(genotype_size)], + ] + ) + expected_fitnesses = jnp.array([0.0, -jnp.inf, -jnp.inf, -jnp.inf]) + expected_descriptors = jnp.array( + [ + [1.0, 1.0], # Centroid coordinates. + [0.0, 0.0], + [0.0, 0.0], + [0.0, 0.0], + ] + ) + # Spread should be 0 since there's only one eval. + expected_spreads = jnp.array([0.0, jnp.inf, jnp.inf, jnp.inf]) + + # check values + pytest.assume(jnp.allclose(repertoire.genotypes, expected_genotypes, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.fitnesses, expected_fitnesses, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.descriptors, expected_descriptors, atol=1e-6)) + pytest.assume(jnp.allclose(repertoire.spreads, expected_spreads, atol=1e-6)) diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py new file mode 100644 index 00000000..21f90517 --- /dev/null +++ b/tests/core_test/mels_test.py @@ -0,0 +1,156 @@ +"""Tests MAP-Elites Low-Spread implementation.""" + +import functools +from typing import Dict, Tuple + +import jax +import jax.numpy as jnp +import pytest + +from qdax import environments +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids +from qdax.core.containers.mels_repertoire import MELSRepertoire +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.mels import MELS +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs +from qdax.types import EnvState, Params, RNGKey + + +@pytest.mark.parametrize( + "env_name, batch_size", + [("walker2d_uni", 1), ("walker2d_uni", 10), ("hopper_uni", 10)], +) +def test_mels(env_name: str, batch_size: int) -> None: + batch_size = batch_size + env_name = env_name + num_samples = 5 + episode_length = 100 + num_iterations = 5 + seed = 42 + policy_hidden_layer_sizes = (64, 64) + num_init_cvt_samples = 1000 + num_centroids = 50 + min_bd = 0.0 + max_bd = 1.0 + + # Init environment + env = environments.create(env_name, episode_length=episode_length) + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + # Init population of controllers. There are batch_size controllers, and each + # controller will be evaluated num_samples times. + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) + init_variables = jax.vmap(policy_network.init)(keys, fake_batch) + + # Define the function to play a step with the policy in the environment + def play_step_fn( + env_state: EnvState, + policy_params: Params, + random_key: RNGKey, + ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: + """Play an environment step and return the updated state and the + transition.""" + + actions = policy_network.apply(policy_params, env_state.obs) + + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = QDTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + actions=actions, + truncations=next_state.info["truncation"], + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + ) + + return next_state, policy_params, random_key, transition + + # Prepare the scoring function + bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] + scoring_fn = functools.partial( + reset_based_scoring_function_brax_envs, + episode_length=episode_length, + play_reset_fn=env.reset, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + # Define emitter + variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = environments.reward_offset[env_name] + + # Define a metrics function + def metrics_fn(repertoire: MELSRepertoire) -> Dict: + # Get metrics + grid_empty = repertoire.fitnesses == -jnp.inf + qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) + # Add offset for positive qd_score + qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) + coverage = 100 * jnp.mean(1.0 - grid_empty) + max_fitness = jnp.max(repertoire.fitnesses) + + return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage} + + # Instantiate ME-LS. + mels = MELS( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + num_samples=num_samples, + ) + + # Compute the centroids + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = mels.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + mels.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_mels(env_name="pointmaze", batch_size=10)