Skip to content

Commit

Permalink
refactor structure variation operators and selectors
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 8, 2022
1 parent 67757df commit a4e0f53
Show file tree
Hide file tree
Showing 16 changed files with 602 additions and 569 deletions.
573 changes: 4 additions & 569 deletions qdax/core/emitters/mutation_operators.py

Large diffs are not rendered by default.

Empty file.
18 changes: 18 additions & 0 deletions qdax/core/emitters/selectors/abstract_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import abc
from typing import Tuple

from qdax.core.containers.repertoire import Repertoire
from qdax.core.emitters.emitter import EmitterState
from qdax.types import Genotype, RNGKey


class Selector(metaclass=abc.ABCMeta):
@abc.abstractmethod
def select(
self,
number_parents_to_select: int,
repertoire: Repertoire,
emitter_state: EmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, EmitterState, RNGKey]:
...
23 changes: 23 additions & 0 deletions qdax/core/emitters/selectors/uniform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Tuple

from qdax.core.containers.repertoire import Repertoire
from qdax.core.emitters.emitter import EmitterState
from qdax.core.emitters.selectors.abstract_selector import Selector
from qdax.types import Genotype, RNGKey


class UniformSelector(Selector):
def select(
self,
number_parents_to_select: int,
repertoire: Repertoire,
emitter_state: EmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, EmitterState, RNGKey]:
"""
Uniform selection of parents
"""
selected_parents, random_key = repertoire.sample(
random_key, number_parents_to_select
)
return selected_parents, emitter_state, random_key
Empty file.
153 changes: 153 additions & 0 deletions qdax/core/emitters/variation_operators/abstract_variation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import abc
from typing import Optional, Tuple

import jax
from chex import ArrayTree
from jax import numpy as jnp

from qdax.core.emitters.emitter import EmitterState
from qdax.types import Genotype, RNGKey


class VariationOperator(metaclass=abc.ABCMeta):
def __init__(self, minval: Optional[float] = None, maxval: Optional[float] = None):
if minval is not None and maxval is not None:
assert minval < maxval, "minval must be smaller than maxval"
self._minval = minval
self._maxval = maxval

@property
@abc.abstractmethod
def number_parents_to_select(self) -> int:
...

@property
@abc.abstractmethod
def number_genotypes_returned(self) -> int:
...

def calculate_number_parents_to_select(self, batch_size: int) -> int:
assert batch_size % self.number_genotypes_returned == 0, (
"The batch size should be a multiple of the "
"number of genotypes returned after each variation"
)
return (
self.number_parents_to_select * batch_size // self.number_genotypes_returned
)

@abc.abstractmethod
def apply_without_clip(
self, genotypes: Genotype, emitter_state: EmitterState, random_key: RNGKey
) -> Tuple[Genotype, RNGKey]:
...

def _clip(self, gen: Genotype) -> Genotype:
if (self._minval is not None) or (self._maxval is not None):
gen = jax.tree_map(
lambda _gen: jnp.clip(_gen, self._minval, self._maxval), gen
)
return gen

def apply_with_clip(
self,
genotypes: Genotype,
emitter_state: EmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
new_genotypes, random_key = self.apply_without_clip(
genotypes, emitter_state, random_key
)
new_genotypes = self._clip(new_genotypes)
return new_genotypes, random_key

def _divide_genotypes(
self,
genotypes: Genotype,
) -> Tuple[Genotype, ...]:
tuple_genotypes = tuple(
jax.tree_map(
lambda x: x[index_start :: self.number_parents_to_select], genotypes
)
for index_start in range(self.number_parents_to_select)
)
return tuple_genotypes

@staticmethod
def get_tree_keys(
genotype: Genotype, random_key: RNGKey
) -> Tuple[ArrayTree, RNGKey]:
nb_leaves = len(jax.tree_leaves(genotype))
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, num=nb_leaves)
keys_tree = jax.tree_unflatten(jax.tree_structure(genotype), subkeys)
return keys_tree, random_key

@staticmethod
def _get_array_keys_for_each_gen(key: RNGKey, gen_tree: Genotype) -> jnp.ndarray:
subkeys = jax.random.split(key, num=gen_tree.shape[0])
return jnp.asarray(subkeys)

