diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9329a64..1414d749 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,17 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 24.8.0 hooks: - id: black - language_version: python3.9 - args: ["--target-version", "py39"] + language_version: python3.10 + args: ["--target-version", "py310"] - repo: https://github.com/PyCQA/flake8 - rev: 3.8.4 + rev: 7.1.1 hooks: - id: flake8 args: ['--max-line-length=88', '--extend-ignore=E203'] @@ -21,12 +21,12 @@ repos: - flake8-comprehensions - flake8-bugbear - repo: https://github.com/kynan/nbstripout - rev: 0.3.9 + rev: 0.7.1 hooks: - id: nbstripout args: ["examples/"] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 + rev: v4.6.0 hooks: - id: debug-statements - id: requirements-txt-fixer @@ -42,6 +42,6 @@ repos: - id: trailing-whitespace # This hook trims trailing whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.942 + rev: v1.11.2 hooks: - id: mypy diff --git a/.readthedocs.yaml b/.readthedocs.yaml index d9f0965b..e22967ae 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.9" + python: "3.10" apt_packages: - swig diff --git a/README.md b/README.md index 551680eb..e7955450 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,12 @@ QDax is available on PyPI and can be installed with: ```bash pip install qdax ``` + +To install QDax with CUDA 12 support, use: +```bash +pip install qdax[cuda12] +``` + Alternatively, the latest commit of QDax can be installed directly from source with: ```bash pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main @@ -129,6 +135,7 @@ QDax currently supports the following algorithms: | [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | | [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) | | [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) | +| [DCRL-ME](https://arxiv.org/abs/2401.08632) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dcrlme.ipynb) | | [QDPG](https://arxiv.org/abs/2006.08505) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/qdpg.ipynb) | | [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmame.ipynb) | | [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/omgmega.ipynb) | @@ -166,13 +173,14 @@ Issues and contributions are welcome. Please refer to the [contribution guide](h ## Citing QDax If you use QDax in your research and want to cite it in your work, please use: ``` -@misc{chalumeau2023qdax, - title={QDax: A Library for Quality-Diversity and Population-based Algorithms with Hardware Acceleration}, - author={Felix Chalumeau and Bryan Lim and Raphael Boige and Maxime Allard and Luca Grillotti and Manon Flageat and Valentin Macé and Arthur Flajolet and Thomas Pierrot and Antoine Cully}, - year={2023}, - eprint={2308.03665}, - archivePrefix={arXiv}, - primaryClass={cs.AI} +@article{chalumeau2024qdax, + title={Qdax: A library for quality-diversity and population-based algorithms with hardware acceleration}, + author={Chalumeau, Felix and Lim, Bryan and Boige, Raphael and Allard, Maxime and Grillotti, Luca and Flageat, Manon and Mac{\'e}, Valentin and Richard, Guillaume and Flajolet, Arthur and Pierrot, Thomas and others}, + journal={Journal of Machine Learning Research}, + volume={25}, + number={108}, + pages={1--16}, + year={2024} } ``` diff --git a/dev.Dockerfile b/dev.Dockerfile index 458599db..305e29e0 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -16,7 +16,7 @@ RUN micromamba create -y --file /tmp/environment.yaml \ FROM python as test-image -ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy310/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH COPY --from=conda /opt/conda/envs/. /opt/conda/envs/ @@ -26,7 +26,7 @@ RUN pip install -r requirements-dev.txt FROM nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04 as cuda-image -ENV PATH=/opt/conda/envs/qdaxpy39/bin/:$PATH APP_FOLDER=/app +ENV PATH=/opt/conda/envs/qdaxpy310/bin/:$PATH APP_FOLDER=/app ENV PYTHONPATH=$APP_FOLDER:$PYTHONPATH @@ -70,7 +70,7 @@ RUN apt-get update && \ libosmesa6-dev \ patchelf \ python3-opengl \ - python3-dev=3.9* \ + python3-dev=3.10* \ python3-pip \ screen \ sudo \ diff --git a/docs/api_documentation/core/dcrlme.md b/docs/api_documentation/core/dcrlme.md new file mode 100644 index 00000000..698d75eb --- /dev/null +++ b/docs/api_documentation/core/dcrlme.md @@ -0,0 +1,5 @@ +# Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME) + +To create an instance of DCRL-ME, one need to use an instance of [MAP-Elites](map_elites.md) with the `DCRLMEEmitter`, detailed below. + +::: qdax.core.emitters.dcrl_me_emitter.DCRLMEEmitter diff --git a/docs/api_documentation/core/map_elites.md b/docs/api_documentation/core/map_elites.md index 5da3af58..89f1948a 100644 --- a/docs/api_documentation/core/map_elites.md +++ b/docs/api_documentation/core/map_elites.md @@ -2,7 +2,7 @@ This class implement the base mechanism of MAP-Elites. It must be used with an emitter. To get the usual MAP-Elites algorithm, one must use the [mixing emitter](emitters.md#qdax.core.emitters.standard_emitters.MixingEmitter). -The MAP-Elites class can be used with other emitters to create variants, like [PGAME](pgame.md), [CMA-MEGA](cma_mega.md) and [OMG-MEGA](omg_mega.md). +The MAP-Elites class can be used with other emitters to create variants, like [PGAME](pgame.md), [DCRL-ME](dcrlme.md) [CMA-MEGA](cma_mega.md) and [OMG-MEGA](omg_mega.md). ::: qdax.core.map_elites.MAPElites diff --git a/docs/installation.md b/docs/installation.md index 90c62659..585af828 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -86,7 +86,7 @@ git clone git@github.com:adaptive-intelligent-robotics/QDax.git 2. Activate the environment and manually install the package qdax ```zsh - conda activate qdaxpy39 + conda activate qdaxpy310 pip install -e . ``` diff --git a/environment.yaml b/environment.yaml index 0ddf80d5..d93726af 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,9 +1,9 @@ -name: qdaxpy39 +name: qdaxpy310 channels: - defaults - conda-forge dependencies: -- python=3.9 +- python=3.10 - pip>=20.3.3 - conda>=4.9.2 - pip: diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index e4b86238..fb955a98 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -49,19 +49,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -93,7 +93,7 @@ "from qdax.core.emitters.mutation_operators import isoline_variation\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", "\n", - "from qdax.types import Observation\n", + "from qdax.custom_types import Observation\n", "from qdax.utils import train_seq2seq\n", "\n", "\n", @@ -512,11 +512,8 @@ } ], "metadata": { - "interpreter": { - "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -530,7 +527,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index c8e2a9fe..8b87473d 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "5c4ab97a", + "id": "0", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmaes.ipynb)" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "222bbe00", + "id": "1", "metadata": {}, "source": [ "# Optimizing with CMA-ES in Jax\n", @@ -26,7 +26,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d731f067", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -36,19 +36,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -71,7 +71,7 @@ }, { "cell_type": "markdown", - "id": "7b6e910b", + "id": "3", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -80,7 +80,7 @@ { "cell_type": "code", "execution_count": null, - "id": "404fb0dc", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "markdown", - "id": "ccc7cbeb", + "id": "5", "metadata": { "pycharm": { "name": "#%% md\n" @@ -111,7 +111,7 @@ { "cell_type": "code", "execution_count": null, - "id": "436dccbb", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -133,7 +133,7 @@ }, { "cell_type": "markdown", - "id": "62bdd2a4", + "id": "7", "metadata": { "pycharm": { "name": "#%% md\n" @@ -146,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4cf03f55", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -167,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "f1f69f50", + "id": "9", "metadata": { "pycharm": { "name": "#%% md\n" @@ -180,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1a95b74d", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -194,7 +194,7 @@ }, { "cell_type": "markdown", - "id": "ac2d5c0d", + "id": "11", "metadata": { "pycharm": { "name": "#%% md\n" @@ -207,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "363198ca", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -245,7 +245,7 @@ }, { "cell_type": "markdown", - "id": "0e5820b8", + "id": "13", "metadata": {}, "source": [ "## Check final fitnesses and distribution mean" @@ -254,7 +254,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e4a2c7b", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +272,7 @@ }, { "cell_type": "markdown", - "id": "f3bd2b0f", + "id": "15", "metadata": { "pycharm": { "name": "#%% md\n" @@ -285,7 +285,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ad85551c", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -333,7 +333,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 3c355eea..7da832eb 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -41,19 +41,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -141,9 +141,9 @@ "def clip(x: jnp.ndarray):\n", " in_bound = (x <= maxval) * (x >= minval)\n", " return jnp.where(\n", - " condition=in_bound,\n", - " x=x,\n", - " y=(maxval / x)\n", + " in_bound,\n", + " x,\n", + " (maxval / x)\n", " )\n", "\n", "def _behavior_descriptor_1(x: jnp.ndarray):\n", @@ -387,7 +387,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index e5749993..a90f8309 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -35,19 +35,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -315,7 +315,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/dads.ipynb b/examples/dads.ipynb index b3cc43b5..19348348 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -45,19 +45,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -554,7 +548,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb new file mode 100644 index 00000000..5c367e37 --- /dev/null +++ b/examples/dcrlme.ipynb @@ -0,0 +1,480 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Optimizing with DCRL-ME in Jax\n", + "\n", + "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME)](https://arxiv.org/abs/2401.08632), also known as *Descriptor-Conditioned Gradients MAP-Elites with Actor Injection (DCG-ME-AI)*. \n", + "This algorithm extends and improves upon [Descriptor-Conditioned Gradients MAP-Elites (DCG-ME)](https://dl.acm.org/doi/abs/10.1145/3583131.3590503)\n", + "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", + "\n", + "- how to define the problem\n", + "- how to create the DCRL emitter\n", + "- how to create a Map-elites instance\n", + "- which functions must be defined before training\n", + "- how to launch a certain number of training steps\n", + "- how to visualize the results of the training process" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Installs and Imports\n", + "!pip install ipympl |tail -n 1\n", + "# %matplotlib widget\n", + "# from google.colab import output\n", + "# output.enable_custom_widget_manager()\n", + "\n", + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "import functools\n", + "import time\n", + "from typing import Any, Tuple\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import pytest\n", + "from brax.envs import Env, State, Wrapper\n", + "\n", + "try:\n", + " import brax\n", + "except:\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", + " import brax\n", + "\n", + "try:\n", + " import flax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", + " import flax\n", + "\n", + "try:\n", + " import chex\n", + "except:\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", + " import chex\n", + "\n", + "try:\n", + " import jumanji\n", + "except:\n", + " !pip install \"jumanji==0.3.1\"\n", + " import jumanji\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " import qdax\n", + "\n", + "\n", + "from qdax import environments\n", + "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", + "from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter\n", + "from qdax.core.emitters.mutation_operators import isoline_variation\n", + "from qdax.core.map_elites import MAPElites\n", + "from qdax.core.neuroevolution.buffers.buffer import DCRLTransition\n", + "from qdax.core.neuroevolution.networks.networks import MLP, MLPDC\n", + "from qdax.custom_types import EnvState, Params, RNGKey\n", + "from qdax.environments import behavior_descriptor_extractor\n", + "from qdax.environments.wrappers import OffsetRewardWrapper, ClipRewardWrapper\n", + "from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs\n", + "from qdax.utils.plotting import plot_map_elites_results\n", + "\n", + "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", + "\n", + "\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " from jax.tools import colab_tpu\n", + " colab_tpu.setup_tpu()\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title QD Training Definitions Fields\n", + "seed = 42 #@param {type:\"integer\"}\n", + "\n", + "env_name = \"ant_omni\" #@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", + "episode_length = 250 #@param {type:\"integer\"}\n", + "min_bd = -30.0 #@param {type:\"number\"}\n", + "max_bd = 30.0 #@param {type:\"number\"}\n", + "\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", + "batch_size = 256 #@param {type:\"integer\"}\n", + "\n", + "# Archive\n", + "num_init_cvt_samples = 50000 #@param {type:\"integer\"}\n", + "num_centroids = 1024 #@param {type:\"integer\"}\n", + "policy_hidden_layer_sizes = (128, 128) #@param {type:\"raw\"}\n", + "\n", + "# DCRL-ME\n", + "ga_batch_size = 128 #@param {type:\"integer\"}\n", + "dcrl_batch_size = 64 #@param {type:\"integer\"}\n", + "ai_batch_size = 64 #@param {type:\"integer\"}\n", + "lengthscale = 0.1 #@param {type:\"number\"}\n", + "\n", + "# GA emitter\n", + "iso_sigma = 0.005 #@param {type:\"number\"}\n", + "line_sigma = 0.05 #@param {type:\"number\"}\n", + "\n", + "# DCRL emitter\n", + "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", + "num_critic_training_steps = 3000\n", + "num_pg_training_steps = 150\n", + "replay_buffer_size = 1_000_000\n", + "discount = 0.99 #@param {type:\"number\"}\n", + "reward_scaling = 1.0 #@param {type:\"number\"}\n", + "critic_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "actor_learning_rate = 3e-4 #@param {type:\"number\"}\n", + "policy_learning_rate = 5e-3 #@param {type:\"number\"}\n", + "noise_clip = 0.5 #@param {type:\"number\"}\n", + "policy_noise = 0.2 #@param {type:\"number\"}\n", + "soft_tau_update = 0.005 #@param {type:\"number\"}\n", + "policy_delay = 2 #@param {type:\"number\"}\n", + "#@markdown ---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Init environment, policy, population params, init states of the env\n", + "\n", + "Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Init a random key\n", + "random_key = jax.random.PRNGKey(seed)\n", + "\n", + "# Init environment\n", + "env = environments.create(env_name, episode_length=episode_length)\n", + "env = OffsetRewardWrapper(\n", + " env, offset=environments.reward_offset[env_name]\n", + ") # apply reward offset as DCRL needs positive rewards\n", + "env = ClipRewardWrapper(\n", + " env, clip_min=0.,\n", + ") # apply reward clip as DCRL needs positive rewards\n", + "\n", + "reset_fn = jax.jit(env.reset)\n", + "\n", + "# Init policy network\n", + "policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n", + "policy_network = MLP(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "actor_dc_network = MLPDC(\n", + " layer_sizes=policy_layer_sizes,\n", + " kernel_init=jax.nn.initializers.lecun_uniform(),\n", + " final_activation=jnp.tanh,\n", + ")\n", + "\n", + "# Init population of controllers\n", + "random_key, subkey = jax.random.split(random_key)\n", + "keys = jax.random.split(subkey, num=batch_size)\n", + "fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size))\n", + "init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the way the policy interacts with the env" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the fonction to play a step with the policy in the environment\n", + "def play_step_fn(\n", + " env_state: EnvState, policy_params: Params, random_key: RNGKey\n", + ") -> Tuple[EnvState, Params, RNGKey, DCRLTransition]:\n", + " actions = policy_network.apply(policy_params, env_state.obs)\n", + " state_desc = env_state.info[\"state_descriptor\"]\n", + " next_state = env.step(env_state, actions)\n", + "\n", + " transition = DCRLTransition(\n", + " obs=env_state.obs,\n", + " next_obs=next_state.obs,\n", + " rewards=next_state.reward,\n", + " dones=next_state.done,\n", + " truncations=next_state.info[\"truncation\"],\n", + " actions=actions,\n", + " state_desc=state_desc,\n", + " next_state_desc=next_state.info[\"state_descriptor\"],\n", + " desc=jnp.zeros(\n", + " env.behavior_descriptor_length,\n", + " )\n", + " * jnp.nan,\n", + " desc_prime=jnp.zeros(\n", + " env.behavior_descriptor_length,\n", + " )\n", + " * jnp.nan,\n", + " )\n", + "\n", + " return next_state, policy_params, random_key, transition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the scoring function and the way metrics are computed\n", + "\n", + "The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the scoring function\n", + "bd_extraction_fn = behavior_descriptor_extractor[env_name]\n", + "scoring_fn = functools.partial(\n", + " reset_based_scoring_function_brax_envs,\n", + " episode_length=episode_length,\n", + " play_reset_fn=reset_fn,\n", + " play_step_fn=play_step_fn,\n", + " behavior_descriptor_extractor=bd_extraction_fn,\n", + ")\n", + "\n", + "# Get minimum reward value to make sure qd_score are positive\n", + "reward_offset = environments.reward_offset[env_name]\n", + "\n", + "# Define a metrics function\n", + "metrics_function = functools.partial(\n", + " default_qd_metrics,\n", + " qd_offset=reward_offset * episode_length,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the emitter: DCRL Emitter\n", + "\n", + "The emitter is used to evolve the population at each mutation step. In this example, the emitter is the Descriptor-Conditioned RL emitter, the one used in DCRL-ME. It trains a critic with the transitions experienced in the environment and uses the critic to apply Descriptor-Conditioned gradients updates to the policies evolved." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dcrl_emitter_config = DCRLMEConfig(\n", + " ga_batch_size=ga_batch_size,\n", + " dcrl_batch_size=dcrl_batch_size,\n", + " ai_batch_size=ai_batch_size,\n", + " lengthscale=lengthscale,\n", + " critic_hidden_layer_size=critic_hidden_layer_size,\n", + " num_critic_training_steps=num_critic_training_steps,\n", + " num_pg_training_steps=num_pg_training_steps,\n", + " batch_size=batch_size,\n", + " replay_buffer_size=replay_buffer_size,\n", + " discount=discount,\n", + " reward_scaling=reward_scaling,\n", + " critic_learning_rate=critic_learning_rate,\n", + " actor_learning_rate=actor_learning_rate,\n", + " policy_learning_rate=policy_learning_rate,\n", + " noise_clip=noise_clip,\n", + " policy_noise=policy_noise,\n", + " soft_tau_update=soft_tau_update,\n", + " policy_delay=policy_delay,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the emitter\n", + "variation_fn = functools.partial(\n", + " isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma\n", + ")\n", + "\n", + "dcrl_emitter = DCRLMEEmitter(\n", + " config=dcrl_emitter_config,\n", + " policy_network=policy_network,\n", + " actor_network=actor_dc_network,\n", + " env=env,\n", + " variation_fn=variation_fn,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiate and initialise the MAP Elites algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate MAP Elites\n", + "map_elites = MAPElites(\n", + " scoring_function=scoring_fn,\n", + " emitter=dcrl_emitter,\n", + " metrics_function=metrics_function,\n", + ")\n", + "\n", + "# Compute the centroids\n", + "centroids, random_key = compute_cvt_centroids(\n", + " num_descriptors=env.behavior_descriptor_length,\n", + " num_init_cvt_samples=num_init_cvt_samples,\n", + " num_centroids=num_centroids,\n", + " minval=min_bd,\n", + " maxval=max_bd,\n", + " random_key=random_key,\n", + ")\n", + "\n", + "# compute initial repertoire\n", + "repertoire, emitter_state, random_key = map_elites.init(\n", + " init_params, centroids, random_key\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def update_scan_fn(carry: Any, unused: Any) -> Any:\n", + " # iterate over grid\n", + " repertoire, emitter_state, metrics, random_key = map_elites.update(*carry)\n", + "\n", + " return (repertoire, emitter_state, random_key), metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "log_period = 10\n", + "num_loops = int(num_iterations / log_period)\n", + "\n", + "csv_logger = CSVLogger(\n", + " \"dcrlme-logs.csv\",\n", + " header=[\"loop\", \"iteration\", \"qd_score\", \"max_fitness\", \"coverage\", \"time\"]\n", + ")\n", + "all_metrics = {}\n", + "\n", + "# main loop\n", + "map_elites_scan_update = map_elites.scan_update\n", + "for i in range(num_loops):\n", + " start_time = time.time()\n", + " # main iterations\n", + " (\n", + " repertoire,\n", + " emitter_state,\n", + " random_key,\n", + " ), metrics = jax.lax.scan(\n", + " update_scan_fn,\n", + " (repertoire, emitter_state, random_key),\n", + " (),\n", + " length=log_period,\n", + " )\n", + " timelapse = time.time() - start_time\n", + "\n", + " # log metrics\n", + " logged_metrics = {\"time\": timelapse, \"loop\": 1+i, \"iteration\": 1 + i*log_period}\n", + " for key, value in metrics.items():\n", + " # take last value\n", + " logged_metrics[key] = value[-1]\n", + "\n", + " # take all values\n", + " if key in all_metrics.keys():\n", + " all_metrics[key] = jnp.concatenate([all_metrics[key], value])\n", + " else:\n", + " all_metrics[key] = value\n", + "\n", + " csv_logger.log(logged_metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "# create the x-axis array\n", + "env_steps = jnp.arange(740) * episode_length * batch_size\n", + "\n", + "%matplotlib inline\n", + "# create the plots and the grid\n", + "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "vscode": { + "interpreter": { + "hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index 0562e7c2..fba2055f 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -45,19 +45,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -544,7 +538,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.12" }, "vscode": { "interpreter": { diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 2e6fd991..0fe1094c 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -40,7 +40,12 @@ "from IPython.display import clear_output\n", "import functools\n", "\n", - "from tqdm import tqdm\n", + "try:\n", + " from tqdm import tqdm\n", + "except:\n", + " !pip install tqdm | tail -n 1\n", + " from tqdm import tqdm\n", + "\n", "import time\n", "\n", "import jax\n", @@ -49,19 +54,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -128,8 +133,7 @@ "outputs": [], "source": [ "# Get devices (change gpu by tpu if needed)\n", - "# devices = jax.devices('gpu')\n", - "devices = jax.devices('tpu')\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f'Detected the following {num_devices} device(s): {devices}')" ] @@ -351,7 +355,7 @@ "random_key = jnp.stack(random_key)\n", "\n", "# add a dimension for devices\n", - "init_variables = jax.tree_map(\n", + "init_variables = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (num_devices, batch_size_per_device,) + x.shape[1:]),\n", " init_variables\n", ")\n", @@ -360,7 +364,7 @@ "repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n", " centroids=centroids,\n", " devices=devices,\n", - ")(init_genotypes=init_variables, random_key=random_key)" + ")(genotypes=init_variables, random_key=random_key)" ] }, { @@ -397,7 +401,7 @@ " repertoire, emitter_state, random_key, metrics = update_fn(repertoire, emitter_state, random_key)\n", "\n", " # get metrics\n", - " metrics = jax.tree_map(lambda x: x[0], metrics)\n", + " metrics = jax.tree_util.tree_map(lambda x: x[0], metrics)\n", " timelapse = time.time() - start_time\n", "\n", " # log metrics\n", @@ -454,7 +458,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index a6a140fd..dc915524 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "233e0f03", + "id": "0", "metadata": {}, "source": [ "# Training a population on Jumanji-Snake with QDax\n", @@ -15,7 +15,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47b46c2f", + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -28,19 +28,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -72,13 +72,13 @@ "from qdax.core.emitters.mutation_operators import isoline_variation\n", "\n", "from qdax.core.emitters.standard_emitters import MixingEmitter\n", - "from qdax.types import ExtraScores, Fitness, RNGKey, Descriptor\n", + "from qdax.custom_types import ExtraScores, Fitness, RNGKey, Descriptor\n", "from qdax.utils.metrics import default_ga_metrics, default_qd_metrics" ] }, { "cell_type": "markdown", - "id": "03c2f1f7", + "id": "2", "metadata": {}, "source": [ "## Define hyperparameters" @@ -87,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52dd1e3b", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -97,7 +97,7 @@ "population_size = 100\n", "batch_size = population_size\n", "\n", - "num_iterations = 5000\n", + "num_iterations = 1000\n", "\n", "iso_sigma = 0.005\n", "line_sigma = 0.05" @@ -105,7 +105,7 @@ }, { "cell_type": "markdown", - "id": "8b8c890a", + "id": "4", "metadata": {}, "source": [ "## Instantiate the snake environment" @@ -114,7 +114,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a842cccc", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "markdown", - "id": "776862f1", + "id": "6", "metadata": {}, "source": [ "## Define the type of policy that will be used to solve the problem" @@ -141,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a1ce7d0", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -161,7 +161,7 @@ }, { "cell_type": "markdown", - "id": "49586b07", + "id": "8", "metadata": {}, "source": [ "## Utils to interact with the environment\n", @@ -172,12 +172,12 @@ { "cell_type": "code", "execution_count": null, - "id": "d1ff7827", + "id": "9", "metadata": {}, "outputs": [], "source": [ "def observation_processing(observation):\n", - " network_input = jnp.ravel(observation)\n", + " network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])\n", " return network_input\n", "\n", "\n", @@ -207,7 +207,7 @@ " obs=timestep.observation,\n", " next_obs=next_timestep.observation,\n", " rewards=next_timestep.reward,\n", - " dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),\n", + " dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),\n", " actions=action,\n", " truncations=jnp.array(0),\n", " state_desc=state_desc,\n", @@ -219,7 +219,7 @@ }, { "cell_type": "markdown", - "id": "0078bc01", + "id": "10", "metadata": {}, "source": [ "## Init a population of policies\n", @@ -230,7 +230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6cbd2065", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -240,7 +240,7 @@ "\n", "# compute observation size from observation spec\n", "obs_spec = env.observation_spec()\n", - "observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))\n", + "observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))\n", "\n", "fake_batch = jnp.zeros(shape=(batch_size, observation_size))\n", "init_variables = jax.vmap(policy_network.init)(keys, fake_batch)\n", @@ -255,7 +255,7 @@ }, { "cell_type": "markdown", - "id": "fe6bf07f", + "id": "12", "metadata": {}, "source": [ "## Define a method to extract behavior descriptor when relevant" @@ -264,7 +264,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a264b672", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +311,7 @@ }, { "cell_type": "markdown", - "id": "1cdc5f87", + "id": "14", "metadata": {}, "source": [ "## Define the scoring function" @@ -320,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b77d826", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "6555491a", + "id": "16", "metadata": {}, "source": [ "## Define the emitter used" @@ -342,7 +342,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30061ff4", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "da7e9b74", + "id": "18", "metadata": {}, "source": [ "## Define the algorithm used and apply the initial step\n", @@ -371,7 +371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7b5c2d6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +415,7 @@ }, { "cell_type": "markdown", - "id": "9b1bfee5", + "id": "20", "metadata": {}, "source": [ "## Run the optimization loop" @@ -424,7 +424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1af3a35", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "114ea4a8", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "92a35bf0", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79ada2d5", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -472,7 +472,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe5da301", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +489,7 @@ }, { "cell_type": "markdown", - "id": "93d8154e", + "id": "26", "metadata": {}, "source": [ "## Play snake with the best policy\n", @@ -500,7 +500,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3ff882f4", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -511,7 +511,7 @@ { "cell_type": "code", "execution_count": null, - "id": "762c167e", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -524,7 +524,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07523e33", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -537,7 +537,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c75ce088", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -550,7 +550,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50ef95f6", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -563,7 +563,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40a03409", + "id": "32", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index b1fea651..575ee0c0 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -49,19 +49,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index c387643a..778a7a5f 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -16,19 +16,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -38,12 +38,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -60,7 +54,7 @@ "from qdax.core.distributed_map_elites import DistributedMAPElites\n", "from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig\n", "from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn\n", - "from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n", + "from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey\n", "from qdax.utils.metrics import CSVLogger, default_qd_metrics\n", "from qdax.utils.plotting import plot_map_elites_results" ] @@ -80,7 +74,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -261,7 +256,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { @@ -348,7 +343,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 924d550f..da8d7311 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -17,19 +17,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -39,12 +39,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -62,7 +56,7 @@ "from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig\n", "from qdax.core.emitters.pbt_variation_operators import td3_pbt_variation_fn\n", "from qdax.core.distributed_map_elites import DistributedMAPElites\n", - "from qdax.types import RNGKey\n", + "from qdax.custom_types import RNGKey\n", "from qdax.utils.metrics import default_qd_metrics\n", "from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_map_elites_results" ] @@ -82,7 +76,8 @@ "metadata": {}, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -264,7 +259,7 @@ " lambda x: jnp.repeat(x, population_size, axis=0), first_states\n", " )\n", " population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)\n", - " return population_returns, population_bds, None, random_key" + " return population_returns, population_bds, {}, random_key" ] }, { @@ -351,7 +346,7 @@ "# initialize map-elites\n", "repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n", " devices=devices, centroids=centroids\n", - ")(init_genotypes=training_states, random_key=keys)" + ")(genotypes=training_states, random_key=keys)" ] }, { @@ -443,7 +438,7 @@ "num_cols = 5\n", "\n", "fig, axes = plt.subplots(\n", - " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30)\n", + " nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30), squeeze=False,\n", ")\n", "for i, repertoire in enumerate(repertoires):\n", "\n", @@ -492,7 +487,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/examples/mees.ipynb b/examples/mees.ipynb index ad1a4740..e9e37c2a 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -54,19 +54,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index bd489ca2..4f3fdc74 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -50,19 +50,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 555381e6..4661d406 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "59f748d3", + "id": "0", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb)" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "a5e13ff6", + "id": "1", "metadata": {}, "source": [ "# Optimizing multiple objectives with MOME in Jax\n", @@ -28,7 +28,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af063418", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -41,19 +41,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -81,12 +81,12 @@ "\n", "import matplotlib.pyplot as plt\n", "\n", - "from qdax.types import Fitness, Descriptor, RNGKey, ExtraScores" + "from qdax.custom_types import Fitness, Descriptor, RNGKey, ExtraScores" ] }, { "cell_type": "markdown", - "id": "22495c16", + "id": "3", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -95,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b96b5d07", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,7 @@ }, { "cell_type": "markdown", - "id": "c2850d54", + "id": "5", "metadata": {}, "source": [ "## Define the scoring function: rastrigin multi-objective\n", @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5effe11", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "231d273d", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "29250e72", + "id": "8", "metadata": {}, "source": [ "## Define the metrics function that will be used" @@ -187,7 +187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab5d6334", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "a4828ca8", + "id": "10", "metadata": {}, "source": [ "## Define the initial population and the emitter" @@ -211,14 +211,14 @@ { "cell_type": "code", "execution_count": null, - "id": "ebf3bd27", + "id": "11", "metadata": {}, "outputs": [], "source": [ "# initial population\n", "random_key = jax.random.PRNGKey(42)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -248,7 +248,7 @@ }, { "cell_type": "markdown", - "id": "c904664b", + "id": "12", "metadata": {}, "source": [ "## Compute the centroids" @@ -257,7 +257,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76547c4c", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +273,7 @@ }, { "cell_type": "markdown", - "id": "15936d15", + "id": "14", "metadata": {}, "source": [ "## Define a MOME instance" @@ -282,7 +282,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07a0d1d9", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -295,7 +295,7 @@ }, { "cell_type": "markdown", - "id": "f7ec5a77", + "id": "16", "metadata": {}, "source": [ "## Init the algorithm" @@ -304,12 +304,12 @@ { "cell_type": "code", "execution_count": null, - "id": "c05cbf1e", + "id": "17", "metadata": {}, "outputs": [], "source": [ "repertoire, emitter_state, random_key = mome.init(\n", - " init_genotypes,\n", + " genotypes,\n", " centroids,\n", " pareto_front_max_length,\n", " random_key\n", @@ -318,7 +318,7 @@ }, { "cell_type": "markdown", - "id": "6de4cedf", + "id": "18", "metadata": {}, "source": [ "## Run MOME iterations" @@ -327,7 +327,7 @@ { "cell_type": "code", "execution_count": null, - "id": "96ea04e6", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +344,7 @@ }, { "cell_type": "markdown", - "id": "3ff9ca98", + "id": "20", "metadata": {}, "source": [ "## Plot the results" @@ -353,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6766dc4f", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -363,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28ab56c9", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ab36cb7", + "id": "23", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index be662981..d6385291 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -42,19 +42,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -88,7 +88,7 @@ "from qdax.utils.plotting import plot_global_pareto_front\n", "from qdax.utils.metrics import default_ga_metrics\n", "\n", - "from qdax.types import Genotype, Fitness, Descriptor" + "from qdax.custom_types import Genotype, Fitness, Descriptor" ] }, { @@ -195,7 +195,7 @@ "# Initial population\n", "random_key = jax.random.PRNGKey(0)\n", "random_key, subkey = jax.random.split(random_key)\n", - "init_genotypes = jax.random.uniform(\n", + "genotypes = jax.random.uniform(\n", " subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n", ")\n", "\n", @@ -244,7 +244,7 @@ "\n", "# init nsga2\n", "repertoire, emitter_state, random_key = nsga2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " random_key\n", ")" @@ -309,7 +309,7 @@ "\n", "# init spea2\n", "repertoire, emitter_state, random_key = spea2.init(\n", - " init_genotypes,\n", + " genotypes,\n", " population_size,\n", " num_neighbours,\n", " random_key\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index 8d417cc0..c09deefe 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -37,19 +37,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 6152ce63..29f4cc74 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -49,19 +49,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -93,7 +93,7 @@ "from qdax.core.emitters.mutation_operators import isoline_variation\n", "from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter\n", "\n", - "from qdax.types import Observation\n", + "from qdax.custom_types import Observation\n", "from qdax.utils import train_seq2seq\n", "\n", "\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index 9b638b2d..d60e246a 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -48,19 +48,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -106,7 +106,7 @@ "#@markdown ---\n", "env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']\n", "episode_length = 250 #@param {type:\"integer\"}\n", - "num_iterations = 4000 #@param {type:\"integer\"}\n", + "num_iterations = 1000 #@param {type:\"integer\"}\n", "seed = 42 #@param {type:\"integer\"}\n", "policy_hidden_layer_sizes = (256, 256) #@param {type:\"raw\"}\n", "iso_sigma = 0.005 #@param {type:\"number\"}\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 102d5262..2082dcaa 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -48,19 +48,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index 7762083f..c71559f7 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1606cdf6", + "id": "0", "metadata": { "jupyter": { "outputs_hidden": false @@ -22,19 +22,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -44,12 +44,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -67,7 +61,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df61dcc5", + "id": "1", "metadata": { "jupyter": { "outputs_hidden": false @@ -84,7 +78,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7d0b0ef", + "id": "2", "metadata": { "jupyter": { "outputs_hidden": false @@ -95,7 +89,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"tpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -103,7 +98,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f342948", + "id": "3", "metadata": { "jupyter": { "outputs_hidden": false @@ -123,8 +118,8 @@ "buffer_size = 100000\n", "\n", "# PBT Config\n", - "num_best_to_replace_from = 20\n", - "num_worse_to_replace = 40\n", + "num_best_to_replace_from = 1\n", + "num_worse_to_replace = 1\n", "\n", "# SAC config\n", "batch_size = 256\n", @@ -144,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "090f8d4d", + "id": "4", "metadata": { "jupyter": { "outputs_hidden": false @@ -175,7 +170,7 @@ { "cell_type": "code", "execution_count": null, - "id": "efac713a", + "id": "5", "metadata": { "jupyter": { "outputs_hidden": false @@ -193,7 +188,7 @@ " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_map(\n", + " lambda tree: jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(\n", " x,\n", " (\n", @@ -214,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bccfc6d", + "id": "6", "metadata": { "jupyter": { "outputs_hidden": false @@ -237,7 +232,7 @@ { "cell_type": "code", "execution_count": null, - "id": "708eea0a", + "id": "7", "metadata": { "jupyter": { "outputs_hidden": false @@ -266,7 +261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e6e2bec", + "id": "8", "metadata": { "jupyter": { "outputs_hidden": false @@ -293,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f09fe1e", + "id": "9", "metadata": { "jupyter": { "outputs_hidden": false @@ -311,7 +306,7 @@ { "cell_type": "code", "execution_count": null, - "id": "66ba826a", + "id": "10", "metadata": { "jupyter": { "outputs_hidden": false @@ -336,7 +331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a49af55e", + "id": "11", "metadata": { "jupyter": { "outputs_hidden": false @@ -362,7 +357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b137c8e5", + "id": "12", "metadata": { "jupyter": { "outputs_hidden": false @@ -384,7 +379,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1dbbb855", + "id": "13", "metadata": { "jupyter": { "outputs_hidden": false @@ -397,8 +392,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_map(\n", + " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -407,7 +402,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a052ba2e", + "id": "14", "metadata": { "jupyter": { "outputs_hidden": false @@ -447,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f7354f2", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" @@ -461,7 +456,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cd1c27e8-1fe7-464d-8af3-72fa8d61852d", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -471,13 +466,13 @@ "source": [ "training_states = unshard_fn(training_states)\n", "best_idx = jnp.argmax(population_returns)\n", - "best_training_state = jax.tree_map(lambda x: x[best_idx], training_states)" + "best_training_state = jax.tree_util.tree_map(lambda x: x[best_idx], training_states)" ] }, { "cell_type": "code", "execution_count": null, - "id": "60e8ee82-27cf-4fa6-b189-66e6e10e2177", + "id": "17", "metadata": { "pycharm": { "name": "#%%\n" @@ -491,7 +486,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84bd809c-0127-4241-9556-9e81e550bbd2", + "id": "18", "metadata": { "pycharm": { "name": "#%%\n" @@ -509,7 +504,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2954be53-ffa5-42cf-8696-7f40f139edaf", + "id": "19", "metadata": { "pycharm": { "name": "#%%\n" @@ -523,7 +518,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14026ff2-e7f2-46eb-91d5-6e7394136e96", + "id": "20", "metadata": { "pycharm": { "name": "#%%\n" @@ -537,7 +532,7 @@ "rng = jax.random.PRNGKey(seed=1)\n", "env_state = jax.jit(env.reset)(rng=rng)\n", "\n", - "training_state, env_state = jax.tree_map(\n", + "training_state, env_state = jax.tree_util.tree_map(\n", " lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)\n", ")\n", "\n", @@ -552,7 +547,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5701084-e876-43f8-8de0-4216361ef5b4", + "id": "21", "metadata": { "pycharm": { "name": "#%%\n" @@ -561,7 +556,7 @@ "outputs": [], "source": [ "rollout = [\n", - " jax.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", + " jax.tree_util.tree_map(lambda x: jax.device_put(x[0], jax.devices(\"cpu\")[0]), env_state)\n", " for env_state in rollout\n", "]" ] @@ -569,7 +564,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85bb7556-37bb-4a20-88b3-28b298c8b0a9", + "id": "22", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py index 699c6aba..433bc1d2 100644 --- a/examples/scripts/me_example.py +++ b/examples/scripts/me_example.py @@ -79,7 +79,12 @@ def run_me() -> None: # Run MAP-Elites loop for _ in range(num_iterations): - (repertoire, emitter_state, metrics, random_key,) = map_elites.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = map_elites.update( repertoire, emitter_state, random_key, diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index fe655fe2..5f08e582 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -45,19 +45,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -67,12 +67,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index ec98b9da..b3d2cbe1 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf95707f", + "id": "0", "metadata": { "pycharm": { "name": "#%%\n" @@ -19,19 +19,19 @@ "try:\n", " import brax\n", "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1\n", + " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", " import brax\n", "\n", "try:\n", " import flax\n", "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", " import flax\n", "\n", "try:\n", " import chex\n", "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1\n", + " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", " import chex\n", "\n", "try:\n", @@ -41,12 +41,6 @@ " import jumanji\n", "\n", "try:\n", - " import haiku\n", - "except:\n", - " !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1\n", - " import haiku\n", - "\n", - "try:\n", " import qdax\n", "except:\n", " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", @@ -62,7 +56,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15a43429", + "id": "1", "metadata": { "pycharm": { "name": "#%%\n" @@ -76,7 +70,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32d15301", + "id": "2", "metadata": { "pycharm": { "name": "#%%\n" @@ -84,7 +78,8 @@ }, "outputs": [], "source": [ - "devices = jax.devices(\"gpu\")\n", + "# Get devices (change gpu by tpu if needed)\n", + "devices = jax.devices('gpu')\n", "num_devices = len(devices)\n", "print(f\"Detected the following {num_devices} device(s): {devices}\")" ] @@ -92,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7520b673", + "id": "3", "metadata": { "pycharm": { "name": "#%%\n" @@ -129,7 +124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c3718a4c", + "id": "4", "metadata": { "pycharm": { "name": "#%%\n" @@ -157,7 +152,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5485a16c", + "id": "5", "metadata": { "pycharm": { "name": "#%%\n" @@ -172,7 +167,7 @@ " eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)\n", "\n", " reshape_fn = jax.jit(\n", - " lambda tree: jax.tree_map(\n", + " lambda tree: jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(\n", " x, (population_size_per_device, env_batch_size,) + x.shape[1:]\n", " ),\n", @@ -188,7 +183,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4dc22ec4", + "id": "6", "metadata": { "pycharm": { "name": "#%%\n" @@ -208,7 +203,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9c610ba", + "id": "7", "metadata": { "pycharm": { "name": "#%%\n" @@ -232,7 +227,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f6fd3b9", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -256,7 +251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a412cd4f", + "id": "9", "metadata": { "pycharm": { "name": "#%%\n" @@ -271,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "535250a8", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -293,7 +288,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d24156f4", + "id": "11", "metadata": { "pycharm": { "name": "#%%\n" @@ -316,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23037e97", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -335,7 +330,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d9ebb235", + "id": "13", "metadata": { "pycharm": { "name": "#%%\n" @@ -345,8 +340,8 @@ "source": [ "@jax.jit\n", "def unshard_fn(sharded_tree):\n", - " tree = jax.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", - " tree = jax.tree_map(\n", + " tree = jax.tree_util.tree_map(lambda x: jax.device_put(x, \"cpu\"), sharded_tree)\n", + " tree = jax.tree_util.tree_map(\n", " lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree\n", " )\n", " return tree" @@ -355,7 +350,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a58253e5", + "id": "14", "metadata": { "pycharm": { "name": "#%%\n" @@ -392,7 +387,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6111e836", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/mkdocs.yml b/mkdocs.yml index 168c6ef5..9207b4f2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -114,6 +114,7 @@ nav: - Examples: - MAPElites: examples/mapelites.ipynb - PGAME: examples/pgame.ipynb + - DCRL-ME: examples/dcrlme.ipynb - CMA ME: examples/cmame.ipynb - QDPG: examples/qdpg.ipynb - OMG MEGA: examples/omgmega.ipynb @@ -136,6 +137,7 @@ nav: - Core algorithms: - MAP Elites: api_documentation/core/map_elites.md - PGAME: api_documentation/core/pgame.md + - DCRLME: api_documentation/core/dcrlme.md - QDPG: api_documentation/core/qdpg.md - CMA ME: api_documentation/core/cmame.md - OMG MEGA: api_documentation/core/omg_mega.md diff --git a/qdax/__init__.py b/qdax/__init__.py index 260c070a..6a9beea8 100644 --- a/qdax/__init__.py +++ b/qdax/__init__.py @@ -1 +1 @@ -__version__ = "0.3.1" +__version__ = "0.4.0" diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index 41f2ff08..bd4f4534 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -25,7 +25,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor +from qdax.custom_types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor class DadsTrainingState(TrainingState): @@ -430,12 +430,17 @@ def _update_dynamics( """ training_state, transitions = operand - dynamics_loss, dynamics_gradient = jax.value_and_grad(self._dynamics_loss_fn,)( + dynamics_loss, dynamics_gradient = jax.value_and_grad( + self._dynamics_loss_fn, + )( training_state.dynamics_params, transitions=transitions, ) - (dynamics_updates, dynamics_optimizer_state,) = self._dynamics_optimizer.update( + ( + dynamics_updates, + dynamics_optimizer_state, + ) = self._dynamics_optimizer.update( dynamics_gradient, training_state.dynamics_optimizer_state ) dynamics_params = optax.apply_updates( @@ -483,7 +488,11 @@ def _update_networks( random_key = training_state.random_key # Update skill-dynamics - (dynamics_params, dynamics_loss, dynamics_optimizer_state,) = jax.lax.cond( + ( + dynamics_params, + dynamics_loss, + dynamics_optimizer_state, + ) = jax.lax.cond( training_state.steps % self._config.dynamics_update_freq == 0, self._update_dynamics, self._not_update_dynamics, diff --git a/qdax/baselines/dads_smerl.py b/qdax/baselines/dads_smerl.py index 206f0012..5bd8274d 100644 --- a/qdax/baselines/dads_smerl.py +++ b/qdax/baselines/dads_smerl.py @@ -14,7 +14,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer from qdax.core.neuroevolution.normalization_utils import normalize_with_rmstd -from qdax.types import Metrics, Reward +from qdax.custom_types import Metrics, Reward @dataclass diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index c03cfb3f..0ebdfc32 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -20,7 +20,7 @@ from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode from qdax.core.neuroevolution.networks.diayn_networks import make_diayn_networks from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor +from qdax.custom_types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor class DiaynTrainingState(TrainingState): diff --git a/qdax/baselines/diayn_smerl.py b/qdax/baselines/diayn_smerl.py index 2966a692..daacaa74 100644 --- a/qdax/baselines/diayn_smerl.py +++ b/qdax/baselines/diayn_smerl.py @@ -13,7 +13,7 @@ from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer -from qdax.types import Metrics, Reward +from qdax.custom_types import Metrics, Reward @dataclass diff --git a/qdax/baselines/genetic_algorithm.py b/qdax/baselines/genetic_algorithm.py index 0714fb6c..b4c6a32f 100644 --- a/qdax/baselines/genetic_algorithm.py +++ b/qdax/baselines/genetic_algorithm.py @@ -1,4 +1,5 @@ """Core components of a basic genetic algorithm.""" + from functools import partial from typing import Any, Callable, Optional, Tuple @@ -6,7 +7,7 @@ from qdax.core.containers.ga_repertoire import GARepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ExtraScores, Fitness, Genotype, Metrics, RNGKey +from qdax.custom_types import ExtraScores, Fitness, Genotype, Metrics, RNGKey class GeneticAlgorithm: @@ -39,12 +40,12 @@ def __init__( @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]: """Initialize a GARepertoire with an initial population of genotypes. Args: - init_genotypes: the initial population of genotypes + genotypes: the initial population of genotypes population_size: the maximal size of the repertoire random_key: a random key to handle stochastic operations @@ -54,26 +55,21 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = GARepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=None, extra_scores=extra_scores, @@ -108,7 +104,7 @@ def update( """ # generate offsprings - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) @@ -127,7 +123,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=None, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics diff --git a/qdax/baselines/nsga2.py b/qdax/baselines/nsga2.py index 663d6f0e..afd587af 100644 --- a/qdax/baselines/nsga2.py +++ b/qdax/baselines/nsga2.py @@ -13,7 +13,7 @@ from qdax.baselines.genetic_algorithm import GeneticAlgorithm from qdax.core.containers.nsga2_repertoire import NSGA2Repertoire from qdax.core.emitters.emitter import EmitterState -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class NSGA2(GeneticAlgorithm): @@ -28,31 +28,36 @@ class NSGA2(GeneticAlgorithm): @partial(jax.jit, static_argnames=("self", "population_size")) def init( - self, init_genotypes: Genotype, population_size: int, random_key: RNGKey + self, genotypes: Genotype, population_size: int, random_key: RNGKey ) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]: # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = NSGA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, ) # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/baselines/pbt.py b/qdax/baselines/pbt.py index 65d1a950..6555c537 100644 --- a/qdax/baselines/pbt.py +++ b/qdax/baselines/pbt.py @@ -6,7 +6,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer -from qdax.types import RNGKey +from qdax.custom_types import RNGKey class PBTTrainingState(PyTreeNode): diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index a5ce15c5..482c5715 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -32,7 +32,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Mask, @@ -449,7 +449,10 @@ def _update_alpha( random_key=subkey, ) alpha_optimizer = optax.adam(learning_rate=alpha_lr) - (alpha_updates, alpha_optimizer_state,) = alpha_optimizer.update( + ( + alpha_updates, + alpha_optimizer_state, + ) = alpha_optimizer.update( alpha_gradient, training_state.alpha_optimizer_state ) alpha_params = optax.apply_updates( @@ -503,7 +506,10 @@ def _update_critic( random_key=subkey, ) critic_optimizer = optax.adam(learning_rate=critic_lr) - (critic_updates, critic_optimizer_state,) = critic_optimizer.update( + ( + critic_updates, + critic_optimizer_state, + ) = critic_optimizer.update( critic_gradient, training_state.critic_optimizer_state ) critic_params = optax.apply_updates( @@ -556,7 +562,10 @@ def _update_actor( random_key=subkey, ) policy_optimizer = optax.adam(learning_rate=policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/sac_pbt.py b/qdax/baselines/sac_pbt.py index 9aa2ff4c..947a7183 100644 --- a/qdax/baselines/sac_pbt.py +++ b/qdax/baselines/sac_pbt.py @@ -22,7 +22,7 @@ ) from qdax.core.neuroevolution.normalization_utils import normalize_with_rmstd from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn -from qdax.types import Descriptor, Mask, Metrics, RNGKey +from qdax.custom_types import Descriptor, Mask, Metrics, RNGKey class PBTSacTrainingState(PBTTrainingState, SacTrainingState): diff --git a/qdax/baselines/spea2.py b/qdax/baselines/spea2.py index 72ec2791..10d195ad 100644 --- a/qdax/baselines/spea2.py +++ b/qdax/baselines/spea2.py @@ -15,7 +15,7 @@ from qdax.baselines.genetic_algorithm import GeneticAlgorithm from qdax.core.containers.spea2_repertoire import SPEA2Repertoire from qdax.core.emitters.emitter import EmitterState -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class SPEA2(GeneticAlgorithm): @@ -40,7 +40,7 @@ class SPEA2(GeneticAlgorithm): ) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, population_size: int, num_neighbours: int, random_key: RNGKey, @@ -48,12 +48,12 @@ def init( # score initial genotypes fitnesses, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = SPEA2Repertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, population_size=population_size, num_neighbours=num_neighbours, @@ -61,14 +61,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=None, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, extra_scores=extra_scores, ) diff --git a/qdax/baselines/td3.py b/qdax/baselines/td3.py index e09b5254..97f37893 100644 --- a/qdax/baselines/td3.py +++ b/qdax/baselines/td3.py @@ -23,7 +23,7 @@ from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode from qdax.core.neuroevolution.networks.td3_networks import make_td3_networks from qdax.core.neuroevolution.sac_td3_utils import generate_unroll -from qdax.types import ( +from qdax.custom_types import ( Action, Descriptor, Mask, @@ -76,7 +76,10 @@ class TD3: def __init__(self, config: TD3Config, action_size: int): self._config = config - self._policy, self._critic, = make_td3_networks( + ( + self._policy, + self._critic, + ) = make_td3_networks( action_size=action_size, critic_hidden_layer_sizes=self._config.critic_hidden_layer_size, policy_hidden_layer_sizes=self._config.policy_hidden_layer_size, @@ -421,7 +424,10 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam( learning_rate=self._config.policy_learning_rate ) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/baselines/td3_pbt.py b/qdax/baselines/td3_pbt.py index 60cd8a38..5762956d 100644 --- a/qdax/baselines/td3_pbt.py +++ b/qdax/baselines/td3_pbt.py @@ -25,7 +25,7 @@ td3_policy_loss_fn, ) from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn -from qdax.types import Descriptor, Mask, Metrics, Params, RNGKey +from qdax.custom_types import Descriptor, Mask, Metrics, Params, RNGKey class PBTTD3TrainingState(PBTTrainingState, TD3TrainingState): @@ -291,7 +291,10 @@ def update( def update_policy_step() -> Tuple[Params, Params, optax.OptState]: policy_optimizer = optax.adam(learning_rate=training_state.policy_lr) - (policy_updates, policy_optimizer_state,) = policy_optimizer.update( + ( + policy_updates, + policy_optimizer_state, + ) = policy_optimizer.update( policy_gradient, training_state.policy_optimizer_state ) policy_params = optax.apply_updates( diff --git a/qdax/core/aurora.py b/qdax/core/aurora.py index fed716e3..f67d7b4f 100644 --- a/qdax/core/aurora.py +++ b/qdax/core/aurora.py @@ -12,8 +12,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.environments.bd_extractors import AuroraExtraInfo -from qdax.types import ( +from qdax.custom_types import ( Descriptor, Fitness, Genotype, @@ -22,6 +21,7 @@ Params, RNGKey, ) +from qdax.environments.bd_extractors import AuroraExtraInfo class AURORA: @@ -118,7 +118,7 @@ def container_size_control( def init( self, - init_genotypes: Genotype, + genotypes: Genotype, aurora_extra_info: AuroraExtraInfo, l_value: jnp.ndarray, max_size: int, @@ -128,7 +128,7 @@ def init( genotypes. Also performs the first training of the AURORA encoder. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) aurora_extra_info: information to perform AURORA encodings, such as the encoder parameters @@ -141,7 +141,7 @@ def init( the emitter, and the updated information to perform AURORA encodings """ fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, + genotypes, random_key, ) @@ -150,7 +150,7 @@ def init( descriptors = self._encoder_fn(observations, aurora_extra_info) repertoire = UnstructuredRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, observations=observations, @@ -160,13 +160,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, - genotypes=init_genotypes, + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -208,9 +204,10 @@ def update( a new key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, @@ -232,10 +229,11 @@ def update( # update emitter state after scoring is made emitter_state = self._emitter.state_update( emitter_state=emitter_state, + repertoire=repertoire, genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores=extra_scores | extra_info, ) # update the metrics diff --git a/qdax/core/cmaes.py b/qdax/core/cmaes.py index 481a49bf..0e9b4084 100644 --- a/qdax/core/cmaes.py +++ b/qdax/core/cmaes.py @@ -2,6 +2,7 @@ Definition of CMAES class, containing main functions necessary to build a CMA optimization script. Link to the paper: https://arxiv.org/abs/1604.00772 """ + from functools import partial from typing import Callable, Optional, Tuple @@ -9,7 +10,7 @@ import jax import jax.numpy as jnp -from qdax.types import Fitness, Genotype, Mask, RNGKey +from qdax.custom_types import Fitness, Genotype, Mask, RNGKey class CMAESState(flax.struct.PyTreeNode): diff --git a/qdax/core/containers/archive.py b/qdax/core/containers/archive.py index 8af808f3..036c5892 100644 --- a/qdax/core/containers/archive.py +++ b/qdax/core/containers/archive.py @@ -40,7 +40,7 @@ def size(self) -> float: fake_data = jnp.isnan(self.data) # count number of real data - return sum(~fake_data) + return float(sum(~fake_data)) @classmethod def create( @@ -161,9 +161,7 @@ def insert(self, state_descriptors: jnp.ndarray) -> Archive: values, _indices = knn(self.data, state_descriptors, 1) # get indices where distance bigger than threshold - relevant_indices = jnp.where( - values.squeeze() > self.acceptance_threshold, x=0, y=1 - ) + relevant_indices = jnp.where(values.squeeze() > self.acceptance_threshold, 0, 1) def iterate_fn( carry: Tuple[Archive, jnp.ndarray, int], condition_data: Dict @@ -192,7 +190,7 @@ def iterate_fn( # get indices where distance bigger than threshold not_too_close = jnp.where( - values.squeeze() > self.acceptance_threshold, x=0, y=1 + values.squeeze() > self.acceptance_threshold, 0, 1 ) second_condition = not_too_close.sum() condition = (first_condition + second_condition) == 0 @@ -280,7 +278,7 @@ def knn( dist = jnp.nan_to_num(dist, nan=jnp.inf) # clipping necessary - numerical approx make some distancies negative - dist = jnp.sqrt(jnp.clip(dist, a_min=0.0)) + dist = jnp.sqrt(jnp.clip(dist, min=0.0)) # return values, indices values, indices = qdax_top_k(-dist, k) diff --git a/qdax/core/containers/ga_repertoire.py b/qdax/core/containers/ga_repertoire.py index 87ade54f..403331ff 100644 --- a/qdax/core/containers/ga_repertoire.py +++ b/qdax/core/containers/ga_repertoire.py @@ -10,7 +10,7 @@ from jax.flatten_util import ravel_pytree from qdax.core.containers.repertoire import Repertoire -from qdax.types import Fitness, Genotype, RNGKey +from qdax.custom_types import Fitness, Genotype, RNGKey class GARepertoire(Repertoire): diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index aed74c78..b473d4b3 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -15,7 +15,14 @@ from numpy.random import RandomState from sklearn.cluster import KMeans -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) def compute_cvt_centroids( @@ -228,6 +235,38 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey return samples, random_key + @partial(jax.jit, static_argnames=("num_samples",)) + def sample_with_descs( + self, + random_key: RNGKey, + num_samples: int, + ) -> Tuple[Genotype, Descriptor, RNGKey]: + """Sample elements in the repertoire. + + Args: + random_key: a jax PRNG random key + num_samples: the number of elements to be sampled + + Returns: + samples: a batch of genotypes sampled in the repertoire + random_key: an updated jax PRNG random key + """ + + repertoire_empty = self.fitnesses == -jnp.inf + p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) + + random_key, subkey = jax.random.split(random_key) + samples = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.genotypes, + ) + descs = jax.tree_util.tree_map( + lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), + self.descriptors, + ) + + return samples, descs, random_key + @jax.jit def add( self, @@ -271,7 +310,7 @@ def add( # put dominated fitness to -jnp.inf batch_of_fitnesses = jnp.where( - batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf ) # get addition condition @@ -283,7 +322,7 @@ def add( # assign fake position when relevant : num_centroids is out of bound batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids + addition_condition, batch_of_indices, num_centroids ) # create new repertoire diff --git a/qdax/core/containers/mels_repertoire.py b/qdax/core/containers/mels_repertoire.py index a2e99971..7ef57bb9 100644 --- a/qdax/core/containers/mels_repertoire.py +++ b/qdax/core/containers/mels_repertoire.py @@ -14,7 +14,14 @@ MapElitesRepertoire, get_cells_indices, ) -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, Spread +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + Spread, +) def _dispersion(descriptors: jnp.ndarray) -> jnp.ndarray: @@ -232,7 +239,7 @@ def add( # assign fake position when relevant : num_centroids is out of bound batch_of_indices = jnp.where( - addition_condition, x=batch_of_indices, y=num_centroids + addition_condition, batch_of_indices, num_centroids ) # create new repertoire diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 0e2b6d3e..43be3835 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -15,7 +15,7 @@ MapElitesRepertoire, get_cells_indices, ) -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, diff --git a/qdax/core/containers/nsga2_repertoire.py b/qdax/core/containers/nsga2_repertoire.py index 74b0f454..331ef153 100644 --- a/qdax/core/containers/nsga2_repertoire.py +++ b/qdax/core/containers/nsga2_repertoire.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.containers.ga_repertoire import GARepertoire -from qdax.types import Fitness, Genotype +from qdax.custom_types import Fitness, Genotype from qdax.utils.pareto_front import compute_masked_pareto_front @@ -56,9 +56,9 @@ def _compute_crowding_distances( norm = jnp.max(srt_fitnesses, axis=0) - jnp.min(srt_fitnesses, axis=0) # get the distances - dists = jnp.row_stack( + dists = jnp.vstack( [srt_fitnesses, jnp.full(num_objective, jnp.inf)] - ) - jnp.row_stack([jnp.full(num_objective, -jnp.inf), srt_fitnesses]) + ) - jnp.vstack([jnp.full(num_objective, -jnp.inf), srt_fitnesses]) # Prepare the distance to last and next vectors dist_to_last, dist_to_next = dists, dists @@ -228,7 +228,7 @@ def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool: # get rid of the zeros (that correspond to the False from the mask) fake_indice = num_candidates + 1 # bigger than all the other indices - indices = jnp.where(indices == 0, x=fake_indice, y=indices) + indices = jnp.where(indices == 0, fake_indice, indices) # sort the indices to remove the fake indices indices = jnp.sort(indices)[: self.size] diff --git a/qdax/core/containers/repertoire.py b/qdax/core/containers/repertoire.py index f50d53b7..77c91683 100644 --- a/qdax/core/containers/repertoire.py +++ b/qdax/core/containers/repertoire.py @@ -4,11 +4,11 @@ from __future__ import annotations -from abc import ABC, abstractclassmethod, abstractmethod +from abc import ABC, abstractmethod import flax -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey class Repertoire(flax.struct.PyTreeNode, ABC): @@ -19,7 +19,8 @@ class Repertoire(flax.struct.PyTreeNode, ABC): to keep the parent classes explicit and transparent. """ - @abstractclassmethod + @classmethod + @abstractmethod def init(cls) -> Repertoire: # noqa: N805 """Create a repertoire.""" pass diff --git a/qdax/core/containers/spea2_repertoire.py b/qdax/core/containers/spea2_repertoire.py index 54870db4..33c31547 100644 --- a/qdax/core/containers/spea2_repertoire.py +++ b/qdax/core/containers/spea2_repertoire.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from qdax.core.containers.ga_repertoire import GARepertoire -from qdax.types import Fitness, Genotype +from qdax.custom_types import Fitness, Genotype class SPEA2Repertoire(GARepertoire): diff --git a/qdax/core/containers/uniform_replacement_archive.py b/qdax/core/containers/uniform_replacement_archive.py index d6f233db..830878cf 100644 --- a/qdax/core/containers/uniform_replacement_archive.py +++ b/qdax/core/containers/uniform_replacement_archive.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from qdax.core.containers.archive import Archive -from qdax.types import RNGKey +from qdax.custom_types import RNGKey class UniformReplacementArchive(Archive): @@ -74,7 +74,7 @@ def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive: subkey, shape=(1,), minval=0, maxval=self.max_size ) - index = jnp.where(condition=is_full, x=random_index, y=new_current_position) + index = jnp.where(is_full, random_index, new_current_position) new_data = self.data.at[index].set(state_descriptor) diff --git a/qdax/core/containers/unstructured_repertoire.py b/qdax/core/containers/unstructured_repertoire.py index f4cc0c98..8512d3d6 100644 --- a/qdax/core/containers/unstructured_repertoire.py +++ b/qdax/core/containers/unstructured_repertoire.py @@ -8,7 +8,14 @@ import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from qdax.types import Centroid, Descriptor, Fitness, Genotype, Observation, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + Fitness, + Genotype, + Observation, + RNGKey, +) @partial(jax.jit, static_argnames=("k_nn",)) @@ -300,7 +307,7 @@ def add( # ReIndexing of all the inputs to the correct sorted way batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get() - batch_of_genotypes = jax.tree_map( + batch_of_genotypes = jax.tree_util.tree_map( lambda x: x.at[sorted_bds].get(), batch_of_genotypes ) batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get() @@ -333,7 +340,7 @@ def add( # put dominated fitness to -jnp.inf batch_of_fitnesses = jnp.where( - batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf + batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf ) # get addition condition @@ -347,12 +354,12 @@ def add( # assign fake position when relevant : num_centroids is out of bounds batch_of_indices = jnp.where( addition_condition, - x=batch_of_indices, - y=self.max_size, + batch_of_indices, + self.max_size, ) # create new grid - new_grid_genotypes = jax.tree_map( + new_grid_genotypes = jax.tree_util.tree_map( lambda grid_genotypes, new_genotypes: grid_genotypes.at[ batch_of_indices.squeeze() ].set(new_genotypes), @@ -398,7 +405,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey grid_empty = self.fitnesses == -jnp.inf p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p), self.genotypes, ) @@ -435,7 +442,7 @@ def init( # Initialize grid with default values default_fitnesses = -jnp.inf * jnp.ones(shape=max_size) - default_genotypes = jax.tree_map( + default_genotypes = jax.tree_util.tree_map( lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan), genotypes, ) diff --git a/qdax/core/distributed_map_elites.py b/qdax/core/distributed_map_elites.py index c8a1ea44..dbc6522b 100644 --- a/qdax/core/distributed_map_elites.py +++ b/qdax/core/distributed_map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -10,14 +11,14 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import Centroid, Genotype, Metrics, RNGKey +from qdax.custom_types import Centroid, Genotype, Metrics, RNGKey class DistributedMAPElites(MAPElites): @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -30,7 +31,7 @@ def init( devices. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -41,7 +42,7 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # gather across all devices @@ -51,7 +52,7 @@ def init( gathered_descriptors, ) = jax.tree_util.tree_map( lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0), - (init_genotypes, fitnesses, descriptors), + (genotypes, fitnesses, descriptors), ) # init the repertoire @@ -64,14 +65,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -108,7 +114,7 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) # scores the offsprings @@ -138,7 +144,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics @@ -184,7 +190,7 @@ def get_distributed_update_fn( of MAP-Elites updates. """ - @partial(jax.jit, static_argnames=("self",)) + @jax.jit def _scan_update( carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], unused: Any, @@ -195,7 +201,12 @@ def _scan_update( repertoire, emitter_state, random_key = carry # apply one step of update - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, @@ -209,7 +220,11 @@ def update_fn( random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]: """Apply num_iterations of update.""" - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( _scan_update, (repertoire, emitter_state, random_key), (), diff --git a/qdax/core/emitters/cma_emitter.py b/qdax/core/emitters/cma_emitter.py index f9d58caa..315dcd9b 100644 --- a/qdax/core/emitters/cma_emitter.py +++ b/qdax/core/emitters/cma_emitter.py @@ -13,7 +13,14 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class CMAEmitterState(EmitterState): @@ -99,14 +106,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -135,7 +148,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -154,7 +167,7 @@ def emit( cmaes_state=emitter_state.cmaes_state, random_key=random_key ) - return offsprings, random_key + return offsprings, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_improvement_emitter.py b/qdax/core/emitters/cma_improvement_emitter.py index 28424f3f..7c3fc98c 100644 --- a/qdax/core/emitters/cma_improvement_emitter.py +++ b/qdax/core/emitters/cma_improvement_emitter.py @@ -6,7 +6,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype class CMAImprovementEmitter(CMAEmitter): @@ -62,13 +62,13 @@ def _ranking_criteria( condition = improvements == jnp.inf # criteria: fitness if new cell, improvement else - ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + ranking_criteria = jnp.where(condition, fitnesses, improvements) # make sure to have all the new cells first new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index f63654fd..c3f87fed 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -12,7 +12,7 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, @@ -100,14 +100,20 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAMEGAState, RNGKey]: """ Initializes the CMA-MEGA emitter. Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -117,7 +123,7 @@ def init( # define init theta as 0 theta = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x[:1, ...]), - init_genotypes, + genotypes, ) # score it @@ -147,7 +153,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAMEGAState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the @@ -181,7 +187,7 @@ def emit( # Compute new candidates new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) - return new_thetas, random_key + return new_thetas, {}, random_key @partial( jax.jit, @@ -232,13 +238,13 @@ def state_update( condition = improvements == jnp.inf # criteria: fitness if new cell, improvement else - ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements) + ranking_criteria = jnp.where(condition, fitnesses, improvements) # make sure to have all the new cells first new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria) ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) # sort indices according to the criteria @@ -276,12 +282,12 @@ def state_update( # update theta in case of reinit theta = jax.tree_util.tree_map( - lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta + lambda x, y: jnp.where(reinitialize, x, y), random_theta, theta ) # update cmaes state in case of reinit cmaes_state = jax.tree_util.tree_map( - lambda x, y: jnp.where(reinitialize, x=x, y=y), + lambda x, y: jnp.where(reinitialize, x, y), self._cma_initial_state, cmaes_state, ) diff --git a/qdax/core/emitters/cma_opt_emitter.py b/qdax/core/emitters/cma_opt_emitter.py index d9c5bf71..9a783585 100644 --- a/qdax/core/emitters/cma_opt_emitter.py +++ b/qdax/core/emitters/cma_opt_emitter.py @@ -6,7 +6,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype class CMAOptimizingEmitter(CMAEmitter): diff --git a/qdax/core/emitters/cma_pool_emitter.py b/qdax/core/emitters/cma_pool_emitter.py index d5424a01..55ccaa4f 100644 --- a/qdax/core/emitters/cma_pool_emitter.py +++ b/qdax/core/emitters/cma_pool_emitter.py @@ -9,7 +9,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class CMAPoolEmitterState(EmitterState): @@ -49,14 +49,20 @@ def batch_size(self) -> int: @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMAPoolEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -67,7 +73,14 @@ def scan_emitter_init( carry: RNGKey, unused: Any ) -> Tuple[RNGKey, CMAEmitterState]: random_key = carry - emitter_state, random_key = self._emitter.init(init_genotypes, random_key) + emitter_state, random_key = self._emitter.init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) return random_key, emitter_state # init all the emitter states @@ -91,7 +104,7 @@ def emit( repertoire: Optional[MapElitesRepertoire], emitter_state: CMAPoolEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emits new individuals. @@ -111,11 +124,11 @@ def emit( ) # use it to emit offsprings - offsprings, random_key = self._emitter.emit( + offsprings, extra_info, random_key = self._emitter.emit( repertoire, used_emitter_state, random_key ) - return offsprings, random_key + return offsprings, extra_info, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/cma_rnd_emitter.py b/qdax/core/emitters/cma_rnd_emitter.py index 4afb2f5d..27e4f0db 100644 --- a/qdax/core/emitters/cma_rnd_emitter.py +++ b/qdax/core/emitters/cma_rnd_emitter.py @@ -9,7 +9,7 @@ from qdax.core.cmaes import CMAESState from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.cma_emitter import CMAEmitter, CMAEmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class CMARndEmitterState(CMAEmitterState): @@ -35,14 +35,20 @@ class CMARndEmitterState(CMAEmitterState): class CMARndEmitter(CMAEmitter): @partial(jax.jit, static_argnames=("self",)) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[CMARndEmitterState, RNGKey]: """ Initializes the CMA-MEGA emitter Args: - init_genotypes: initial genotypes to add to the grid. + genotypes: initial genotypes to add to the grid. random_key: a random key to handle stochastic operations. Returns: @@ -162,7 +168,7 @@ def _ranking_criteria( condition = improvements == jnp.inf ranking_criteria = jnp.where( - condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria + condition, ranking_criteria + new_cell_offset, ranking_criteria ) return ranking_criteria # type: ignore diff --git a/qdax/core/emitters/dcrl_emitter.py b/qdax/core/emitters/dcrl_emitter.py new file mode 100644 index 00000000..b353a22f --- /dev/null +++ b/qdax/core/emitters/dcrl_emitter.py @@ -0,0 +1,769 @@ +"""Implements the DCRL Emitter from DCRL-MAP-Elites algorithm +in JAX for Brax environments. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Any, Tuple + +import flax.linen as nn +import jax +import optax +from jax import numpy as jnp + +from qdax.core.containers.repertoire import Repertoire +from qdax.core.emitters.emitter import Emitter, EmitterState +from qdax.core.neuroevolution.buffers.buffer import DCRLTransition, ReplayBuffer +from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_dc_fn +from qdax.core.neuroevolution.networks.networks import QModuleDC +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey +from qdax.environments.base_wrappers import QDEnv + + +@dataclass +class DCRLConfig: + """Configuration for DCRL Emitter""" + + dcrl_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class DCRLEmitterState(EmitterState): + """Contains training state for the learner.""" + + critic_params: Params + critic_opt_state: optax.OptState + actor_params: Params + actor_opt_state: optax.OptState + target_critic_params: Params + target_actor_params: Params + replay_buffer: ReplayBuffer + key: RNGKey + steps: jnp.ndarray + + +class DCRLEmitter(Emitter): + """ + A descriptor-conditioned reinforcement learning emitter used to implement + DCRL-MAP-Elites algorithm. + """ + + def __init__( + self, + config: DCRLConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + ) -> None: + self._config = config + self._env = env + self._policy_network = policy_network + self._actor_network = actor_network + + # Init Critics + critic_network = QModuleDC( + n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size + ) + self._critic_network = critic_network + + # Set up the losses and optimizers - return the opt states + ( + self._policy_loss_fn, + self._actor_loss_fn, + self._critic_loss_fn, + ) = make_td3_loss_dc_fn( + policy_fn=policy_network.apply, + actor_fn=actor_network.apply, + critic_fn=critic_network.apply, + reward_scaling=self._config.reward_scaling, + discount=self._config.discount, + noise_clip=self._config.noise_clip, + policy_noise=self._config.policy_noise, + ) + + # Init optimizers + self._actor_optimizer = optax.adam( + learning_rate=self._config.actor_learning_rate + ) + self._critic_optimizer = optax.adam( + learning_rate=self._config.critic_learning_rate + ) + self._policies_optimizer = optax.adam( + learning_rate=self._config.policy_learning_rate + ) + + @property + def batch_size(self) -> int: + """ + Returns: + the batch size emitted by the emitter. + """ + return self._config.dcrl_batch_size + self._config.ai_batch_size + + @property + def use_all_data(self) -> bool: + """Whether to use all data or not when used along other emitters. + + QualityPGEmitter uses the transitions from the genotypes that were generated + by other emitters. + """ + return True + + def init( + self, + key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> Tuple[DCRLEmitterState, RNGKey]: + """Initializes the emitter state. + + Args: + genotypes: The initial population. + key: A random key. + + Returns: + The initial state of the PGAMEEmitter, a new random key. + """ + + observation_size = jax.tree_util.tree_leaves(genotypes)[1].shape[1] + descriptor_size = self._env.behavior_descriptor_length + action_size = self._env.action_size + + # Initialise critic, greedy actor and population + key, subkey = jax.random.split(key) + fake_obs = jnp.zeros(shape=(observation_size,)) + fake_desc = jnp.zeros(shape=(descriptor_size,)) + fake_action = jnp.zeros(shape=(action_size,)) + + critic_params = self._critic_network.init( + subkey, obs=fake_obs, actions=fake_action, desc=fake_desc + ) + target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) + + key, subkey = jax.random.split(key) + actor_params = self._actor_network.init(subkey, obs=fake_obs, desc=fake_desc) + target_actor_params = jax.tree_util.tree_map(lambda x: x, actor_params) + + # Prepare init optimizer states + critic_opt_state = self._critic_optimizer.init(critic_params) + actor_opt_state = self._actor_optimizer.init(actor_params) + + # Initialize replay buffer + dummy_transition = DCRLTransition.init_dummy( + observation_dim=self._env.observation_size, + action_dim=action_size, + descriptor_dim=descriptor_size, + ) + + replay_buffer = ReplayBuffer.init( + buffer_size=self._config.replay_buffer_size, transition=dummy_transition + ) + + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_normalized + ) + replay_buffer = replay_buffer.insert(transitions) + + # Initial training state + key, subkey = jax.random.split(key) + emitter_state = DCRLEmitterState( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + replay_buffer=replay_buffer, + key=subkey, + steps=jnp.array(0), + ) + + return emitter_state, key + + @partial(jax.jit, static_argnames=("self",)) + def _similarity(self, descs_1: Descriptor, descs_2: Descriptor) -> jnp.array: + """Compute the similarity between two batches of descriptors. + Args: + descs_1: batch of descriptors. + descs_2: batch of descriptors. + Returns: + batch of similarity measures. + """ + return jnp.exp( + -jnp.linalg.norm(descs_1 - descs_2, axis=-1) / self._config.lengthscale + ) + + @partial(jax.jit, static_argnames=("self",)) + def _normalize_desc(self, desc: Descriptor) -> Descriptor: + return ( + 2 + * (desc - self._env.behavior_descriptor_limits[0]) + / ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) + - 1 + ) + + @partial(jax.jit, static_argnames=("self",)) + def _unnormalize_desc(self, desc_normalized: Descriptor) -> Descriptor: + return 0.5 * ( + self._env.behavior_descriptor_limits[1] + - self._env.behavior_descriptor_limits[0] + ) * desc_normalized + 0.5 * ( + self._env.behavior_descriptor_limits[1] + + self._env.behavior_descriptor_limits[0] + ) + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_kernel_bias_with_desc( + self, actor_dc_params: Params, desc: Descriptor + ) -> Tuple[Params, Params]: + """ + Compute the equivalent bias of the first layer of the actor network + given a descriptor. + """ + # Extract kernel and bias of the first layer + kernel = actor_dc_params["params"]["Dense_0"]["kernel"] + bias = actor_dc_params["params"]["Dense_0"]["bias"] + + # Compute the equivalent bias + equivalent_kernel = kernel[: -desc.shape[0], :] + equivalent_bias = bias + jnp.dot(desc, kernel[-desc.shape[0] :]) + + return equivalent_kernel, equivalent_bias + + @partial(jax.jit, static_argnames=("self",)) + def _compute_equivalent_params_with_desc( + self, actor_dc_params: Params, desc: Descriptor + ) -> Params: + desc_normalized = self._normalize_desc(desc) + ( + equivalent_kernel, + equivalent_bias, + ) = self._compute_equivalent_kernel_bias_with_desc( + actor_dc_params, desc_normalized + ) + actor_dc_params["params"]["Dense_0"]["kernel"] = equivalent_kernel + actor_dc_params["params"]["Dense_0"]["bias"] = equivalent_bias + return actor_dc_params + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit( + self, + repertoire: Repertoire, + emitter_state: DCRLEmitterState, + key: RNGKey, + ) -> Tuple[Genotype, ExtraScores, RNGKey]: + """Do a step of PG emission. + + Args: + repertoire: the current repertoire of genotypes + emitter_state: the state of the emitter used + key: a random key + + Returns: + A batch of offspring, the new emitter state and a new key. + """ + # PG emitter + parents_pg, descs_pg, key = repertoire.sample_with_descs( + key, self._config.dcrl_batch_size + ) + genotypes_pg = self.emit_pg(emitter_state, parents_pg, descs_pg) + + # Actor injection emitter + _, descs_ai, key = repertoire.sample_with_descs(key, self._config.ai_batch_size) + descs_ai = descs_ai.reshape( + descs_ai.shape[0], self._env.behavior_descriptor_length + ) + genotypes_ai = self.emit_ai(emitter_state, descs_ai) + + # Concatenate PG and AI genotypes + genotypes = jax.tree_util.tree_map( + lambda x1, x2: jnp.concatenate((x1, x2), axis=0), genotypes_pg, genotypes_ai + ) + + return ( + genotypes, + {"desc_prime": jnp.concatenate([descs_pg, descs_ai], axis=0)}, + key, + ) + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit_pg( + self, + emitter_state: DCRLEmitterState, + parents: Genotype, + descs: Descriptor, + ) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + mutation_fn = partial( + self._mutation_function_pg, + emitter_state=emitter_state, + ) + offsprings = jax.vmap(mutation_fn)(parents, descs) + + return offsprings + + @partial( + jax.jit, + static_argnames=("self",), + ) + def emit_ai(self, emitter_state: DCRLEmitterState, descs: Descriptor) -> Genotype: + """Emit the offsprings generated through pg mutation. + + Args: + emitter_state: current emitter state, contains critic and + replay buffer. + parents: the parents selected to be applied gradients in order + to mutate towards better performance. + descs: the descriptors of the parents. + + Returns: + A new set of offsprings. + """ + offsprings = jax.vmap( + self._compute_equivalent_params_with_desc, in_axes=(None, 0) + )(emitter_state.actor_params, descs) + + return offsprings + + @partial(jax.jit, static_argnames=("self",)) + def emit_actor(self, emitter_state: DCRLEmitterState) -> Genotype: + """Emit the greedy actor. + + Simply needs to be retrieved from the emitter state. + + Args: + emitter_state: the current emitter state, it stores the + greedy actor. + + Returns: + The parameters of the actor. + """ + return emitter_state.actor_params + + @partial( + jax.jit, + static_argnames=("self",), + ) + def state_update( + self, + emitter_state: DCRLEmitterState, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, + ) -> DCRLEmitterState: + """This function gives an opportunity to update the emitter state + after the genotypes have been scored. + + Here it is used to fill the Replay Buffer with the transitions + from the scoring of the genotypes, and then the training of the + critic/actor happens. Hence the params of critic/actor are updated, + as well as their optimizer states. + + Args: + emitter_state: current emitter state. + repertoire: the current genotypes repertoire + genotypes: unused here - but compulsory in the signature. + fitnesses: unused here - but compulsory in the signature. + descriptors: unused here - but compulsory in the signature. + extra_scores: extra information coming from the scoring function, + this contains the transitions added to the replay buffer. + + Returns: + New emitter state where the replay buffer has been filled with + the new experienced transitions. + """ + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + episode_length = transitions.obs.shape[1] + + desc_prime = jnp.concatenate( + [ + extra_scores["desc_prime"], + descriptors[ + self._config.dcrl_batch_size + self._config.ai_batch_size : + ], + ], + axis=0, + ) + desc_prime = jnp.repeat(desc_prime[:, jnp.newaxis, :], episode_length, axis=1) + desc = jnp.repeat(descriptors[:, jnp.newaxis, :], episode_length, axis=1) + + desc_prime_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc_prime) + desc_normalized = jax.vmap(jax.vmap(self._normalize_desc))(desc) + transitions = transitions.replace( + desc=desc_normalized, desc_prime=desc_prime_normalized + ) + + # Add transitions to replay buffer + replay_buffer = emitter_state.replay_buffer.insert(transitions) + emitter_state = emitter_state.replace(replay_buffer=replay_buffer) + + # sample transitions from the replay buffer + key, subkey = jax.random.split(emitter_state.key) + transitions, key = replay_buffer.sample( + subkey, self._config.num_critic_training_steps * self._config.batch_size + ) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, + ( + self._config.num_critic_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) + transitions = transitions.replace( + rewards=self._similarity(transitions.desc, transitions.desc_prime) + * transitions.rewards + ) + emitter_state = emitter_state.replace(key=key) + + def scan_train_critics( + carry: DCRLEmitterState, + transitions: DCRLTransition, + ) -> Tuple[DCRLEmitterState, Any]: + emitter_state = carry + new_emitter_state = self._train_critics(emitter_state, transitions) + return new_emitter_state, () + + # Train critics and greedy actor + emitter_state, _ = jax.lax.scan( + scan_train_critics, + emitter_state, + transitions, + length=self._config.num_critic_training_steps, + ) + + return emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _train_critics( + self, emitter_state: DCRLEmitterState, transitions: DCRLTransition + ) -> DCRLEmitterState: + """Apply one gradient step to critics and to the greedy actor + (contained in carry in training_state), then soft update target critics + and target actor. + + Those updates are very similar to those made in TD3. + + Args: + emitter_state: actual emitter state + + Returns: + New emitter state where the critic and the greedy actor have been + updated. Optimizer states have also been updated in the process. + """ + # Update Critic + ( + critic_opt_state, + critic_params, + target_critic_params, + key, + ) = self._update_critic( + critic_params=emitter_state.critic_params, + target_critic_params=emitter_state.target_critic_params, + target_actor_params=emitter_state.target_actor_params, + critic_opt_state=emitter_state.critic_opt_state, + transitions=transitions, + key=emitter_state.key, + ) + + # Update greedy actor + ( + actor_opt_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( + emitter_state.steps % self._config.policy_delay == 0, + lambda x: self._update_actor(*x), + lambda _: ( + emitter_state.actor_opt_state, + emitter_state.actor_params, + emitter_state.target_actor_params, + ), + operand=( + emitter_state.actor_params, + emitter_state.actor_opt_state, + emitter_state.target_actor_params, + emitter_state.critic_params, + transitions, + ), + ) + + # Create new training state + new_emitter_state = emitter_state.replace( + critic_params=critic_params, + critic_opt_state=critic_opt_state, + actor_params=actor_params, + actor_opt_state=actor_opt_state, + target_critic_params=target_critic_params, + target_actor_params=target_actor_params, + key=key, + steps=emitter_state.steps + 1, + ) + + return new_emitter_state # type: ignore + + @partial(jax.jit, static_argnames=("self",)) + def _update_critic( + self, + critic_params: Params, + target_critic_params: Params, + target_actor_params: Params, + critic_opt_state: Params, + transitions: DCRLTransition, + key: RNGKey, + ) -> Tuple[Params, Params, Params, RNGKey]: + + # compute loss and gradients + key, subkey = jax.random.split(key) + critic_gradient = jax.grad(self._critic_loss_fn)( + critic_params, + target_actor_params, + target_critic_params, + transitions, + subkey, + ) + critic_updates, critic_opt_state = self._critic_optimizer.update( + critic_gradient, critic_opt_state + ) + + # update critic + critic_params = optax.apply_updates(critic_params, critic_updates) + + # Soft update of target critic network + target_critic_params = jax.tree_util.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_critic_params, + critic_params, + ) + + return critic_opt_state, critic_params, target_critic_params, key + + @partial(jax.jit, static_argnames=("self",)) + def _update_actor( + self, + actor_params: Params, + actor_opt_state: optax.OptState, + target_actor_params: Params, + critic_params: Params, + transitions: DCRLTransition, + ) -> Tuple[optax.OptState, Params, Params]: + + # Update greedy actor + policy_gradient = jax.grad(self._actor_loss_fn)( + actor_params, + critic_params, + transitions, + ) + ( + policy_updates, + actor_opt_state, + ) = self._actor_optimizer.update(policy_gradient, actor_opt_state) + actor_params = optax.apply_updates(actor_params, policy_updates) + + # Soft update of target greedy actor + target_actor_params = jax.tree_util.tree_map( + lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + + self._config.soft_tau_update * x2, + target_actor_params, + actor_params, + ) + + return ( + actor_opt_state, + actor_params, + target_actor_params, + ) + + @partial( + jax.jit, + static_argnames=("self",), + ) + def _mutation_function_pg( + self, + policy_params: Genotype, + descs: Descriptor, + emitter_state: DCRLEmitterState, + ) -> Genotype: + """Apply pg mutation to a policy via multiple steps of gradient descent. + First, update the rewards to be diversity rewards, then apply the gradient + steps. + + Args: + policy_params: a policy, supposed to be a differentiable neural + network. + emitter_state: the current state of the emitter, containing among others, + the replay buffer, the critic. + + Returns: + The updated params of the neural network. + """ + # Get transitions + transitions, key = emitter_state.replay_buffer.sample( + emitter_state.key, + sample_size=self._config.num_pg_training_steps * self._config.batch_size, + ) + descs_prime = jnp.tile( + descs, (self._config.num_pg_training_steps * self._config.batch_size, 1) + ) + descs_prime_normalized = jax.vmap(self._normalize_desc)(descs_prime) + transitions = transitions.replace( + rewards=self._similarity(transitions.desc, descs_prime_normalized) + * transitions.rewards, + desc_prime=descs_prime_normalized, + ) + transitions = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, + ( + self._config.num_pg_training_steps, + self._config.batch_size, + *x.shape[1:], + ), + ), + transitions, + ) + + # Replace key + emitter_state = emitter_state.replace(key=key) + + # Define new policy optimizer state + policy_opt_state = self._policies_optimizer.init(policy_params) + + def scan_train_policy( + carry: Tuple[DCRLEmitterState, Genotype, optax.OptState], + transitions: DCRLTransition, + ) -> Tuple[Tuple[DCRLEmitterState, Genotype, optax.OptState], Any]: + emitter_state, policy_params, policy_opt_state = carry + ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ) = self._train_policy( + emitter_state, + policy_params, + policy_opt_state, + transitions, + ) + return ( + new_emitter_state, + new_policy_params, + new_policy_opt_state, + ), () + + ( + emitter_state, + policy_params, + policy_opt_state, + ), _ = jax.lax.scan( + scan_train_policy, + (emitter_state, policy_params, policy_opt_state), + transitions, + length=self._config.num_pg_training_steps, + ) + + return policy_params + + @partial(jax.jit, static_argnames=("self",)) + def _train_policy( + self, + emitter_state: DCRLEmitterState, + policy_params: Params, + policy_opt_state: optax.OptState, + transitions: DCRLTransition, + ) -> Tuple[DCRLEmitterState, Params, optax.OptState]: + """Apply one gradient step to a policy (called policy_params). + + Args: + emitter_state: current state of the emitter. + policy_params: parameters corresponding to the weights and bias of + the neural network that defines the policy. + + Returns: + The new emitter state and new params of the NN. + """ + # update policy + policy_opt_state, policy_params = self._update_policy( + critic_params=emitter_state.critic_params, + policy_opt_state=policy_opt_state, + policy_params=policy_params, + transitions=transitions, + ) + + return emitter_state, policy_params, policy_opt_state + + @partial(jax.jit, static_argnames=("self",)) + def _update_policy( + self, + critic_params: Params, + policy_opt_state: optax.OptState, + policy_params: Params, + transitions: DCRLTransition, + ) -> Tuple[optax.OptState, Params]: + + # compute loss + policy_gradient = jax.grad(self._policy_loss_fn)( + policy_params, + critic_params, + transitions, + ) + # Compute gradient and update policies + ( + policy_updates, + policy_opt_state, + ) = self._policies_optimizer.update(policy_gradient, policy_opt_state) + policy_params = optax.apply_updates(policy_params, policy_updates) + + return policy_opt_state, policy_params diff --git a/qdax/core/emitters/dcrl_me_emitter.py b/qdax/core/emitters/dcrl_me_emitter.py new file mode 100644 index 00000000..89ffebc2 --- /dev/null +++ b/qdax/core/emitters/dcrl_me_emitter.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Callable, Tuple + +import flax.linen as nn + +from qdax.core.emitters.dcrl_emitter import DCRLConfig, DCRLEmitter +from qdax.core.emitters.multi_emitter import MultiEmitter +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Params, RNGKey +from qdax.environments.base_wrappers import QDEnv + + +@dataclass +class DCRLMEConfig: + """Configuration for DCRL-MAP-Elites Algorithm""" + + ga_batch_size: int = 128 + dcrl_batch_size: int = 64 + ai_batch_size: int = 64 + lengthscale: float = 0.1 + + # PG emitter + critic_hidden_layer_size: Tuple[int, ...] = (256, 256) + num_critic_training_steps: int = 3000 + num_pg_training_steps: int = 150 + batch_size: int = 100 + replay_buffer_size: int = 1_000_000 + discount: float = 0.99 + reward_scaling: float = 1.0 + critic_learning_rate: float = 3e-4 + actor_learning_rate: float = 3e-4 + policy_learning_rate: float = 1e-3 + noise_clip: float = 0.5 + policy_noise: float = 0.2 + soft_tau_update: float = 0.005 + policy_delay: int = 2 + + +class DCRLMEEmitter(MultiEmitter): + def __init__( + self, + config: DCRLMEConfig, + policy_network: nn.Module, + actor_network: nn.Module, + env: QDEnv, + variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]], + ) -> None: + self._config = config + self._env = env + self._variation_fn = variation_fn + + dcrl_config = DCRLConfig( + dcrl_batch_size=config.dcrl_batch_size, + ai_batch_size=config.ai_batch_size, + lengthscale=config.lengthscale, + critic_hidden_layer_size=config.critic_hidden_layer_size, + num_critic_training_steps=config.num_critic_training_steps, + num_pg_training_steps=config.num_pg_training_steps, + batch_size=config.batch_size, + replay_buffer_size=config.replay_buffer_size, + discount=config.discount, + reward_scaling=config.reward_scaling, + critic_learning_rate=config.critic_learning_rate, + actor_learning_rate=config.actor_learning_rate, + policy_learning_rate=config.policy_learning_rate, + noise_clip=config.noise_clip, + policy_noise=config.policy_noise, + soft_tau_update=config.soft_tau_update, + policy_delay=config.policy_delay, + ) + + # define the quality emitter + dcrl_emitter = DCRLEmitter( + config=dcrl_config, + policy_network=policy_network, + actor_network=actor_network, + env=env, + ) + + # define the GA emitter + ga_emitter = MixingEmitter( + mutation_fn=lambda x, r: (x, r), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=config.ga_batch_size, + ) + + super().__init__(emitters=(dcrl_emitter, ga_emitter)) diff --git a/qdax/core/emitters/dpg_emitter.py b/qdax/core/emitters/dpg_emitter.py index 8b858db4..ea921237 100644 --- a/qdax/core/emitters/dpg_emitter.py +++ b/qdax/core/emitters/dpg_emitter.py @@ -1,6 +1,7 @@ """ Implements the Diversity PG inspired by QDPG algorithm in jax for brax environments, based on: https://arxiv.org/abs/2006.08505 """ + from dataclasses import dataclass from functools import partial from typing import Any, Callable, Optional, Tuple @@ -17,8 +18,7 @@ QualityPGEmitterState, ) from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.environments.base_wrappers import QDEnv -from qdax.types import ( +from qdax.custom_types import ( Descriptor, ExtraScores, Fitness, @@ -28,6 +28,7 @@ RNGKey, StateDescriptor, ) +from qdax.environments.base_wrappers import QDEnv @dataclass @@ -77,12 +78,18 @@ def __init__( self._score_novelty = score_novelty def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[DiversityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -90,7 +97,14 @@ def init( """ # init elements of diversity emitter state with QualityEmitterState.init() - diversity_emitter_state, random_key = super().init(init_genotypes, random_key) + diversity_emitter_state, random_key = super().init( + random_key, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) # store elements in a dictionary attributes_dict = vars(diversity_emitter_state) @@ -102,6 +116,12 @@ def init( max_size=self._config.archive_max_size, ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + archive = archive.insert(transitions.state_desc) + # init emitter state emitter_state = DiversityPGEmitterState( # retrieve all attributes from the QualityPGEmitterState @@ -161,7 +181,10 @@ def scan_train_critics( return new_emitter_state, () # sample transitions - (transitions, random_key,) = emitter_state.replay_buffer.sample( + ( + transitions, + random_key, + ) = emitter_state.replay_buffer.sample( random_key=emitter_state.random_key, sample_size=self._config.num_critic_training_steps * self._config.batch_size, @@ -230,7 +253,11 @@ def _train_critics( ) # Update greedy policy - (policy_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + policy_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -329,7 +356,11 @@ def scan_train_policy( transitions, ) - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (transitions), diff --git a/qdax/core/emitters/emitter.py b/qdax/core/emitters/emitter.py index d32ed981..21139356 100644 --- a/qdax/core/emitters/emitter.py +++ b/qdax/core/emitters/emitter.py @@ -6,7 +6,7 @@ from flax.struct import PyTreeNode from qdax.core.containers.repertoire import Repertoire -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class EmitterState(PyTreeNode): @@ -30,14 +30,20 @@ class EmitterState(PyTreeNode): class Emitter(ABC): def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be outputted. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -51,7 +57,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire. diff --git a/qdax/core/emitters/mees_emitter.py b/qdax/core/emitters/mees_emitter.py index b5bb1ada..4d51326a 100644 --- a/qdax/core/emitters/mees_emitter.py +++ b/qdax/core/emitters/mees_emitter.py @@ -3,6 +3,7 @@ from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217 """ + from __future__ import annotations from dataclasses import dataclass @@ -19,7 +20,7 @@ get_cells_indices, ) from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class NoveltyArchive(flax.struct.PyTreeNode): @@ -236,26 +237,32 @@ def batch_size(self) -> int: static_argnames=("self",), ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[MEESEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: The initial state of the MEESEmitter, a new random key. """ # Initialisation requires one initial genotype - if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1: - init_genotypes = jax.tree_util.tree_map( + if jax.tree_util.tree_leaves(genotypes)[0].shape[0] > 1: + genotypes = jax.tree_util.tree_map( lambda x: x[0], - init_genotypes, + genotypes, ) # Initialise optimizer - initial_optimizer_state = self._optimizer.init(init_genotypes) + initial_optimizer_state = self._optimizer.init(genotypes) # Create empty Novelty archive if self._config.use_explore: @@ -270,7 +277,7 @@ def init( # Create empty updated genotypes and fitness last_updated_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]), - init_genotypes, + genotypes, ) last_updated_fitnesses = -jnp.inf * jnp.ones( shape=self._config.last_updated_size @@ -280,7 +287,7 @@ def init( MEESEmitterState( initial_optimizer_state=initial_optimizer_state, optimizer_state=initial_optimizer_state, - offspring=init_genotypes, + offspring=genotypes, generation_count=0, novelty_archive=novelty_archive, last_updated_genotypes=last_updated_genotypes, @@ -300,7 +307,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: MEESEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Return the offspring generated through gradient update. Params: @@ -313,7 +320,7 @@ def emit( a new jax PRNG key """ - return emitter_state.offspring, random_key + return emitter_state.offspring, {}, random_key @partial( jax.jit, @@ -356,7 +363,7 @@ def _sample( genotypes_empty = fitnesses < min_fitness p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty) random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), genotypes, ) @@ -423,7 +430,7 @@ def _sample_explore( repertoire_empty = novelties < min_novelty p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(subkey, x, shape=(1,), p=p), repertoire.genotypes, ) @@ -480,7 +487,7 @@ def _es_emitter( # Sampling non-mirror noise else: sample_number = total_sample_number - sample_noise = jax.tree_map( + sample_noise = jax.tree_util.tree_map( lambda x: jax.random.normal( key=subkey, shape=jnp.repeat(x, sample_number, axis=0).shape, @@ -490,11 +497,11 @@ def _es_emitter( gradient_noise = sample_noise # Applying noise - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jnp.repeat(x, total_sample_number, axis=0), parent, ) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda mean, noise: mean + self._config.sample_sigma * noise, samples, sample_noise, @@ -520,7 +527,7 @@ def _es_emitter( if self._config.sample_mirror: ranks = jnp.reshape(ranks, (sample_number, 2)) ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks) - ranks = jax.tree_map( + ranks = jax.tree_util.tree_map( lambda x: jnp.reshape( jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape ), @@ -528,16 +535,16 @@ def _es_emitter( ) # Computing the gradients - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda noise, rank: jnp.multiply(noise, rank), gradient_noise, ranks, ) - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda x: jnp.reshape(x, (sample_number, -1)), gradient, ) - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda g, p: jnp.reshape( -jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma), p.shape, @@ -547,7 +554,7 @@ def _es_emitter( ) # Adding regularisation - gradient = jax.tree_map( + gradient = jax.tree_util.tree_map( lambda g, p: g + self._config.l2_coefficient * p, gradient, parent, @@ -620,7 +627,7 @@ def _buffers_update( last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set( fitnesses[0] ) - last_updated_genotypes = jax.tree_map( + last_updated_genotypes = jax.tree_util.tree_map( lambda last_gen, gen: last_gen.at[ jnp.expand_dims(last_updated_position, axis=0) ].set(gen), diff --git a/qdax/core/emitters/multi_emitter.py b/qdax/core/emitters/multi_emitter.py index 2da46639..17cb8ace 100644 --- a/qdax/core/emitters/multi_emitter.py +++ b/qdax/core/emitters/multi_emitter.py @@ -8,7 +8,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class MultiEmitterState(EmitterState): @@ -56,13 +56,19 @@ def get_indexes_separation_batches( return tuple(indexes_separation_batches) def init( - self, init_genotypes: Optional[Genotype], random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[Optional[EmitterState], RNGKey]: """ Initialize the state of the emitter. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: @@ -76,7 +82,14 @@ def init( # init all emitter states - gather them emitter_states = [] for emitter, subkey_emitter in zip(self.emitters, subkeys): - emitter_state, _ = emitter.init(init_genotypes, subkey_emitter) + emitter_state, _ = emitter.init( + subkey_emitter, + repertoire, + genotypes, + fitnesses, + descriptors, + extra_scores, + ) emitter_states.append(emitter_state) return MultiEmitterState(tuple(emitter_states)), random_key @@ -87,7 +100,7 @@ def emit( repertoire: Optional[Repertoire], emitter_state: Optional[MultiEmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Emit new population. Use all the sub emitters to emit subpopulation and gather them. @@ -108,21 +121,25 @@ def emit( # emit from all emitters and gather offsprings all_offsprings = [] + all_extra_info: ExtraScores = {} for emitter, sub_emitter_state, subkey_emitter in zip( self.emitters, emitter_state.emitter_states, subkeys, ): - genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter) + genotype, extra_info, _ = emitter.emit( + repertoire, sub_emitter_state, subkey_emitter + ) batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0] assert batch_size == emitter.batch_size all_offsprings.append(genotype) + all_extra_info = {**all_extra_info, **extra_info} # concatenate offsprings together offsprings = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=0), *all_offsprings ) - return offsprings, random_key + return offsprings, all_extra_info, random_key @partial(jax.jit, static_argnames=("self",)) def state_update( diff --git a/qdax/core/emitters/mutation_operators.py b/qdax/core/emitters/mutation_operators.py index f39b8060..bda2daca 100644 --- a/qdax/core/emitters/mutation_operators.py +++ b/qdax/core/emitters/mutation_operators.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp -from qdax.types import Genotype, RNGKey +from qdax.custom_types import Genotype, RNGKey def _polynomial_mutation( diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index 7336750d..580bd151 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -6,7 +6,14 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Centroid, Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import ( + Centroid, + Descriptor, + ExtraScores, + Fitness, + Genotype, + RNGKey, +) class OMGMEGAEmitterState(EmitterState): @@ -84,20 +91,26 @@ def __init__( self._num_descriptors = num_descriptors def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: MapElitesRepertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[OMGMEGAEmitterState, RNGKey]: """Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals. Args: - init_genotypes: The genotypes of the initial population. + genotypes: The genotypes of the initial population. random_key: a random key to handle stochastic operations. Returns: The initial emitter state. """ # retrieve one genotype from the population - first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes) # add a dimension of size num descriptors + 1 gradient_genotype = jax.tree_util.tree_map( @@ -112,6 +125,18 @@ def init( genotype=gradient_genotype, centroids=self._centroids ) + # get gradients out of the extra scores + assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key" + gradients = extra_scores["gradients"] + + # update the gradients repertoire + gradients_repertoire = gradients_repertoire.add( + gradients, + descriptors, + fitnesses, + extra_scores, + ) + return ( OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire), random_key, @@ -126,7 +151,7 @@ def emit( repertoire: MapElitesRepertoire, emitter_state: OMGMEGAEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates. @@ -190,7 +215,7 @@ def emit( lambda x, y: x + y, genotypes, update_grad ) - return new_genotypes, random_key + return new_genotypes, {}, random_key @partial( jax.jit, diff --git a/qdax/core/emitters/pbt_me_emitter.py b/qdax/core/emitters/pbt_me_emitter.py index 3fdb4418..55bded4e 100644 --- a/qdax/core/emitters/pbt_me_emitter.py +++ b/qdax/core/emitters/pbt_me_emitter.py @@ -12,8 +12,8 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey class PBTEmitterState(EmitterState): @@ -91,12 +91,18 @@ def __init__( ) def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[PBTEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -145,13 +151,13 @@ def init( # Create emitter state # keep only pg population size training states if more are provided - init_genotypes = jax.tree_util.tree_map( - lambda x: x[: self._config.pg_population_size_per_device], init_genotypes + genotypes = jax.tree_util.tree_map( + lambda x: x[: self._config.pg_population_size_per_device], genotypes ) emitter_state = PBTEmitterState( replay_buffers=replay_buffers, env_states=env_states, - training_states=init_genotypes, + training_states=genotypes, random_key=subkey2, ) @@ -166,7 +172,7 @@ def emit( repertoire: Repertoire, emitter_state: PBTEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid. @@ -199,7 +205,7 @@ def emit( else: genotypes = x_mutation_pg - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/emitters/pbt_variation_operators.py b/qdax/core/emitters/pbt_variation_operators.py index bd76ecd1..c8537003 100644 --- a/qdax/core/emitters/pbt_variation_operators.py +++ b/qdax/core/emitters/pbt_variation_operators.py @@ -3,7 +3,7 @@ from qdax.baselines.sac_pbt import PBTSacTrainingState from qdax.baselines.td3_pbt import PBTTD3TrainingState from qdax.core.emitters.mutation_operators import isoline_variation -from qdax.types import RNGKey +from qdax.custom_types import RNGKey def sac_pbt_variation_fn( @@ -94,7 +94,10 @@ def td3_pbt_variation_fn( training_state1.critic_params, training_state2.critic_params, ) - (policy_params, critic_params,), random_key = isoline_variation( + ( + policy_params, + critic_params, + ), random_key = isoline_variation( x1=(policy_params1, critic_params1), x2=(policy_params2, critic_params2), random_key=random_key, diff --git a/qdax/core/emitters/pga_me_emitter.py b/qdax/core/emitters/pga_me_emitter.py index e93eb696..a4f8b33f 100644 --- a/qdax/core/emitters/pga_me_emitter.py +++ b/qdax/core/emitters/pga_me_emitter.py @@ -6,8 +6,8 @@ from qdax.core.emitters.multi_emitter import MultiEmitter from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Params, RNGKey @dataclass diff --git a/qdax/core/emitters/qdpg_emitter.py b/qdax/core/emitters/qdpg_emitter.py index eefd1566..b9de6090 100644 --- a/qdax/core/emitters/qdpg_emitter.py +++ b/qdax/core/emitters/qdpg_emitter.py @@ -5,6 +5,7 @@ it has been updated to work better with Jax in term of time cost. Those changes have been made in accordance with the authors of this algorithm. """ + import functools from dataclasses import dataclass from typing import Callable @@ -17,8 +18,8 @@ from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.qpg_emitter import QualityPGConfig, QualityPGEmitter from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.custom_types import Reward, StateDescriptor from qdax.environments.base_wrappers import QDEnv -from qdax.types import Reward, StateDescriptor @dataclass diff --git a/qdax/core/emitters/qpg_emitter.py b/qdax/core/emitters/qpg_emitter.py index c07e3b18..63373494 100644 --- a/qdax/core/emitters/qpg_emitter.py +++ b/qdax/core/emitters/qpg_emitter.py @@ -17,8 +17,8 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn from qdax.core.neuroevolution.networks.networks import QModule +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey from qdax.environments.base_wrappers import QDEnv -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey @dataclass @@ -119,12 +119,18 @@ def use_all_data(self) -> bool: return True def init( - self, init_genotypes: Genotype, random_key: RNGKey + self, + random_key: RNGKey, + repertoire: Repertoire, + genotypes: Genotype, + fitnesses: Fitness, + descriptors: Descriptor, + extra_scores: ExtraScores, ) -> Tuple[QualityPGEmitterState, RNGKey]: """Initializes the emitter state. Args: - init_genotypes: The initial population. + genotypes: The initial population. random_key: A random key. Returns: @@ -144,8 +150,8 @@ def init( ) target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) - target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) + target_actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes) # Prepare init optimizer states critic_optimizer_state = self._critic_optimizer.init(critic_params) @@ -162,6 +168,13 @@ def init( buffer_size=self._config.replay_buffer_size, transition=dummy_transition ) + # get the transitions out of the dictionary + assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key" + transitions = extra_scores["transitions"] + + # add transitions in the replay buffer + replay_buffer = replay_buffer.insert(transitions) + # Initial training state random_key, subkey = jax.random.split(random_key) emitter_state = QualityPGEmitterState( @@ -171,9 +184,9 @@ def init( actor_opt_state=actor_optimizer_state, target_critic_params=target_critic_params, target_actor_params=target_actor_params, + replay_buffer=replay_buffer, random_key=subkey, steps=jnp.array(0), - replay_buffer=replay_buffer, ) return emitter_state, random_key @@ -187,7 +200,7 @@ def emit( repertoire: Repertoire, emitter_state: QualityPGEmitterState, random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """Do a step of PG emission. Args: @@ -223,7 +236,7 @@ def emit( offspring_actor, ) - return genotypes, random_key + return genotypes, {}, random_key @partial( jax.jit, @@ -366,7 +379,11 @@ def _train_critics( ) # Update greedy actor - (actor_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond( + ( + actor_optimizer_state, + actor_params, + target_actor_params, + ) = jax.lax.cond( emitter_state.steps % self._config.policy_delay == 0, lambda x: self._update_actor(*x), lambda _: ( @@ -411,7 +428,7 @@ def _update_critic( # compute loss and gradients random_key, subkey = jax.random.split(random_key) - critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)( + critic_gradient = jax.grad(self._critic_loss_fn)( critic_params, target_actor_params, target_critic_params, @@ -426,7 +443,7 @@ def _update_critic( critic_params = optax.apply_updates(critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_critic_params, @@ -446,7 +463,7 @@ def _update_actor( ) -> Tuple[optax.OptState, Params, Params]: # Update greedy actor - policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( actor_params, critic_params, transitions, @@ -458,7 +475,7 @@ def _update_actor( actor_params = optax.apply_updates(actor_params, policy_updates) # Soft update of target greedy actor - target_actor_params = jax.tree_map( + target_actor_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, target_actor_params, @@ -514,7 +531,11 @@ def scan_train_policy( new_policy_optimizer_state, ), () - (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan( + ( + emitter_state, + policy_params, + policy_optimizer_state, + ), _ = jax.lax.scan( scan_train_policy, (emitter_state, policy_params, policy_optimizer_state), (), @@ -574,7 +595,7 @@ def _update_policy( ) -> Tuple[optax.OptState, Params]: # compute loss - _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)( + policy_gradient = jax.grad(self._policy_loss_fn)( policy_params, critic_params, transitions, diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index 8b877792..1d949b2d 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -6,7 +6,7 @@ from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import Genotype, RNGKey +from qdax.custom_types import ExtraScores, Genotype, RNGKey class MixingEmitter(Emitter): @@ -31,7 +31,7 @@ def emit( repertoire: Repertoire, emitter_state: Optional[EmitterState], random_key: RNGKey, - ) -> Tuple[Genotype, RNGKey]: + ) -> Tuple[Genotype, ExtraScores, RNGKey]: """ Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, @@ -75,7 +75,7 @@ def emit( x_mutation, ) - return genotypes, random_key + return genotypes, {}, random_key @property def batch_size(self) -> int: diff --git a/qdax/core/map_elites.py b/qdax/core/map_elites.py index c71b0013..d0b075a9 100644 --- a/qdax/core/map_elites.py +++ b/qdax/core/map_elites.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites algorithm.""" + from __future__ import annotations from functools import partial @@ -8,7 +9,7 @@ from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, @@ -23,7 +24,7 @@ class MAPElites: """Core elements of the MAP-Elites algorithm. Note: Although very similar to the GeneticAlgorithm, we decided to keep the - MAPElites class independent of the GeneticAlgorithm class at the moment to keep + MAPElites class independant of the GeneticAlgorithm class at the moment to keep elements explicit. Args: @@ -52,7 +53,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]: @@ -62,9 +63,9 @@ def init( such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) - centroids: tessellation centroids of shape (batch_size, num_descriptors) + centroids: tesselation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. Returns: @@ -73,12 +74,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MapElitesRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -87,14 +88,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, @@ -129,9 +125,10 @@ def update( a new jax PRNG key """ # generate offsprings with the emitter - genotypes, random_key = self._emitter.emit( + genotypes, extra_info, random_key = self._emitter.emit( repertoire, emitter_state, random_key ) + # scores the offsprings fitnesses, descriptors, extra_scores, random_key = self._scoring_function( genotypes, random_key @@ -147,7 +144,7 @@ def update( genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, - extra_scores=extra_scores, + extra_scores={**extra_scores, **extra_info}, ) # update the metrics @@ -173,7 +170,12 @@ def scan_update( The updated repertoire and emitter state, with a new random key and metrics. """ repertoire, emitter_state, random_key = carry - (repertoire, emitter_state, metrics, random_key,) = self.update( + ( + repertoire, + emitter_state, + metrics, + random_key, + ) = self.update( repertoire, emitter_state, random_key, diff --git a/qdax/core/mels.py b/qdax/core/mels.py index 6c06b785..8b0e7511 100644 --- a/qdax/core/mels.py +++ b/qdax/core/mels.py @@ -1,4 +1,5 @@ """Core components of the MAP-Elites Low-Spread algorithm.""" + from __future__ import annotations from functools import partial @@ -9,7 +10,7 @@ from qdax.core.containers.mels_repertoire import MELSRepertoire from qdax.core.emitters.emitter import Emitter, EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import ( +from qdax.custom_types import ( Centroid, Descriptor, ExtraScores, @@ -55,7 +56,7 @@ def __init__( @partial(jax.jit, static_argnames=("self",)) def init( self, - init_genotypes: Genotype, + genotypes: Genotype, centroids: Centroid, random_key: RNGKey, ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]: @@ -64,7 +65,7 @@ def init( be computed with any method such as CVT or Euclidean mapping. Args: - init_genotypes: initial genotypes, pytree in which leaves + genotypes: initial genotypes, pytree in which leaves have shape (batch_size, num_features) centroids: tessellation centroids of shape (batch_size, num_descriptors) random_key: a random key used for stochastic operations. @@ -75,12 +76,12 @@ def init( """ # score initial genotypes fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MELSRepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -89,14 +90,19 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key + random_key=random_key, + repertoire=repertoire, + genotypes=genotypes, + fitnesses=fitnesses, + descriptors=descriptors, + extra_scores=extra_scores, ) # update emitter state emitter_state = self._emitter.state_update( emitter_state=emitter_state, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/mome.py b/qdax/core/mome.py index 2a004f59..c239bd1f 100644 --- a/qdax/core/mome.py +++ b/qdax/core/mome.py @@ -9,7 +9,7 @@ from qdax.core.containers.mome_repertoire import MOMERepertoire from qdax.core.emitters.emitter import EmitterState from qdax.core.map_elites import MAPElites -from qdax.types import Centroid, RNGKey +from qdax.custom_types import Centroid, RNGKey class MOME(MAPElites): @@ -23,7 +23,7 @@ class MOME(MAPElites): @partial(jax.jit, static_argnames=("self", "pareto_front_max_length")) def init( self, - init_genotypes: jnp.ndarray, + genotypes: jnp.ndarray, centroids: Centroid, pareto_front_max_length: int, random_key: RNGKey, @@ -33,7 +33,7 @@ def init( CVT or Euclidean mapping. Args: - init_genotypes: genotypes of the initial population. + genotypes: genotypes of the initial population. centroids: centroids of the repertoire. pareto_front_max_length: maximum size of the pareto front. This is necessary to respect jax.jit fixed shape size constraint. @@ -45,12 +45,12 @@ def init( # first score fitnesses, descriptors, extra_scores, random_key = self._scoring_function( - init_genotypes, random_key + genotypes, random_key ) # init the repertoire repertoire = MOMERepertoire.init( - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, centroids=centroids, @@ -60,14 +60,9 @@ def init( # get initial state of the emitter emitter_state, random_key = self._emitter.init( - init_genotypes=init_genotypes, random_key=random_key - ) - - # update emitter state - emitter_state = self._emitter.state_update( - emitter_state=emitter_state, + random_key=random_key, repertoire=repertoire, - genotypes=init_genotypes, + genotypes=genotypes, fitnesses=fitnesses, descriptors=descriptors, extra_scores=extra_scores, diff --git a/qdax/core/neuroevolution/buffers/buffer.py b/qdax/core/neuroevolution/buffers/buffer.py index 42ed7552..81f1e896 100644 --- a/qdax/core/neuroevolution/buffers/buffer.py +++ b/qdax/core/neuroevolution/buffers/buffer.py @@ -7,7 +7,15 @@ import jax import jax.numpy as jnp -from qdax.types import Action, Done, Observation, Reward, RNGKey, StateDescriptor +from qdax.custom_types import ( + Action, + Descriptor, + Done, + Observation, + Reward, + RNGKey, + StateDescriptor, +) class Transition(flax.struct.PyTreeNode): @@ -262,6 +270,155 @@ def init_dummy( # type: ignore return dummy_transition +class DCRLTransition(QDTransition): + """Stores data corresponding to a transition collected by a QD algorithm.""" + + desc: Descriptor + desc_prime: Descriptor + + @property + def descriptor_dim(self) -> int: + """ + Returns: + the dimension of the descriptors. + """ + return self.state_desc.shape[-1] # type: ignore + + @property + def flatten_dim(self) -> int: + """ + Returns: + the dimension of the transition once flattened. + """ + flatten_dim = ( + 2 * self.observation_dim + + self.action_dim + + 3 + + 2 * self.state_descriptor_dim + + 2 * self.descriptor_dim + ) + return flatten_dim + + def flatten(self) -> jnp.ndarray: + """ + Returns: + a jnp.ndarray that corresponds to the flattened transition. + """ + flatten_transition = jnp.concatenate( + [ + self.obs, + self.next_obs, + jnp.expand_dims(self.rewards, axis=-1), + jnp.expand_dims(self.dones, axis=-1), + jnp.expand_dims(self.truncations, axis=-1), + self.actions, + self.state_desc, + self.next_state_desc, + self.desc, + self.desc_prime, + ], + axis=-1, + ) + return flatten_transition + + @classmethod + def from_flatten( + cls, + flattened_transition: jnp.ndarray, + transition: DCRLTransition, + ) -> DCRLTransition: + """ + Creates a transition from a flattened transition in a jnp.ndarray. + Args: + flattened_transition: flattened transition in a jnp.ndarray of shape + (batch_size, flatten_dim) + transition: a transition object (might be a dummy one) to + get the dimensions right + Returns: + a Transition object + """ + obs_dim = transition.observation_dim + action_dim = transition.action_dim + state_desc_dim = transition.state_descriptor_dim + desc_dim = transition.descriptor_dim + + obs = flattened_transition[:, :obs_dim] + next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)] + rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)]) + dones = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)] + ) + truncations = jnp.ravel( + flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)] + ) + actions = flattened_transition[ + :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim) + ] + state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim) : ( + 2 * obs_dim + 3 + action_dim + state_desc_dim + ), + ] + next_state_desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + ), + ] + desc = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim + ), + ] + desc_prime = flattened_transition[ + :, + (2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + desc_dim) : ( + 2 * obs_dim + 3 + action_dim + 2 * state_desc_dim + 2 * desc_dim + ), + ] + return cls( + obs=obs, + next_obs=next_obs, + rewards=rewards, + dones=dones, + truncations=truncations, + actions=actions, + state_desc=state_desc, + next_state_desc=next_state_desc, + desc=desc, + desc_prime=desc_prime, + ) + + @classmethod + def init_dummy( # type: ignore + cls, observation_dim: int, action_dim: int, descriptor_dim: int + ) -> DCRLTransition: + """ + Initialize a dummy transition that then can be passed to constructors to get + all shapes right. + Args: + observation_dim: observation dimension + action_dim: action dimension + Returns: + a dummy transition + """ + dummy_transition = DCRLTransition( + obs=jnp.zeros(shape=(1, observation_dim)), + next_obs=jnp.zeros(shape=(1, observation_dim)), + rewards=jnp.zeros(shape=(1,)), + dones=jnp.zeros(shape=(1,)), + truncations=jnp.zeros(shape=(1,)), + actions=jnp.zeros(shape=(1, action_dim)), + state_desc=jnp.zeros(shape=(1, descriptor_dim)), + next_state_desc=jnp.zeros(shape=(1, descriptor_dim)), + desc=jnp.zeros(shape=(1, descriptor_dim)), + desc_prime=jnp.zeros(shape=(1, descriptor_dim)), + ) + return dummy_transition + + class ReplayBuffer(flax.struct.PyTreeNode): """ A replay buffer where transitions are flattened before being stored. diff --git a/qdax/core/neuroevolution/buffers/trajectory_buffer.py b/qdax/core/neuroevolution/buffers/trajectory_buffer.py index 2cc4ab69..93e1b2f9 100644 --- a/qdax/core/neuroevolution/buffers/trajectory_buffer.py +++ b/qdax/core/neuroevolution/buffers/trajectory_buffer.py @@ -8,7 +8,7 @@ from flax import struct from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Reward, RNGKey +from qdax.custom_types import Reward, RNGKey class TrajectoryBuffer(struct.PyTreeNode): diff --git a/qdax/core/neuroevolution/losses/dads_loss.py b/qdax/core/neuroevolution/losses/dads_loss.py index b42ca416..60edfee1 100644 --- a/qdax/core/neuroevolution/losses/dads_loss.py +++ b/qdax/core/neuroevolution/losses/dads_loss.py @@ -6,7 +6,14 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.losses.sac_loss import make_sac_loss_fn -from qdax.types import Action, Observation, Params, RNGKey, Skill, StateDescriptor +from qdax.custom_types import ( + Action, + Observation, + Params, + RNGKey, + Skill, + StateDescriptor, +) def make_dads_loss_fn( diff --git a/qdax/core/neuroevolution/losses/diayn_loss.py b/qdax/core/neuroevolution/losses/diayn_loss.py index 8bca3b4b..e25a73bd 100644 --- a/qdax/core/neuroevolution/losses/diayn_loss.py +++ b/qdax/core/neuroevolution/losses/diayn_loss.py @@ -7,7 +7,7 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.losses.sac_loss import make_sac_loss_fn -from qdax.types import Action, Observation, Params, RNGKey, StateDescriptor +from qdax.custom_types import Action, Observation, Params, RNGKey, StateDescriptor def make_diayn_loss_fn( diff --git a/qdax/core/neuroevolution/losses/sac_loss.py b/qdax/core/neuroevolution/losses/sac_loss.py index b3656b18..d7289292 100644 --- a/qdax/core/neuroevolution/losses/sac_loss.py +++ b/qdax/core/neuroevolution/losses/sac_loss.py @@ -6,7 +6,7 @@ from brax.training.distribution import ParametricDistribution from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.custom_types import Action, Observation, Params, RNGKey def make_sac_loss_fn( diff --git a/qdax/core/neuroevolution/losses/td3_loss.py b/qdax/core/neuroevolution/losses/td3_loss.py index 7f34a036..964c2c4f 100644 --- a/qdax/core/neuroevolution/losses/td3_loss.py +++ b/qdax/core/neuroevolution/losses/td3_loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Action, Observation, Params, RNGKey +from qdax.custom_types import Action, Descriptor, Observation, Params, RNGKey def make_td3_loss_fn( @@ -94,6 +94,110 @@ def _critic_loss_fn( return _policy_loss_fn, _critic_loss_fn +def make_td3_loss_dc_fn( + policy_fn: Callable[[Params, Observation], jnp.ndarray], + actor_fn: Callable[[Params, Observation, Descriptor], jnp.ndarray], + critic_fn: Callable[[Params, Observation, Action, Descriptor], jnp.ndarray], + reward_scaling: float, + discount: float, + noise_clip: float, + policy_noise: float, +) -> Tuple[ + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Transition], jnp.ndarray], + Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray], +]: + """Creates the loss functions for TD3. + Args: + policy_fn: forward pass through the neural network defining the policy. + actor_fn: forward pass through the neural network defining the + descriptor-conditioned policy. + critic_fn: forward pass through the neural network defining the + descriptor-conditioned critic. + reward_scaling: value to multiply the reward given by the environment. + discount: discount factor. + noise_clip: value that clips the noise to avoid extreme values. + policy_noise: noise applied to smooth the bootstrapping. + Returns: + Return the loss functions used to train the policy and the critic in TD3. + """ + + @jax.jit + def _policy_loss_fn( + policy_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Policy loss function for TD3 agent""" + action = policy_fn(policy_params, transitions.obs) + q_value = critic_fn( + critic_params, transitions.obs, action, transitions.desc_prime + ) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _actor_loss_fn( + actor_params: Params, + critic_params: Params, + transitions: Transition, + ) -> jnp.ndarray: + """Descriptor-conditioned policy loss function for TD3 agent""" + action = actor_fn(actor_params, transitions.obs, transitions.desc_prime) + q_value = critic_fn( + critic_params, transitions.obs, action, transitions.desc_prime + ) + q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) + policy_loss = -jnp.mean(q1_action) + return policy_loss + + @jax.jit + def _critic_loss_fn( + critic_params: Params, + target_actor_params: Params, + target_critic_params: Params, + transitions: Transition, + random_key: RNGKey, + ) -> jnp.ndarray: + """Descriptor-conditioned critic loss function for TD3 agent""" + noise = ( + jax.random.normal(random_key, shape=transitions.actions.shape) + * policy_noise + ).clip(-noise_clip, noise_clip) + + next_action = ( + actor_fn(target_actor_params, transitions.next_obs, transitions.desc_prime) + + noise + ).clip(-1.0, 1.0) + next_q = critic_fn( + target_critic_params, + transitions.next_obs, + next_action, + transitions.desc_prime, + ) + next_v = jnp.min(next_q, axis=-1) + target_q = jax.lax.stop_gradient( + transitions.rewards * reward_scaling + + (1.0 - transitions.dones) * discount * next_v + ) + q_old_action = critic_fn( + critic_params, transitions.obs, transitions.actions, transitions.desc_prime + ) + q_error = q_old_action - jnp.expand_dims(target_q, -1) + + # Better bootstrapping for truncated episodes. + q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1) + + # compute the loss + q_losses = jnp.mean(jnp.square(q_error), axis=-2) + q_loss = jnp.sum(q_losses, axis=-1) + + return q_loss + + return _policy_loss_fn, _actor_loss_fn, _critic_loss_fn + + def td3_policy_loss_fn( policy_params: Params, critic_params: Params, @@ -115,9 +219,7 @@ def td3_policy_loss_fn( """ action = policy_fn(policy_params, transitions.obs) - q_value = critic_fn( - critic_params, obs=transitions.obs, actions=action # type: ignore - ) + q_value = critic_fn(critic_params, transitions.obs, action) # type: ignore q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1) policy_loss = -jnp.mean(q1_action) return policy_loss diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 984d1aeb..f269a22b 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -9,7 +9,7 @@ from flax.struct import PyTreeNode from qdax.core.neuroevolution.buffers.buffer import Transition -from qdax.types import Genotype, Params, RNGKey +from qdax.custom_types import Descriptor, Genotype, Params, RNGKey class TrainingState(PyTreeNode): @@ -67,6 +67,60 @@ def _scan_play_step_fn( return state, transitions +@partial(jax.jit, static_argnames=("play_step_actor_dc_fn", "episode_length")) +def generate_unroll_actor_dc( + init_state: EnvState, + actor_dc_params: Params, + desc: Descriptor, + random_key: RNGKey, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[ + EnvState, + Descriptor, + Params, + RNGKey, + Transition, + ], + ], +) -> Tuple[EnvState, Transition]: + """Generates an episode according to the agent's policy and descriptor, + returns the final state of the episode and the transitions of the episode. + + Args: + init_state: first state of the rollout. + policy_dc_params: descriptor-conditioned policy params. + desc: descriptor the policy attempts to achieve. + random_key: random key for stochasiticity handling. + episode_length: length of the rollout. + play_step_fn: function describing how a step need to be taken. + + Returns: + A new state, the experienced transition. + """ + + def _scan_play_step_fn( + carry: Tuple[EnvState, Params, Descriptor, RNGKey], unused_arg: Any + ) -> Tuple[Tuple[EnvState, Params, Descriptor, RNGKey], Transition]: + ( + env_state, + actor_dc_params, + desc, + random_key, + transitions, + ) = play_step_actor_dc_fn(*carry) + return (env_state, actor_dc_params, desc, random_key), transitions + + (state, _, _, _), transitions = jax.lax.scan( + _scan_play_step_fn, + (init_state, actor_dc_params, desc, random_key), + (), + length=episode_length, + ) + return state, transitions + + @jax.jit def get_first_episode(transition: Transition) -> Transition: """Extracts the first episode from a batch of transitions, returns the batch of @@ -80,7 +134,7 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray: # the double transpose trick is here to allow easy broadcasting return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T - return jax.tree_map(mask_episodes, transition) # type: ignore + return jax.tree_util.tree_map(mask_episodes, transition) # type: ignore def init_population_controllers( diff --git a/qdax/core/neuroevolution/networks/dads_networks.py b/qdax/core/neuroevolution/networks/dads_networks.py index beb4b77a..863bdab5 100644 --- a/qdax/core/neuroevolution/networks/dads_networks.py +++ b/qdax/core/neuroevolution/networks/dads_networks.py @@ -1,128 +1,129 @@ from typing import Optional, Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp -from haiku.initializers import Initializer, VarianceScaling - -from qdax.types import Action, Observation, Skill, StateDescriptor - - -class GaussianMixture(hk.Module): - """Module that outputs a Gaussian Mixture Distribution.""" - - def __init__( - self, - num_dimensions: int, - num_components: int, - reinterpreted_batch_ndims: Optional[int] = None, - identity_covariance: bool = True, - initializer: Optional[Initializer] = None, - name: str = "GaussianMixture", - ): - """Module that outputs a Gaussian Mixture Distribution - with identity covariance matrix.""" - - super().__init__(name=name) - if initializer is None: - initializer = VarianceScaling(1.0, "fan_in", "uniform") - self._num_dimensions = num_dimensions - self._num_components = num_components - self._reinterpreted_batch_ndims = reinterpreted_batch_ndims - self._identity_covariance = identity_covariance - self.initializer = initializer - logits_size = self._num_components - - self.logit_layer = hk.Linear(logits_size, w_init=self.initializer) - - # Create two layers that outputs a location and a scale, respectively, for - # each dimension and each component. - self.loc_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=self.initializer - ) - if not self._identity_covariance: - self.scale_layer = hk.Linear( - self._num_dimensions * self._num_components, w_init=self.initializer - ) +from jax.nn import initializers + +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation, Skill, StateDescriptor + +class GaussianMixture(nn.Module): + num_dimensions: int + num_components: int + reinterpreted_batch_ndims: Optional[int] = None + identity_covariance: bool = True + initializer: Optional[initializers.Initializer] = None + + @nn.compact def __call__(self, inputs: jnp.ndarray) -> tfp.distributions.Distribution: - # Compute logits, locs, and scales if necessary. - logits = self.logit_layer(inputs) - locs = self.loc_layer(inputs) + if self.initializer is None: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + else: + init = self.initializer - shape = [-1, self._num_components, self._num_dimensions] # [B, D, C] + logits = nn.Dense(self.num_components, kernel_init=init)(inputs) + locs = nn.Dense(self.num_dimensions * self.num_components, kernel_init=init)( + inputs + ) - # Reshape the mixture's location and scale parameters appropriately. + shape = [-1, self.num_components, self.num_dimensions] # [B, D, C] locs = locs.reshape(shape) - if not self._identity_covariance: - scales = self.scale_layer(inputs) + if not self.identity_covariance: + scales = nn.Dense( + self.num_dimensions * self.num_components, kernel_init=init + )(inputs) scales = scales.reshape(shape) else: scales = jnp.ones_like(locs) - # Create the mixture distribution components = tfp.distributions.MultivariateNormalDiag( loc=locs, scale_diag=scales ) mixture = tfp.distributions.Categorical(logits=logits) - distribution = tfp.distributions.MixtureSameFamily( + return tfp.distributions.MixtureSameFamily( mixture_distribution=mixture, components_distribution=components ) - return distribution - -class DynamicsNetwork(hk.Module): - """Dynamics network (used in DADS).""" +class DynamicsNetwork(nn.Module): + hidden_layer_sizes: Tuple[int, ...] + output_size: int + omit_input_dynamics_dim: int = 2 + identity_covariance: bool = True + initializer: Optional[initializers.Initializer] = None - def __init__( - self, - hidden_layer_sizes: tuple, - output_size: int, - omit_input_dynamics_dim: int = 2, - name: Optional[str] = None, - identity_covariance: bool = True, - initializer: Optional[Initializer] = None, - ): - super().__init__(name=name) - if initializer is None: - initializer = VarianceScaling(1.0, "fan_in", "uniform") + @nn.compact + def __call__( + self, obs: StateDescriptor, skill: Skill, target: StateDescriptor + ) -> jnp.ndarray: + if self.initializer is None: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + else: + init = self.initializer - self.distribution = GaussianMixture( - output_size, + distribution = GaussianMixture( + self.output_size, num_components=4, reinterpreted_batch_ndims=None, - identity_covariance=identity_covariance, - initializer=initializer, - ) - self.network = hk.Sequential( - [ - hk.nets.MLP( - list(hidden_layer_sizes), - w_init=initializer, - activation=jax.nn.relu, - activate_final=True, - ), - ] + identity_covariance=self.identity_covariance, + initializer=init, ) - self._omit_input_dynamics_dim = omit_input_dynamics_dim - def __call__( - self, obs: StateDescriptor, skill: Skill, target: StateDescriptor - ) -> jnp.ndarray: - """Normalizes the observation, predicts a distribution probability conditioned - on (obs,skill) and returns the log_prob of the target. - """ - - obs = obs[:, self._omit_input_dynamics_dim :] + obs = obs[:, self.omit_input_dynamics_dim :] obs = jnp.concatenate((obs, skill), axis=1) - out = self.network(obs) - dist = self.distribution(out) + + x = MLP( + layer_sizes=self.hidden_layer_sizes, + kernel_init=init, + activation=nn.relu, + final_activation=nn.relu, + )(obs) + + dist = distribution(x) return dist.log_prob(target) +class Actor(nn.Module): + action_size: int + hidden_layer_sizes: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + + return MLP( + layer_sizes=self.hidden_layer_sizes + (2 * self.action_size,), + kernel_init=init, + activation=nn.relu, + )(obs) + + +class Critic(nn.Module): + hidden_layer_sizes: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + init = initializers.variance_scaling(1.0, "fan_in", "uniform") + input_ = jnp.concatenate([obs, action], axis=-1) + + value_1 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + kernel_init=init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) + + def make_dads_networks( action_size: int, descriptor_size: int, @@ -130,78 +131,16 @@ def make_dads_networks( policy_hidden_layer_size: Tuple[int, ...] = (256, 256), omit_input_dynamics_dim: int = 2, identity_covariance: bool = True, - dynamics_initializer: Optional[Initializer] = None, -) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: - """Creates networks used in DADS. - - Args: - action_size: the size of the environment's action space - descriptor_size: the size of the environment's descriptor space (i.e. the - dimension of the dynamics network's input) - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). - omit_input_dynamics_dim: how many descriptors we omit when creating the input - of the dynamics networks. Defaults to 2. - identity_covariance: whether to fix the covariance matrix of the Gaussian models - to identity. Defaults to True. - dynamics_initializer: the initializer of the dynamics layers. Defaults to None. - - Returns: - the policy network - the critic network - the dynamics network - """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - def _dynamics_fn( - obs: StateDescriptor, skill: Skill, target: StateDescriptor - ) -> jnp.ndarray: - dynamics_network = DynamicsNetwork( - critic_hidden_layer_size, - descriptor_size, - omit_input_dynamics_dim=omit_input_dynamics_dim, - identity_covariance=identity_covariance, - initializer=dynamics_initializer, - ) - return dynamics_network(obs, skill, target) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - dynamics = hk.without_apply_rng(hk.transform(_dynamics_fn)) + dynamics_initializer: Optional[initializers.Initializer] = None, +) -> Tuple[nn.Module, nn.Module, nn.Module]: + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) + dynamics = DynamicsNetwork( + critic_hidden_layer_size, + descriptor_size, + omit_input_dynamics_dim=omit_input_dynamics_dim, + identity_covariance=identity_covariance, + initializer=dynamics_initializer, + ) return policy, critic, dynamics diff --git a/qdax/core/neuroevolution/networks/diayn_networks.py b/qdax/core/neuroevolution/networks/diayn_networks.py index c656cace..e292e131 100644 --- a/qdax/core/neuroevolution/networks/diayn_networks.py +++ b/qdax/core/neuroevolution/networks/diayn_networks.py @@ -1,10 +1,60 @@ from typing import Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp -from qdax.types import Action, Observation +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation + + +class Actor(nn.Module): + action_size: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) + + +class Critic(nn.Module): + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + input_ = jnp.concatenate([obs, action], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) + + +class Discriminator(nn.Module): + num_skills: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (self.num_skills,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + activation=nn.relu, + )(obs) def make_diayn_networks( @@ -12,71 +62,22 @@ def make_diayn_networks( num_skills: int, critic_hidden_layer_size: Tuple[int, ...] = (256, 256), policy_hidden_layer_size: Tuple[int, ...] = (256, 256), -) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]: +) -> Tuple[nn.Module, nn.Module, nn.Module]: """Creates networks used in DIAYN. Args: action_size: the size of the environment's action space num_skills: the number of skills set - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). + critic_hidden_layer_size: the number of neurons for critic hidden layers. + policy_hidden_layer_size: the number of neurons for policy hidden layers. Returns: the policy network the critic network the discriminator network """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - def _discriminator_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [num_skills], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) - discriminator = hk.without_apply_rng(hk.transform(_discriminator_fn)) + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) + discriminator = Discriminator(num_skills, critic_hidden_layer_size) return policy, critic, discriminator diff --git a/qdax/core/neuroevolution/networks/networks.py b/qdax/core/neuroevolution/networks/networks.py index b2b176ef..365c8d56 100644 --- a/qdax/core/neuroevolution/networks/networks.py +++ b/qdax/core/neuroevolution/networks/networks.py @@ -5,31 +5,51 @@ import flax.linen as nn import jax import jax.numpy as jnp -from brax.training import networks -class QModule(nn.Module): - """Q Module.""" +class MLP(nn.Module): + """MLP module.""" - hidden_layer_sizes: Tuple[int, ...] - n_critics: int = 2 + layer_sizes: Tuple[int, ...] + activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform() + final_activation: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None + bias: bool = True + kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: - hidden = jnp.concatenate([obs, actions], axis=-1) - res = [] - for _ in range(self.n_critics): - q = networks.MLP( - layer_sizes=self.hidden_layer_sizes + (1,), - activation=nn.relu, - kernel_init=jax.nn.initializers.lecun_uniform(), - )(hidden) - res.append(q) - return jnp.concatenate(res, axis=-1) + def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: + hidden = obs + for i, hidden_size in enumerate(self.layer_sizes): + if i != len(self.layer_sizes) - 1: + hidden = nn.Dense( + hidden_size, + kernel_init=self.kernel_init, + use_bias=self.bias, + )(hidden) + hidden = self.activation(hidden) # type: ignore -class MLP(nn.Module): - """MLP module.""" + else: + if self.kernel_init_final is not None: + kernel_init = self.kernel_init_final + else: + kernel_init = self.kernel_init + + hidden = nn.Dense( + hidden_size, + kernel_init=kernel_init, + use_bias=self.bias, + )(hidden) + + if self.final_activation is not None: + hidden = self.final_activation(hidden) + + return hidden + + +class MLPDC(nn.Module): + """Descriptor-conditioned MLP module.""" layer_sizes: Tuple[int, ...] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @@ -39,15 +59,13 @@ class MLP(nn.Module): kernel_init_final: Optional[Callable[..., Any]] = None @nn.compact - def __call__(self, data: jnp.ndarray) -> jnp.ndarray: - hidden = data + def __call__(self, obs: jnp.ndarray, desc: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, desc], axis=-1) for i, hidden_size in enumerate(self.layer_sizes): if i != len(self.layer_sizes) - 1: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", with this version of flax, changing the name - # changes the initialization kernel_init=self.kernel_init, use_bias=self.bias, )(hidden) @@ -61,7 +79,6 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = nn.Dense( hidden_size, - # name=f"hidden_{i}", kernel_init=kernel_init, use_bias=self.bias, )(hidden) @@ -70,3 +87,45 @@ def __call__(self, data: jnp.ndarray) -> jnp.ndarray: hidden = self.final_activation(hidden) return hidden + + +class QModule(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLP( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden) + res.append(q) + return jnp.concatenate(res, axis=-1) + + +class QModuleDC(nn.Module): + """Q Module.""" + + hidden_layer_sizes: Tuple[int, ...] + n_critics: int = 2 + + @nn.compact + def __call__( + self, obs: jnp.ndarray, actions: jnp.ndarray, desc: jnp.ndarray + ) -> jnp.ndarray: + hidden = jnp.concatenate([obs, actions], axis=-1) + res = [] + for _ in range(self.n_critics): + q = MLPDC( + layer_sizes=self.hidden_layer_sizes + (1,), + activation=nn.relu, + kernel_init=jax.nn.initializers.lecun_uniform(), + )(hidden, desc) + res.append(q) + return jnp.concatenate(res, axis=-1) diff --git a/qdax/core/neuroevolution/networks/sac_networks.py b/qdax/core/neuroevolution/networks/sac_networks.py index dcadfaa2..a236afd4 100644 --- a/qdax/core/neuroevolution/networks/sac_networks.py +++ b/qdax/core/neuroevolution/networks/sac_networks.py @@ -1,66 +1,65 @@ from typing import Tuple -import haiku as hk -import jax +import flax.linen as nn import jax.numpy as jnp -from qdax.types import Action, Observation +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Action, Observation + + +class Actor(nn.Module): + action_size: int + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation) -> jnp.ndarray: + return MLP( + layer_sizes=self.hidden_layer_size + (2 * self.action_size,), + kernel_init=nn.initializers.variance_scaling(1.0, "fan_in", "uniform"), + )(obs) + + +class Critic(nn.Module): + hidden_layer_size: Tuple[int, ...] + + @nn.compact + def __call__(self, obs: Observation, action: Action) -> jnp.ndarray: + input_ = jnp.concatenate([obs, action], axis=-1) + + kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "uniform") + + value_1 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + value_2 = MLP( + layer_sizes=self.hidden_layer_size + (1,), + kernel_init=kernel_init, + activation=nn.relu, + )(input_) + + return jnp.concatenate([value_1, value_2], axis=-1) def make_sac_networks( action_size: int, critic_hidden_layer_size: Tuple[int, ...] = (256, 256), policy_hidden_layer_size: Tuple[int, ...] = (256, 256), -) -> Tuple[hk.Transformed, hk.Transformed]: +) -> Tuple[nn.Module, nn.Module]: """Creates networks used in SAC. Args: action_size: the size of the environment's action space - hidden_layer_sizes: the number of neurons for hidden layers. - Defaults to (256, 256). + critic_hidden_layer_size: the number of neurons for critic hidden layers. + policy_hidden_layer_size: the number of neurons for policy hidden layers. Returns: the policy network the critic network """ - - def _actor_fn(obs: Observation) -> jnp.ndarray: - network = hk.Sequential( - [ - hk.nets.MLP( - list(policy_hidden_layer_size) + [2 * action_size], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - return network(obs) - - def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray: - network1 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - network2 = hk.Sequential( - [ - hk.nets.MLP( - list(critic_hidden_layer_size) + [1], - w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"), - activation=jax.nn.relu, - ), - ] - ) - input_ = jnp.concatenate([obs, action], axis=-1) - value1 = network1(input_) - value2 = network2(input_) - return jnp.concatenate([value1, value2], axis=-1) - - policy = hk.without_apply_rng(hk.transform(_actor_fn)) - critic = hk.without_apply_rng(hk.transform(_critic_fn)) + policy = Actor(action_size, policy_hidden_layer_size) + critic = Critic(critic_hidden_layer_size) return policy, critic diff --git a/qdax/core/neuroevolution/networks/seq2seq_networks.py b/qdax/core/neuroevolution/networks/seq2seq_networks.py index ea7618ba..3cb52a3e 100644 --- a/qdax/core/neuroevolution/networks/seq2seq_networks.py +++ b/qdax/core/neuroevolution/networks/seq2seq_networks.py @@ -7,7 +7,6 @@ Licensed under the Apache License, Version 2.0 (the "License") """ - import functools from typing import Any, Tuple diff --git a/qdax/core/neuroevolution/normalization_utils.py b/qdax/core/neuroevolution/normalization_utils.py index 63820921..0c98b29d 100644 --- a/qdax/core/neuroevolution/normalization_utils.py +++ b/qdax/core/neuroevolution/normalization_utils.py @@ -1,11 +1,10 @@ """Utilities functions to perform normalization (generally on observations in RL).""" - from typing import NamedTuple import jax.numpy as jnp -from qdax.types import Observation +from qdax.custom_types import Observation class RunningMeanStdState(NamedTuple): diff --git a/qdax/core/neuroevolution/sac_td3_utils.py b/qdax/core/neuroevolution/sac_td3_utils.py index 1c54511a..32bbe7a4 100644 --- a/qdax/core/neuroevolution/sac_td3_utils.py +++ b/qdax/core/neuroevolution/sac_td3_utils.py @@ -5,6 +5,7 @@ We are currently thinking about elegant ways to unify both in order to avoid code repetition. """ + # TODO: Uniformize with the functions in mdp_utils from functools import partial from typing import Any, Callable, Tuple @@ -14,7 +15,7 @@ from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.core.neuroevolution.mdp_utils import TrainingState -from qdax.types import Metrics +from qdax.custom_types import Metrics @partial( @@ -75,7 +76,8 @@ def generate_unroll( ], ], ) -> Tuple[EnvState, TrainingState, Transition]: - """Generates an episode according to the agent's policy, returns the final state of the + """ + Generates an episode according to the agent's policy, returns the final state of the episode and the transitions of the episode. """ diff --git a/qdax/types.py b/qdax/custom_types.py similarity index 100% rename from qdax/types.py rename to qdax/custom_types.py diff --git a/qdax/environments/base_wrappers.py b/qdax/environments/base_wrappers.py index 6f317e7f..3f709fa7 100644 --- a/qdax/environments/base_wrappers.py +++ b/qdax/environments/base_wrappers.py @@ -1,6 +1,7 @@ from abc import abstractmethod -from typing import Any, List, Tuple +from typing import Any, Tuple +import jax from brax.v1 import jumpy as jp from brax.v1.envs import Env, State @@ -22,7 +23,7 @@ def state_descriptor_name(self) -> str: @property @abstractmethod - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -32,7 +33,7 @@ def behavior_descriptor_length(self) -> int: @property @abstractmethod - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: pass @property @@ -71,7 +72,7 @@ def state_descriptor_name(self) -> str: return self.env.state_descriptor_name @property - def state_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def state_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.state_descriptor_limits @property @@ -79,7 +80,7 @@ def behavior_descriptor_length(self) -> int: return self.env.behavior_descriptor_length @property - def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]: + def behavior_descriptor_limits(self) -> Tuple[jax.Array, jax.Array]: return self.env.behavior_descriptor_limits @property diff --git a/qdax/environments/bd_extractors.py b/qdax/environments/bd_extractors.py index af1d51ba..918fbbfb 100644 --- a/qdax/environments/bd_extractors.py +++ b/qdax/environments/bd_extractors.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.types import Descriptor, Params +from qdax.custom_types import Descriptor, Params def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor: diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index ec32e7a2..c784b045 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -436,10 +436,8 @@ def step(self, state: State, action: jp.ndarray) -> State: # this line avoid this by increasing the threshold done = jp.where( state.qp.pos[0, 2] < 0.2, - x=jp.array(1, dtype=jp.float32), - y=jp.array(0, dtype=jp.float32), - ) - done = jp.where( - state.qp.pos[0, 2] > 5.0, x=jp.array(1, dtype=jp.float32), y=done + jp.array(1, dtype=jp.float32), + jp.array(0, dtype=jp.float32), ) + done = jp.where(state.qp.pos[0, 2] > 5.0, jp.array(1, dtype=jp.float32), done) return state.replace(obs=new_obs, reward=new_reward, done=done) # type: ignore diff --git a/qdax/environments/locomotion_wrappers.py b/qdax/environments/locomotion_wrappers.py index a727479e..982f5b69 100644 --- a/qdax/environments/locomotion_wrappers.py +++ b/qdax/environments/locomotion_wrappers.py @@ -260,7 +260,7 @@ def name(self) -> str: def reset(self, rng: jp.ndarray) -> State: state = self.env.reset(rng) state.info["state_descriptor"] = jnp.clip( - state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval + state.qp.pos[self._cog_idx][:2], min=self._minval, max=self._maxval ) return state @@ -268,7 +268,7 @@ def step(self, state: State, action: jp.ndarray) -> State: state = self.env.step(state, action) # get xy position of the center of gravity state.info["state_descriptor"] = jnp.clip( - state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval + state.qp.pos[self._cog_idx][:2], min=self._minval, max=self._maxval ) return state diff --git a/qdax/environments/pointmaze.py b/qdax/environments/pointmaze.py index b5f86ef5..78f7c575 100644 --- a/qdax/environments/pointmaze.py +++ b/qdax/environments/pointmaze.py @@ -150,8 +150,8 @@ def step(self, state: State, action: jp.ndarray) -> State: done = jp.where( jp.array(in_zone), - x=jp.array(1.0), - y=jp.array(0.0), + jp.array(1.0), + jp.array(0.0), ) new_obs = jp.array([x_pos, y_pos]) @@ -199,8 +199,8 @@ def _collision_lower_wall( y_axis_down_contact_condition_1 & y_axis_down_contact_condition_2 & x_axis_contact_condition, - x=jp.array(self.lower_wall_height_offset), - y=y_pos, + jp.array(self.lower_wall_height_offset), + y_pos, ) # From up - boolean style @@ -217,8 +217,8 @@ def _collision_lower_wall( & y_axis_up_contact_condition_2 & y_axis_up_contact_condition_3 & x_axis_contact_condition, - x=jp.array(self.lower_wall_height_offset + self.wallheight), - y=new_y_pos, + jp.array(self.lower_wall_height_offset + self.wallheight), + new_y_pos, ) return new_y_pos @@ -250,8 +250,8 @@ def _collision_upper_wall( y_axis_up_contact_condition_1 & y_axis_up_contact_condition_2 & x_axis_contact_condition, - x=jp.array(self.upper_wall_height_offset + self.wallheight), - y=y_pos, + jp.array(self.upper_wall_height_offset + self.wallheight), + y_pos, ) # From down - boolean style @@ -264,8 +264,8 @@ def _collision_upper_wall( & y_axis_down_contact_condition_2 & y_axis_down_contact_condition_3 & x_axis_contact_condition, - x=jp.array(self.upper_wall_height_offset), - y=new_y_pos, + jp.array(self.upper_wall_height_offset), + new_y_pos, ) return new_y_pos diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 720f662a..babedaed 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -1,9 +1,9 @@ -from typing import Dict +from typing import Dict, Optional import flax.struct import jax from brax.v1 import jumpy as jp -from brax.v1.envs import State, Wrapper +from brax.v1.envs import Env, State, Wrapper class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -69,3 +69,55 @@ def step(self, state: State, action: jp.ndarray) -> State: ) nstate.info[self.STATE_INFO_KEY] = eval_metrics return nstate + + +class ClipRewardWrapper(Wrapper): + """Wraps gym environments to clip the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__( + self, + env: Env, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + ) -> None: + super().__init__(env) + self._clip_min = clip_min + self._clip_max = clip_max + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace( + reward=jp.clip(state.reward, a_min=self._clip_min, a_max=self._clip_max) + ) + + +class OffsetRewardWrapper(Wrapper): + """Wraps gym environments to offset the reward to be greater than 0. + + Utilisation is simple: create an environment with Brax, pass + it to the wrapper with the name of the environment, and it will + work like before and will simply clip the reward to be greater than 0. + """ + + def __init__(self, env: Env, offset: float = 0.0) -> None: + super().__init__(env) + self._offset = offset + + def reset(self, rng: jp.ndarray) -> State: + state = self.env.reset(rng) + return state.replace(reward=state.reward + self._offset) + + def step(self, state: State, action: jp.ndarray) -> State: + state = self.env.step(state, action) + return state.replace(reward=state.reward + self._offset) diff --git a/qdax/tasks/arm.py b/qdax/tasks/arm.py index 7122ed63..27782cf3 100644 --- a/qdax/tasks/arm.py +++ b/qdax/tasks/arm.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py index 931ee9d3..07d37d59 100644 --- a/qdax/tasks/brax_envs.py +++ b/qdax/tasks/brax_envs.py @@ -10,9 +10,9 @@ import qdax.environments from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.core.neuroevolution.mdp_utils import generate_unroll +from qdax.core.neuroevolution.mdp_utils import generate_unroll, generate_unroll_actor_dc from qdax.core.neuroevolution.networks.networks import MLP -from qdax.types import ( +from qdax.custom_types import ( Descriptor, EnvState, ExtraScores, @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_brax( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: EnvState, @@ -160,6 +161,81 @@ def scoring_function_brax_envs( ) +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + init_states: EnvState, + episode_length: int, + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel in + deterministic or pseudo-deterministic environments. + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarily + evaluated with the same environment everytime, this won't be determinist. + When the init states are different, this is not purely stochastic. + + Args: + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from evaluation + random_key: The updated random key. + """ + + # Perform rollouts with each policy + random_key, subkey = jax.random.split(random_key) + unroll_fn = partial( + generate_unroll_actor_dc, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + random_key=subkey, + ) + + _final_state, data = jax.vmap(unroll_fn)(init_states, actors_dc_params, descs) + + # create a mask to extract data properly + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + + # Scores - add offset to ensure positive fitness (through positive rewards) + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors = behavior_descriptor_extractor(data, mask) + + return ( + fitnesses, + descriptors, + { + "transitions": data, + }, + random_key, + ) + + @partial( jax.jit, static_argnames=( @@ -225,6 +301,83 @@ def reset_based_scoring_function_brax_envs( return fitnesses, descriptors, extra_scores, random_key +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_reset_fn", + "play_step_actor_dc_fn", + "behavior_descriptor_extractor", + ), +) +def reset_based_scoring_actor_dc_function_brax_envs( + actors_dc_params: Genotype, + descs: Descriptor, + random_key: RNGKey, + episode_length: int, + play_reset_fn: Callable[[RNGKey], EnvState], + play_step_actor_dc_fn: Callable[ + [EnvState, Descriptor, Params, RNGKey], + Tuple[EnvState, Descriptor, Params, RNGKey, QDTransition], + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policy_dc_params in parallel. + The play_reset_fn function allows for a more general scoring_function that can be + called with different batch-size and not only with a batch-size of the same + dimension as init_states. + + To define purely stochastic environments, using the reset function from the + environment, use "play_reset_fn = env.reset". + + To define purely deterministic environments, as in "scoring_function", generate + a single init_state using "init_state = env.reset(random_key)", then use + "play_reset_fn = lambda random_key: init_state". + + Args: + policy_dc_params: The parameters of closed-loop + descriptor-conditioned policy to evaluate. + descriptors: The descriptors the + descriptor-conditioned policy attempts to achieve. + random_key: A jax random key + episode_length: The maximal rollout length. + play_reset_fn: The function to reset the environment + and obtain initial states. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from the evaluation + random_key: The updated random key. + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split( + subkey, jax.tree_util.tree_leaves(actors_dc_params)[0].shape[0] + ) + reset_fn = jax.vmap(play_reset_fn) + init_states = reset_fn(keys) + + ( + fitnesses, + descriptors, + extra_scores, + random_key, + ) = scoring_actor_dc_function_brax_envs( + actors_dc_params=actors_dc_params, + descs=descs, + random_key=random_key, + init_states=init_states, + episode_length=episode_length, + play_step_actor_dc_fn=play_step_actor_dc_fn, + behavior_descriptor_extractor=behavior_descriptor_extractor, + ) + + return fitnesses, descriptors, extra_scores, random_key + + def create_brax_scoring_fn( env: brax.envs.Env, policy_network: nn.Module, diff --git a/qdax/tasks/hypervolume_functions.py b/qdax/tasks/hypervolume_functions.py index f4936574..340581ab 100644 --- a/qdax/tasks/hypervolume_functions.py +++ b/qdax/tasks/hypervolume_functions.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def square(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/tasks/jumanji_envs.py b/qdax/tasks/jumanji_envs.py index 14455d66..5f861f0e 100644 --- a/qdax/tasks/jumanji_envs.py +++ b/qdax/tasks/jumanji_envs.py @@ -7,7 +7,7 @@ import jumanji from qdax.core.neuroevolution.buffers.buffer import QDTransition, Transition -from qdax.types import ( +from qdax.custom_types import ( Descriptor, ExtraScores, Fitness, @@ -41,6 +41,7 @@ def make_policy_network_play_step_fn_jumanji( Returns: default_play_step_fn: A function that plays a step of the environment. """ + # Define the function to play a step with the policy in the environment def default_play_step_fn( env_state: JumanjiState, @@ -67,7 +68,7 @@ def default_play_step_fn( obs=timestep.observation, next_obs=next_timestep.observation, rewards=next_timestep.reward, - dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)), + dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)), actions=action, truncations=jnp.array(0), state_desc=state_desc, diff --git a/qdax/tasks/qd_suite/archimedean_spiral.py b/qdax/tasks/qd_suite/archimedean_spiral.py index 5784f596..59108ae5 100644 --- a/qdax/tasks/qd_suite/archimedean_spiral.py +++ b/qdax/tasks/qd_suite/archimedean_spiral.py @@ -4,8 +4,8 @@ import jax.lax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype class ParameterizationGenotype(Enum): diff --git a/qdax/tasks/qd_suite/deceptive_evolvability.py b/qdax/tasks/qd_suite/deceptive_evolvability.py index d5be0688..830ad523 100644 --- a/qdax/tasks/qd_suite/deceptive_evolvability.py +++ b/qdax/tasks/qd_suite/deceptive_evolvability.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype def multivariate_normal( diff --git a/qdax/tasks/qd_suite/qd_suite_task.py b/qdax/tasks/qd_suite/qd_suite_task.py index 6f1af76f..0d79317f 100644 --- a/qdax/tasks/qd_suite/qd_suite_task.py +++ b/qdax/tasks/qd_suite/qd_suite_task.py @@ -4,7 +4,7 @@ import jax from jax import numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey class QDSuiteTask(abc.ABC): diff --git a/qdax/tasks/qd_suite/ssf.py b/qdax/tasks/qd_suite/ssf.py index 547bee8d..601aa6ad 100644 --- a/qdax/tasks/qd_suite/ssf.py +++ b/qdax/tasks/qd_suite/ssf.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp +from qdax.custom_types import Descriptor, Fitness, Genotype from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask -from qdax.types import Descriptor, Fitness, Genotype class SsfV0(QDSuiteTask): diff --git a/qdax/tasks/standard_functions.py b/qdax/tasks/standard_functions.py index 53d5b492..82b2f875 100644 --- a/qdax/tasks/standard_functions.py +++ b/qdax/tasks/standard_functions.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def rastrigin(params: Genotype) -> Tuple[Fitness, Descriptor]: diff --git a/qdax/utils/metrics.py b/qdax/utils/metrics.py index 2b8355af..509c6d91 100644 --- a/qdax/utils/metrics.py +++ b/qdax/utils/metrics.py @@ -12,7 +12,7 @@ from qdax.core.containers.ga_repertoire import GARepertoire from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire from qdax.core.containers.mome_repertoire import MOMERepertoire -from qdax.types import Metrics +from qdax.custom_types import Metrics from qdax.utils.pareto_front import compute_hypervolume diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 692a2bde..54fad3e6 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from qdax.types import Mask, ParetoFront +from qdax.custom_types import Mask, ParetoFront def compute_pareto_dominance( @@ -24,7 +24,10 @@ def compute_pareto_dominance( Return booleans when the vector is dominated by the batch. """ diff = jnp.subtract(batch_of_criteria, criteria_point) - return jnp.any(jnp.all(diff > 0, axis=-1)) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray: @@ -67,7 +70,10 @@ def compute_masked_pareto_dominance( diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( neutral_values, diff ) - return jnp.any(jnp.all(diff > 0, axis=-1)) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) def compute_masked_pareto_front( diff --git a/qdax/utils/plotting.py b/qdax/utils/plotting.py index 9b107c7e..7f0f086d 100644 --- a/qdax/utils/plotting.py +++ b/qdax/utils/plotting.py @@ -544,7 +544,7 @@ def _get_projection_in_1d( for all index i: x[i] < bases_tuple[i]. The vector and tuple of bases must have the same length. - For example if x=jnp.array([3, 1, 2]) and the bases are (5, 7, 3). + For example if jnp.array([3, 1, 2]) and the bases are (5, 7, 3). then the projection is 3*(7*3) + 1*(3) + 2 = 47. Args: @@ -574,7 +574,7 @@ def _get_projection_in_2d( """Projects an integer vector into a pair of integers, (given tuple of bases to consider for conversion). - For example if x=jnp.array([3, 1, 2, 5]) and the bases are (5, 2, 3, 7). + For example if jnp.array([3, 1, 2, 5]) and the bases are (5, 2, 3, 7). then the projection is obtained by: - projecting in 1D the point jnp.array([3, 2]) with the bases (5, 3) - projecting in 1D the point jnp.array([1, 5]) with the bases (2, 7) diff --git a/qdax/utils/sampling.py b/qdax/utils/sampling.py index bf5c1ae4..be1d336d 100644 --- a/qdax/utils/sampling.py +++ b/qdax/utils/sampling.py @@ -1,11 +1,12 @@ """Core components of the MAP-Elites-sampling algorithm.""" + from functools import partial from typing import Callable, Tuple import jax import jax.numpy as jnp -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey @jax.jit diff --git a/qdax/utils/train_seq2seq.py b/qdax/utils/train_seq2seq.py index acb14a9b..bd9570a9 100644 --- a/qdax/utils/train_seq2seq.py +++ b/qdax/utils/train_seq2seq.py @@ -16,8 +16,8 @@ from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire from qdax.core.neuroevolution.networks.seq2seq_networks import Seq2seq +from qdax.custom_types import Params, RNGKey from qdax.environments.bd_extractors import AuroraExtraInfoNormalization -from qdax.types import Params, RNGKey Array = Any PRNGKey = Any @@ -132,7 +132,7 @@ def lstm_ae_train( std_obs = jnp.nanstd(repertoire.observations, axis=(0, 1)) # the std where they were NaNs was set to zero. But here we divide by the # std, so we replace the zeros by inf here. - std_obs = jnp.where(std_obs == 0, x=jnp.inf, y=std_obs) + std_obs = jnp.where(std_obs == 0, jnp.inf, std_obs) # TODO: maybe we could just compute this data on the valid dataset diff --git a/requirements.txt b/requirements.txt index 978a1c87..f6dea29a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,17 @@ absl-py==1.0.0 -brax==0.9.2 -chex==0.1.83 -dm-haiku==0.0.10 -flax==0.7.4 +brax==0.10.4 +chex==0.1.86 +flax==0.8.5 gym==0.26.2 ipython -jax==0.4.16 -jaxlib==0.4.16 +jax==0.4.28 +jaxlib==0.4.28 jumanji==0.3.1 jupyter -numpy==1.24.1 -optax==0.1.7 -protobuf==3.19.4 -scikit-learn==1.0.2 -scipy==1.8.0 -seaborn==0.11.2 -tensorflow-probability==0.19.0 -typing-extensions==4.3.0 +numpy==1.26.4 +optax==0.1.9 +protobuf==3.19.5 +scikit-learn==1.5.1 +scipy==1.10.1 +tensorflow-probability==0.24.0 +typing-extensions==4.12.2 diff --git a/setup.py b/setup.py index 0065bf18..cd7d2b13 100644 --- a/setup.py +++ b/setup.py @@ -22,19 +22,23 @@ long_description_content_type="text/markdown", install_requires=[ "absl-py>=1.0.0", - "jax>=0.4.16", - "jaxlib>=0.4.16", # necessary to build the doc atm - "jinja2<3.1.0", + "brax>=0.10.4", + "chex>=0.1.86", + "flax>=0.8.5", + "gym>=0.26.2", + "jax>=0.4.28", + "jaxlib>=0.4.28", # necessary to build the doc atm + "jinja2>=3.1.4", "jumanji>=0.3.1", - "flax>=0.7.4", - "chex>=0.1.83", - "brax>=0.9.2", - "gym>=0.23.1", - "numpy>=1.22.3", - "optax>=0.1.7", - "scikit-learn>=1.0.2", - "scipy>=1.8.0", + "numpy>=1.26.4", + "optax>=0.1.9", + "scikit-learn>=1.5.1", + "scipy>=1.10.1", + "tensorflow-probability>=0.24.0", ], + extras_require={ + "cuda12": ["jax[cuda12]>=0.4.28"], + }, dependency_links=[ "https://storage.googleapis.com/jax-releases/jax_releases.html", ], @@ -46,7 +50,9 @@ "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) diff --git a/tests/baselines_test/cmame_test.py b/tests/baselines_test/cmame_test.py index c86bd622..2dc6fa10 100644 --- a/tests/baselines_test/cmame_test.py +++ b/tests/baselines_test/cmame_test.py @@ -16,7 +16,7 @@ from qdax.core.emitters.cma_pool_emitter import CMAPoolEmitter from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey @pytest.mark.parametrize( @@ -25,7 +25,7 @@ ) def test_cma_me(emitter_type: Type[CMAEmitter]) -> None: - num_iterations = 1000 + num_iterations = 2000 num_dimensions = 20 grid_shape = (50, 50) batch_size = 36 @@ -43,7 +43,7 @@ def sphere_scoring(x: jnp.ndarray) -> jnp.ndarray: def clip(x: jnp.ndarray) -> jnp.ndarray: in_bound = (x <= maxval) * (x >= minval) - return jnp.where(condition=in_bound, x=x, y=(maxval / x)) + return jnp.where(in_bound, x, (maxval / x)) def _behavior_descriptor_1(x: jnp.ndarray) -> jnp.ndarray: return jnp.sum(clip(x[: x.shape[-1] // 2])) @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/cmamega_test.py b/tests/baselines_test/cmamega_test.py index fdd9330b..5bfdfd58 100644 --- a/tests/baselines_test/cmamega_test.py +++ b/tests/baselines_test/cmamega_test.py @@ -12,7 +12,7 @@ ) from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey def test_cma_mega() -> None: @@ -125,7 +125,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/dads_smerl_test.py b/tests/baselines_test/dads_smerl_test.py index 1e782f2a..2a8d3d1f 100644 --- a/tests/baselines_test/dads_smerl_test.py +++ b/tests/baselines_test/dads_smerl_test.py @@ -1,4 +1,5 @@ """Testing script for the algorithm DADS""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/dads_test.py b/tests/baselines_test/dads_test.py index 0b9af46e..77094ffd 100644 --- a/tests/baselines_test/dads_test.py +++ b/tests/baselines_test/dads_test.py @@ -1,5 +1,6 @@ """Training script for the algorithm DADS, should be launched with hydra. e.g. python train_dads.py config=dads_ant""" + from functools import partial from typing import Any, Tuple diff --git a/tests/baselines_test/dcrlme_test.py b/tests/baselines_test/dcrlme_test.py new file mode 100644 index 00000000..05304944 --- /dev/null +++ b/tests/baselines_test/dcrlme_test.py @@ -0,0 +1,221 @@ +import functools +from typing import Any, Tuple + +import jax +import jax.numpy as jnp +import pytest + +from qdax import environments +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids +from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.map_elites import MAPElites +from qdax.core.neuroevolution.buffers.buffer import DCRLTransition +from qdax.core.neuroevolution.networks.networks import MLP, MLPDC +from qdax.custom_types import EnvState, Params, RNGKey +from qdax.environments import behavior_descriptor_extractor +from qdax.environments.wrappers import ClipRewardWrapper, OffsetRewardWrapper +from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs +from qdax.utils.metrics import default_qd_metrics + + +def test_dcrlme() -> None: + seed = 42 + + env_name = "ant_omni" + episode_length = 100 + min_bd = -30.0 + max_bd = 30.0 + + num_iterations = 5 + batch_size = 256 + + # Archive + num_init_cvt_samples = 50000 + num_centroids = 1024 + policy_hidden_layer_sizes = (128, 128) + + # DCRL-ME + ga_batch_size = 128 + dcrl_batch_size = 64 + ai_batch_size = 64 + lengthscale = 0.1 + + # GA emitter + iso_sigma = 0.005 + line_sigma = 0.05 + + # DCRL emitter + critic_hidden_layer_size = (256, 256) + num_critic_training_steps = 3000 + num_pg_training_steps = 150 + replay_buffer_size = 1_000_000 + discount = 0.99 + reward_scaling = 1.0 + critic_learning_rate = 3e-4 + actor_learning_rate = 3e-4 + policy_learning_rate = 5e-3 + noise_clip = 0.5 + policy_noise = 0.2 + soft_tau_update = 0.005 + policy_delay = 2 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init environment + env = environments.create(env_name, episode_length=episode_length) + env = OffsetRewardWrapper( + env, offset=environments.reward_offset[env_name] + ) # apply reward offset as DCRL needs positive rewards + env = ClipRewardWrapper( + env, + clip_min=0.0, + ) # apply reward clip as DCRL needs positive rewards + + reset_fn = jax.jit(env.reset) + + # Compute the centroids + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + + # Init policy network + policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + actor_dc_network = MLPDC( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, num=batch_size) + fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size)) + init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs) + + # Define the fonction to play a step with the policy in the environment + def play_step_fn( + env_state: EnvState, policy_params: Params, random_key: RNGKey + ) -> Tuple[EnvState, Params, RNGKey, DCRLTransition]: + actions = policy_network.apply(policy_params, env_state.obs) + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = DCRLTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + truncations=next_state.info["truncation"], + actions=actions, + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + desc=jnp.zeros( + env.behavior_descriptor_length, + ) + * jnp.nan, + desc_prime=jnp.zeros( + env.behavior_descriptor_length, + ) + * jnp.nan, + ) + + return next_state, policy_params, random_key, transition + + # Prepare the scoring function + bd_extraction_fn = behavior_descriptor_extractor[env_name] + scoring_fn = functools.partial( + reset_based_scoring_function_brax_envs, + episode_length=episode_length, + play_reset_fn=reset_fn, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + # Get minimum reward value to make sure qd_score are positive + reward_offset = environments.reward_offset[env_name] + + # Define a metrics function + metrics_function = functools.partial( + default_qd_metrics, + qd_offset=reward_offset * episode_length, + ) + + # Define the DCRL-emitter config + dcrl_emitter_config = DCRLMEConfig( + ga_batch_size=ga_batch_size, + dcrl_batch_size=dcrl_batch_size, + ai_batch_size=ai_batch_size, + lengthscale=lengthscale, + critic_hidden_layer_size=critic_hidden_layer_size, + num_critic_training_steps=num_critic_training_steps, + num_pg_training_steps=num_pg_training_steps, + batch_size=batch_size, + replay_buffer_size=replay_buffer_size, + discount=discount, + reward_scaling=reward_scaling, + critic_learning_rate=critic_learning_rate, + actor_learning_rate=actor_learning_rate, + policy_learning_rate=policy_learning_rate, + noise_clip=noise_clip, + policy_noise=policy_noise, + soft_tau_update=soft_tau_update, + policy_delay=policy_delay, + ) + + # Get the emitter + variation_fn = functools.partial( + isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma + ) + + dcrl_emitter = DCRLMEEmitter( + config=dcrl_emitter_config, + policy_network=policy_network, + actor_network=actor_dc_network, + env=env, + variation_fn=variation_fn, + ) + + # Instantiate MAP Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=dcrl_emitter, + metrics_function=metrics_function, + ) + + # compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init( + init_params, centroids, random_key + ) + + @jax.jit + def update_scan_fn(carry: Any, unused: Any) -> Any: + # iterate over grid + repertoire, emitter_state, metrics, random_key = map_elites.update(*carry) + + return (repertoire, emitter_state, random_key), metrics + + # Run the algorithm + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( + update_scan_fn, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) diff --git a/tests/baselines_test/ga_test.py b/tests/baselines_test/ga_test.py index 4e11370b..a1eb1b51 100644 --- a/tests/baselines_test/ga_test.py +++ b/tests/baselines_test/ga_test.py @@ -15,7 +15,7 @@ polynomial_mutation, ) from qdax.core.emitters.standard_emitters import MixingEmitter -from qdax.types import ExtraScores, Fitness, RNGKey +from qdax.custom_types import ExtraScores, Fitness, RNGKey from qdax.utils.metrics import default_ga_metrics @@ -32,11 +32,11 @@ def test_ga(algorithm_class: Type[GeneticAlgorithm]) -> None: batch_size = 100 genotype_dim = 6 lag = 2.2 - base_lag = 0 + base_lag = 0.0 num_neighbours = 1 def rastrigin_scorer( - genotypes: jnp.ndarray, base_lag: int, lag: int + genotypes: jnp.ndarray, base_lag: float, lag: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Rastrigin Scorer with first two dimensions as descriptors @@ -73,7 +73,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, genotype_dim), minval=minval, @@ -111,15 +111,19 @@ def scoring_fn( if isinstance(algo_instance, SPEA2): repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, num_neighbours, random_key + genotypes, population_size, num_neighbours, random_key ) else: repertoire, emitter_state, random_key = algo_instance.init( - init_genotypes, population_size, random_key + genotypes, population_size, random_key ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( algo_instance.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/me_pbt_sac_test.py b/tests/baselines_test/me_pbt_sac_test.py index 079fde45..5058bad6 100644 --- a/tests/baselines_test/me_pbt_sac_test.py +++ b/tests/baselines_test/me_pbt_sac_test.py @@ -119,14 +119,14 @@ def test_me_pbt_sac() -> None: def scoring_function(genotypes, random_key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -178,7 +178,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) @@ -186,7 +186,7 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_map( + initial_metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] @@ -196,7 +196,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys, metrics = update_fn( repertoire, emitter_state, keys ) - metrics_cpu = jax.tree_map( + metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/me_pbt_td3_test.py b/tests/baselines_test/me_pbt_td3_test.py index 5c6fbb0a..39c3e942 100644 --- a/tests/baselines_test/me_pbt_td3_test.py +++ b/tests/baselines_test/me_pbt_td3_test.py @@ -117,14 +117,14 @@ def test_me_pbt_td3() -> None: def scoring_function(genotypes, random_key): # type: ignore population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0] - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) - first_states = jax.tree_map( + first_states = jax.tree_util.tree_map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_bds, _, _ = eval_policy(genotypes, first_states) - return population_returns, population_bds, None, random_key + return population_returns, population_bds, {}, random_key # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] @@ -176,7 +176,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )( - init_genotypes=training_states, random_key=keys + genotypes=training_states, random_key=keys ) # type: ignore update_fn = map_elites.get_distributed_update_fn(num_iterations=1, devices=devices) @@ -184,7 +184,7 @@ def scoring_function(genotypes, random_key): # type: ignore initial_metrics = jax.pmap(metrics_function, axis_name="p", devices=devices)( repertoire ) - initial_metrics_cpu = jax.tree_map( + initial_metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], initial_metrics ) initial_qd_score = initial_metrics_cpu["qd_score"] @@ -194,7 +194,7 @@ def scoring_function(genotypes, random_key): # type: ignore repertoire, emitter_state, keys, metrics = update_fn( repertoire, emitter_state, keys ) - metrics_cpu = jax.tree_map( + metrics_cpu = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) diff --git a/tests/baselines_test/mees_test.py b/tests/baselines_test/mees_test.py index 3f3314fd..d1913b02 100644 --- a/tests/baselines_test/mees_test.py +++ b/tests/baselines_test/mees_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_mees() -> None: @@ -185,7 +185,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/omgmega_test.py b/tests/baselines_test/omgmega_test.py index 7b0f0639..ad51c7ae 100644 --- a/tests/baselines_test/omgmega_test.py +++ b/tests/baselines_test/omgmega_test.py @@ -11,7 +11,7 @@ ) from qdax.core.emitters.omg_mega_emitter import OMGMEGAEmitter from qdax.core.map_elites import MAPElites -from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey def test_omg_mega() -> None: @@ -113,7 +113,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]: initial_population, centroids, random_key ) - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/pbt_sac_test.py b/tests/baselines_test/pbt_sac_test.py index c83f277c..db7dc69e 100644 --- a/tests/baselines_test/pbt_sac_test.py +++ b/tests/baselines_test/pbt_sac_test.py @@ -59,7 +59,7 @@ def init_environments(random_key): # type: ignore eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) reshape_fn = jax.jit( - lambda tree: jax.tree_map( + lambda tree: jax.tree_util.tree_map( lambda x: jnp.reshape( x, ( diff --git a/tests/baselines_test/pbt_td3_test.py b/tests/baselines_test/pbt_td3_test.py index 9e6134c9..0be68277 100644 --- a/tests/baselines_test/pbt_td3_test.py +++ b/tests/baselines_test/pbt_td3_test.py @@ -57,7 +57,7 @@ def init_environments(random_key): # type: ignore eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key) reshape_fn = jax.jit( - lambda tree: jax.tree_map( + lambda tree: jax.tree_util.tree_map( lambda x: jnp.reshape( x, ( diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 9cb1b3fb..0490a481 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -15,8 +15,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_pgame() -> None: @@ -189,7 +189,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/qdpg_test.py b/tests/baselines_test/qdpg_test.py index 1889f197..704416a4 100644 --- a/tests/baselines_test/qdpg_test.py +++ b/tests/baselines_test/qdpg_test.py @@ -17,8 +17,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey def test_qdpg() -> None: @@ -239,7 +239,11 @@ def update_scan_fn(carry: Any, unused: Any) -> Any: return (repertoire, emitter_state, random_key), metrics # Run the algorithm - (repertoire, emitter_state, random_key,), _metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), _metrics = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, random_key), (), diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index c667aa66..8c26b510 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -10,7 +10,7 @@ from qdax.baselines.sac import SAC, SacConfig, TrainingState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer -from qdax.types import EnvState +from qdax.custom_types import EnvState def test_sac() -> None: diff --git a/tests/core_test/aurora_test.py b/tests/core_test/aurora_test.py index 2b238237..4bbb9d82 100644 --- a/tests/core_test/aurora_test.py +++ b/tests/core_test/aurora_test.py @@ -11,6 +11,7 @@ from qdax import environments from qdax.core.aurora import AURORA from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.custom_types import Observation from qdax.environments.bd_extractors import ( AuroraExtraInfoNormalization, get_aurora_encoding, @@ -19,7 +20,6 @@ create_default_brax_task_components, get_aurora_scoring_fn, ) -from qdax.types import Observation from qdax.utils import train_seq2seq from qdax.utils.metrics import default_qd_metrics from tests.core_test.map_elites_test import get_mixing_emitter diff --git a/tests/core_test/containers_test/mapelites_repertoire_test.py b/tests/core_test/containers_test/mapelites_repertoire_test.py index 55e6ed11..5c0d9d75 100644 --- a/tests/core_test/containers_test/mapelites_repertoire_test.py +++ b/tests/core_test/containers_test/mapelites_repertoire_test.py @@ -5,7 +5,7 @@ MapElitesRepertoire, compute_euclidean_centroids, ) -from qdax.types import ExtraScores +from qdax.custom_types import ExtraScores def test_mapelites_repertoire() -> None: diff --git a/tests/core_test/containers_test/mels_repertoire_test.py b/tests/core_test/containers_test/mels_repertoire_test.py index 2fb1bd76..0b854b32 100644 --- a/tests/core_test/containers_test/mels_repertoire_test.py +++ b/tests/core_test/containers_test/mels_repertoire_test.py @@ -2,7 +2,7 @@ import pytest from qdax.core.containers.mels_repertoire import MELSRepertoire -from qdax.types import ExtraScores +from qdax.custom_types import ExtraScores def test_add_to_mels_repertoire() -> None: diff --git a/tests/core_test/emitters_test/multi_emitter_test.py b/tests/core_test/emitters_test/multi_emitter_test.py index 93b3e081..ebf712d5 100644 --- a/tests/core_test/emitters_test/multi_emitter_test.py +++ b/tests/core_test/emitters_test/multi_emitter_test.py @@ -96,7 +96,11 @@ def test_multi_emitter() -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index b532aa65..c89ce04f 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -14,8 +14,8 @@ from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey from qdax.utils.metrics import default_qd_metrics @@ -143,7 +143,11 @@ def play_step_fn( ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mels_test.py b/tests/core_test/mels_test.py index 21f90517..66bcc05f 100644 --- a/tests/core_test/mels_test.py +++ b/tests/core_test/mels_test.py @@ -15,8 +15,8 @@ from qdax.core.mels import MELS from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey @pytest.mark.parametrize( @@ -142,7 +142,11 @@ def metrics_fn(repertoire: MELSRepertoire) -> Dict: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mels.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/mome_test.py b/tests/core_test/mome_test.py index c70683ef..746b94a0 100644 --- a/tests/core_test/mome_test.py +++ b/tests/core_test/mome_test.py @@ -14,7 +14,7 @@ ) from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.mome import MOME -from qdax.types import Descriptor, ExtraScores, Fitness, RNGKey +from qdax.custom_types import Descriptor, ExtraScores, Fitness, RNGKey from qdax.utils.metrics import default_moqd_metrics @@ -36,10 +36,10 @@ def test_mome(num_descriptors: int) -> None: crossover_percentage = 1.0 batch_size = 80 lag = 2.2 - base_lag = 0 + base_lag = 0.0 def rastrigin_scorer( - genotypes: jnp.ndarray, base_lag: int, lag: int + genotypes: jnp.ndarray, base_lag: float, lag: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Rastrigin Scorer with first two dimensions as descriptors @@ -81,7 +81,7 @@ def scoring_fn( # initial population random_key = jax.random.PRNGKey(42) random_key, subkey = jax.random.split(random_key) - init_genotypes = jax.random.uniform( + genotypes = jax.random.uniform( subkey, (batch_size, num_variables), minval=minval, @@ -127,11 +127,15 @@ def scoring_fn( ) repertoire, emitter_state, random_key = mome.init( - init_genotypes, centroids, pareto_front_max_length, random_key + genotypes, centroids, pareto_front_max_length, random_key ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( mome.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py index e0e298c1..06e25fcd 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/buffer_test.py @@ -42,7 +42,9 @@ def test_insert_batch() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) + simple_transition = jax.tree_util.tree_map( + lambda x: x.repeat(3, axis=0), dummy_transition + ) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) data = QDTransition.from_flatten(replay_buffer.data, dummy_transition) pytest.assume( @@ -83,7 +85,9 @@ def test_sample() -> None: buffer_size=buffer_size, transition=dummy_transition ) - simple_transition = jax.tree_map(lambda x: x.repeat(3, axis=0), dummy_transition) + simple_transition = jax.tree_util.tree_map( + lambda x: x.repeat(3, axis=0), dummy_transition + ) simple_transition = simple_transition.replace(rewards=jnp.arange(3)) replay_buffer = replay_buffer.insert(simple_transition) @@ -91,6 +95,6 @@ def test_sample() -> None: samples, random_key = replay_buffer.sample(random_key, 3) - samples_shapes = jax.tree_map(lambda x: x.shape, samples) - transition_shapes = jax.tree_map(lambda x: x.shape, simple_transition) + samples_shapes = jax.tree_util.tree_map(lambda x: x.shape, samples) + transition_shapes = jax.tree_util.tree_map(lambda x: x.shape, simple_transition) pytest.assume((samples_shapes == transition_shapes)) diff --git a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py index 75f68b40..12ea0874 100644 --- a/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py +++ b/tests/core_test/neuroevolution_test/buffers_test/trajectory_buffer_test.py @@ -202,8 +202,8 @@ def test_trajectory_buffer_insert() -> None: multy_step_episodic_data, equal_nan=True, ), - "Episodic data when transitions are added sequentially is not consistent to when\ - theya are added as batch.", + "Episodic data when transitions are added sequentially is not consistent to \ + when they are added as batch.", ) pytest.assume( diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py index e71e761c..98361b23 100644 --- a/tests/default_tasks_test/arm_test.py +++ b/tests/default_tasks_test/arm_test.py @@ -96,7 +96,11 @@ def test_arm(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py index f8c63259..c12518fb 100644 --- a/tests/default_tasks_test/brax_task_test.py +++ b/tests/default_tasks_test/brax_task_test.py @@ -84,7 +84,11 @@ def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) - ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py index a390f709..3d619353 100644 --- a/tests/default_tasks_test/hypervolume_functions_test.py +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -102,7 +102,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index eed90127..636a02cf 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -11,11 +11,11 @@ from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import Descriptor, Observation from qdax.tasks.jumanji_envs import ( jumanji_scoring_function, make_policy_network_play_step_fn_jumanji, ) -from qdax.types import Descriptor, Observation def test_jumanji_utils() -> None: @@ -53,7 +53,13 @@ def test_jumanji_utils() -> None: def observation_processing( observation: jumanji.environments.routing.snake.types.Observation, ) -> Observation: - network_input = jnp.ravel(observation.grid) + network_input = jnp.concatenate( + [ + jnp.ravel(observation.grid), + jnp.array([observation.step_count]), + observation.action_mask.ravel(), + ] + ) return network_input play_step_fn = make_policy_network_play_step_fn_jumanji( @@ -67,7 +73,12 @@ def observation_processing( keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec - observation_size = np.prod(np.array(env.observation_spec().grid.shape)) + obs_spec = env.observation_spec() + observation_size = int( + np.prod(obs_spec.grid.shape) + + np.prod(obs_spec.step_count.shape) + + np.prod(obs_spec.action_mask.shape) + ) fake_batch = jnp.zeros(shape=(batch_size, observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py index a0542e9b..46f6ce9b 100644 --- a/tests/default_tasks_test/qd_suite_test.py +++ b/tests/default_tasks_test/qd_suite_test.py @@ -117,7 +117,11 @@ def test_qd_suite(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py index 7b310389..87913364 100644 --- a/tests/default_tasks_test/standard_functions_test.py +++ b/tests/default_tasks_test/standard_functions_test.py @@ -92,7 +92,11 @@ def test_standard_functions(task_name: str, batch_size: int) -> None: ) # Run the algorithm - (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + ( + repertoire, + emitter_state, + random_key, + ), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, random_key), (), diff --git a/tests/environments_test/pointmaze_test.py b/tests/environments_test/pointmaze_test.py index a13f41cc..ecc97864 100644 --- a/tests/environments_test/pointmaze_test.py +++ b/tests/environments_test/pointmaze_test.py @@ -6,8 +6,8 @@ from brax.v1.envs import Env import qdax +from qdax.custom_types import EnvState from qdax.environments.pointmaze import PointMaze -from qdax.types import EnvState def test_pointmaze() -> None: diff --git a/tests/utils_test/sampling_test.py b/tests/utils_test/sampling_test.py index 6ce6cbe9..8d19379e 100644 --- a/tests/utils_test/sampling_test.py +++ b/tests/utils_test/sampling_test.py @@ -8,8 +8,8 @@ from qdax import environments from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP +from qdax.custom_types import EnvState, Params, RNGKey from qdax.tasks.brax_envs import scoring_function_brax_envs -from qdax.types import EnvState, Params, RNGKey from qdax.utils.sampling import ( average, closest, diff --git a/tool.Dockerfile b/tool.Dockerfile index 10b15b02..26a68236 100644 --- a/tool.Dockerfile +++ b/tool.Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9.18-slim +FROM python:3.10.14-slim ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1 ENV PIPENV_VENV_IN_PROJECT=true PIP_NO_CACHE_DIR=false PIP_DISABLE_PIP_VERSION_CHECK=1