@staticmethod
def get_keys_arrays_tree(
gen_tree: Genotype, random_key: RNGKey
) -> Tuple[ArrayTree, RNGKey]:
keys_tree, random_key = VariationOperator.get_tree_keys(gen_tree, random_key)
keys_arrays_tree = jax.tree_map(
VariationOperator._get_array_keys_for_each_gen, keys_tree, gen_tree
)
return keys_arrays_tree, random_key

@staticmethod
def _get_random_positions_to_change(
genotypes_tree: Genotype,
variation_rate: float,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
def _get_indexes_positions_cross_over(
_gen: Genotype, _key: RNGKey
) -> jnp.ndarray:
num_positions = _gen.shape[0]
positions = jnp.arange(start=0, stop=num_positions)
num_positions_to_change = int(variation_rate * num_positions)
_key, subkey = jax.random.split(_key)
selected_positions = jax.random.choice(
key=subkey, a=positions, shape=(num_positions_to_change,), replace=False
)
return selected_positions

random_key, _subkey = jax.random.split(random_key)

keys_arrays_tree, random_key = VariationOperator.get_keys_arrays_tree(
genotypes_tree, random_key
)

return (
jax.tree_map(
jax.vmap(_get_indexes_positions_cross_over),
genotypes_tree,
keys_arrays_tree,
),
random_key,
)

@staticmethod
def _get_sub_genotypes(
genotypes_tree: Genotype,
selected_positions: jnp.ndarray,
) -> Genotype:
return jax.tree_map(
jax.vmap(lambda _x, _i: _x[_i]), genotypes_tree, selected_positions
)

@staticmethod
def _set_sub_genotypes(
genotypes_tree: Genotype,
selected_positions: jnp.ndarray,
new_genotypes: Genotype,
) -> Genotype:
return jax.tree_map(
jax.vmap(lambda _x, _i, _y: _x.at[_i].set(_y)),
genotypes_tree,
selected_positions,
new_genotypes,
)
40 changes: 40 additions & 0 deletions qdax/core/emitters/variation_operators/composer_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import math
from typing import List, Optional, Tuple

from qdax.core.emitters.emitter import EmitterState
from qdax.core.emitters.variation_operators.abstract_variation import VariationOperator
from qdax.types import Genotype, RNGKey


class ComposerVariations(VariationOperator):
def __init__(
self,
variations_operators_list: List[VariationOperator],
minval: Optional[float] = None,
maxval: Optional[float] = None,
):
super().__init__(minval, maxval)
self.variations_list = variations_operators_list

@property
def number_parents_to_select(self) -> int:
numbers_to_select = map(
lambda x: x.number_parents_to_select, self.variations_list
)
return math.prod(numbers_to_select)

@property
def number_genotypes_returned(self) -> int:
numbers_to_return = map(
lambda x: x.number_genotypes_returned, self.variations_list
)
return math.prod(numbers_to_return)

def apply_without_clip(
self, genotypes: Genotype, emitter_state: EmitterState, random_key: RNGKey
) -> Tuple[Genotype, RNGKey]:
for variation in self.variations_list:
genotypes, random_key = variation.apply_with_clip(
genotypes, emitter_state, random_key
)
return genotypes, random_key
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import abc
from typing import Optional, Tuple

import jax
from jax import numpy as jnp

from qdax.core.emitters.emitter import EmitterState
from qdax.core.emitters.variation_operators.abstract_variation import VariationOperator
from qdax.types import Genotype, RNGKey


class CrossOver(VariationOperator, abc.ABC):
def __init__(
self,
cross_over_rate: float = 1.0,
returns_single_genotype: bool = True,
minval: Optional[float] = None,
maxval: Optional[float] = None,
):
super().__init__(minval, maxval)
self.cross_over_rate = cross_over_rate
self.returns_single_genotype = returns_single_genotype

@property
def number_parents_to_select(self) -> int:
return 2

@property
def number_genotypes_returned(self) -> int:
if self.returns_single_genotype:
return 1
else:
return 2

def apply_without_clip(
self,
genotypes: Genotype,
emitter_state: EmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
gen_1, gen_2 = self._divide_genotypes(genotypes)
selected_indices, random_key = self._get_random_positions_to_change(
gen_1, self.cross_over_rate, random_key
)
subgen_1 = self._get_sub_genotypes(gen_1, selected_positions=selected_indices)
subgen_2 = self._get_sub_genotypes(gen_2, selected_positions=selected_indices)

if self.returns_single_genotype:

new_subgen, random_key = self._cross_over(subgen_1, subgen_2, random_key)
new_gen = self._set_sub_genotypes(gen_1, selected_indices, new_subgen)
return new_gen, random_key
else:
# Not changing random key here to keep same noise for gen_tilde_1 and
# gen_tilde_2 (as done in the literature)
new_subgen_1, _ = self._cross_over(subgen_1, subgen_2, random_key)
new_subgen_2, random_key = self._cross_over(subgen_2, subgen_1, random_key)

new_gen_1 = self._set_sub_genotypes(gen_1, selected_indices, new_subgen_1)
new_gen_2 = self._set_sub_genotypes(gen_2, selected_indices, new_subgen_2)

new_gen = jax.tree_util.tree_map(
lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0),
new_gen_1,
new_gen_2,
)
return new_gen, random_key

@abc.abstractmethod
def _cross_over(
self, gen_1: Genotype, gen_2: Genotype, random_key: RNGKey
) -> Tuple[Genotype, RNGKey]:
...
60 changes: 60 additions & 0 deletions qdax/core/emitters/variation_operators/cross_overs/iso_line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional, Tuple

import jax
from jax import numpy as jnp

from qdax.core.emitters.variation_operators.cross_overs.abstract_cross_over import (
CrossOver,
)
from qdax.types import Genotype, RNGKey


class IsolineVariationOperator(CrossOver):
def __init__(
self,
iso_sigma: float,
line_sigma: float,
cross_over_rate: float = 1.0,
returns_single_genotype: bool = True,
minval: Optional[float] = None,
maxval: Optional[float] = None,
):
super().__init__(
cross_over_rate=cross_over_rate,
returns_single_genotype=returns_single_genotype,
minval=minval,
maxval=maxval,
)
self._iso_sigma = iso_sigma
self._line_sigma = line_sigma

def _cross_over(
self, gen_1: Genotype, gen_2: Genotype, random_key: RNGKey
) -> Tuple[Genotype, RNGKey]:
# Computing line_noise
random_key, key_line_noise = jax.random.split(random_key)
batch_size = jax.tree_leaves(gen_1)[0].shape[0]
line_noise = (
jax.random.normal(key_line_noise, shape=(batch_size,)) * self._line_sigma
)

def _variation_fn(
_x1: jnp.ndarray, _x2: jnp.ndarray, _random_key: RNGKey
) -> jnp.ndarray:
iso_noise = (
jax.random.normal(_random_key, shape=_x1.shape) * self._iso_sigma
)
x = (_x1 + iso_noise) + jax.vmap(jnp.multiply)((_x2 - _x1), line_noise)

# Back in bounds if necessary (floating point issues)
if (self._minval is not None) or (self._maxval is not None):
x = jnp.clip(x, self._minval, self._maxval)
return x

# create a tree with random keys
keys_tree, random_key = self.get_tree_keys(gen_1, random_key)

# apply isolinedd to each branch of the tree
gen_new = jax.tree_map(_variation_fn, gen_1, gen_2, keys_tree)

return gen_new, random_key
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Tuple

from qdax.core.emitters.variation_operators.cross_overs.abstract_cross_over import (
CrossOver,
)
from qdax.types import Genotype, RNGKey


class RecombinationCrossOver(CrossOver):
def _cross_over(
self, gen_original: Genotype, gen_exchange: Genotype, random_key: RNGKey
) -> Tuple[Genotype, RNGKey]:
# The exchange cross over is a simple exchange of the two genotypes
# the proportion of the two genotypes that are changed is the
# same as the cross-over rate the parts which are exchanged are
# randomly selected in CrossOver
return gen_exchange, random_key
Loading

0 comments on commit a4e0f53

Please sign in to comment.