-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor structure variation operators and selectors
- Loading branch information
1 parent
67757df
commit a4e0f53
Showing
16 changed files
with
602 additions
and
569 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import abc | ||
from typing import Tuple | ||
|
||
from qdax.core.containers.repertoire import Repertoire | ||
from qdax.core.emitters.emitter import EmitterState | ||
from qdax.types import Genotype, RNGKey | ||
|
||
|
||
class Selector(metaclass=abc.ABCMeta): | ||
@abc.abstractmethod | ||
def select( | ||
self, | ||
number_parents_to_select: int, | ||
repertoire: Repertoire, | ||
emitter_state: EmitterState, | ||
random_key: RNGKey, | ||
) -> Tuple[Genotype, EmitterState, RNGKey]: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Tuple | ||
|
||
from qdax.core.containers.repertoire import Repertoire | ||
from qdax.core.emitters.emitter import EmitterState | ||
from qdax.core.emitters.selectors.abstract_selector import Selector | ||
from qdax.types import Genotype, RNGKey | ||
|
||
|
||
class UniformSelector(Selector): | ||
def select( | ||
self, | ||
number_parents_to_select: int, | ||
repertoire: Repertoire, | ||
emitter_state: EmitterState, | ||
random_key: RNGKey, | ||
) -> Tuple[Genotype, EmitterState, RNGKey]: | ||
""" | ||
Uniform selection of parents | ||
""" | ||
selected_parents, random_key = repertoire.sample( | ||
random_key, number_parents_to_select | ||
) | ||
return selected_parents, emitter_state, random_key |
Empty file.
153 changes: 153 additions & 0 deletions
153
qdax/core/emitters/variation_operators/abstract_variation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import abc | ||
from typing import Optional, Tuple | ||
|
||
import jax | ||
from chex import ArrayTree | ||
from jax import numpy as jnp | ||
|
||
from qdax.core.emitters.emitter import EmitterState | ||
from qdax.types import Genotype, RNGKey | ||
|
||
|
||
class VariationOperator(metaclass=abc.ABCMeta): | ||
def __init__(self, minval: Optional[float] = None, maxval: Optional[float] = None): | ||
if minval is not None and maxval is not None: | ||
assert minval < maxval, "minval must be smaller than maxval" | ||
self._minval = minval | ||
self._maxval = maxval | ||
|
||
@property | ||
@abc.abstractmethod | ||
def number_parents_to_select(self) -> int: | ||
... | ||
|
||
@property | ||
@abc.abstractmethod | ||
def number_genotypes_returned(self) -> int: | ||
... | ||
|
||
def calculate_number_parents_to_select(self, batch_size: int) -> int: | ||
assert batch_size % self.number_genotypes_returned == 0, ( | ||
"The batch size should be a multiple of the " | ||
"number of genotypes returned after each variation" | ||
) | ||
return ( | ||
self.number_parents_to_select * batch_size // self.number_genotypes_returned | ||
) | ||
|
||
@abc.abstractmethod | ||
def apply_without_clip( | ||
self, genotypes: Genotype, emitter_state: EmitterState, random_key: RNGKey | ||
) -> Tuple[Genotype, RNGKey]: | ||
... | ||
|
||
def _clip(self, gen: Genotype) -> Genotype: | ||
if (self._minval is not None) or (self._maxval is not None): | ||
gen = jax.tree_map( | ||
lambda _gen: jnp.clip(_gen, self._minval, self._maxval), gen | ||
) | ||
return gen | ||
|
||
def apply_with_clip( | ||
self, | ||
genotypes: Genotype, | ||
emitter_state: EmitterState, | ||
random_key: RNGKey, | ||
) -> Tuple[Genotype, RNGKey]: | ||
new_genotypes, random_key = self.apply_without_clip( | ||
genotypes, emitter_state, random_key | ||
) | ||
new_genotypes = self._clip(new_genotypes) | ||
return new_genotypes, random_key | ||
|
||
def _divide_genotypes( | ||
self, | ||
genotypes: Genotype, | ||
) -> Tuple[Genotype, ...]: | ||
tuple_genotypes = tuple( | ||
jax.tree_map( | ||
lambda x: x[index_start :: self.number_parents_to_select], genotypes | ||
) | ||
for index_start in range(self.number_parents_to_select) | ||
) | ||
return tuple_genotypes | ||
|
||
@staticmethod | ||
def get_tree_keys( | ||
genotype: Genotype, random_key: RNGKey | ||
) -> Tuple[ArrayTree, RNGKey]: | ||
nb_leaves = len(jax.tree_leaves(genotype)) | ||
random_key, subkey = jax.random.split(random_key) | ||
subkeys = jax.random.split(subkey, num=nb_leaves) | ||
keys_tree = jax.tree_unflatten(jax.tree_structure(genotype), subkeys) | ||
return keys_tree, random_key | ||
|
||
@staticmethod | ||
def _get_array_keys_for_each_gen(key: RNGKey, gen_tree: Genotype) -> jnp.ndarray: | ||
subkeys = jax.random.split(key, num=gen_tree.shape[0]) | ||
return jnp.asarray(subkeys) | ||
|
||
@staticmethod | ||
def get_keys_arrays_tree( | ||
gen_tree: Genotype, random_key: RNGKey | ||
) -> Tuple[ArrayTree, RNGKey]: | ||
keys_tree, random_key = VariationOperator.get_tree_keys(gen_tree, random_key) | ||
keys_arrays_tree = jax.tree_map( | ||
VariationOperator._get_array_keys_for_each_gen, keys_tree, gen_tree | ||
) | ||
return keys_arrays_tree, random_key | ||
|
||
@staticmethod | ||
def _get_random_positions_to_change( | ||
genotypes_tree: Genotype, | ||
variation_rate: float, | ||
random_key: RNGKey, | ||
) -> Tuple[Genotype, RNGKey]: | ||
def _get_indexes_positions_cross_over( | ||
_gen: Genotype, _key: RNGKey | ||
) -> jnp.ndarray: | ||
num_positions = _gen.shape[0] | ||
positions = jnp.arange(start=0, stop=num_positions) | ||
num_positions_to_change = int(variation_rate * num_positions) | ||
_key, subkey = jax.random.split(_key) | ||
selected_positions = jax.random.choice( | ||
key=subkey, a=positions, shape=(num_positions_to_change,), replace=False | ||
) | ||
return selected_positions | ||
|
||
random_key, _subkey = jax.random.split(random_key) | ||
|
||
keys_arrays_tree, random_key = VariationOperator.get_keys_arrays_tree( | ||
genotypes_tree, random_key | ||
) | ||
|
||
return ( | ||
jax.tree_map( | ||
jax.vmap(_get_indexes_positions_cross_over), | ||
genotypes_tree, | ||
keys_arrays_tree, | ||
), | ||
random_key, | ||
) | ||
|
||
@staticmethod | ||
def _get_sub_genotypes( | ||
genotypes_tree: Genotype, | ||
selected_positions: jnp.ndarray, | ||
) -> Genotype: | ||
return jax.tree_map( | ||
jax.vmap(lambda _x, _i: _x[_i]), genotypes_tree, selected_positions | ||
) | ||
|
||
@staticmethod | ||
def _set_sub_genotypes( | ||
genotypes_tree: Genotype, | ||
selected_positions: jnp.ndarray, | ||
new_genotypes: Genotype, | ||
) -> Genotype: | ||
return jax.tree_map( | ||
jax.vmap(lambda _x, _i, _y: _x.at[_i].set(_y)), | ||
genotypes_tree, | ||
selected_positions, | ||
new_genotypes, | ||
) |
40 changes: 40 additions & 0 deletions
40
qdax/core/emitters/variation_operators/composer_variations.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import math | ||
from typing import List, Optional, Tuple | ||
|
||
from qdax.core.emitters.emitter import EmitterState | ||
from qdax.core.emitters.variation_operators.abstract_variation import VariationOperator | ||
from qdax.types import Genotype, RNGKey | ||
|
||
|
||
class ComposerVariations(VariationOperator): | ||
def __init__( | ||
self, | ||
variations_operators_list: List[VariationOperator], | ||
minval: Optional[float] = None, | ||
maxval: Optional[float] = None, | ||
): | ||
super().__init__(minval, maxval) | ||
self.variations_list = variations_operators_list | ||
|
||
@property | ||
def number_parents_to_select(self) -> int: | ||
numbers_to_select = map( | ||
lambda x: x.number_parents_to_select, self.variations_list | ||
) | ||
return math.prod(numbers_to_select) | ||
|
||
@property | ||
def number_genotypes_returned(self) -> int: | ||
numbers_to_return = map( | ||
lambda x: x.number_genotypes_returned, self.variations_list | ||
) | ||
return math.prod(numbers_to_return) | ||
|
||
def apply_without_clip( | ||
self, genotypes: Genotype, emitter_state: EmitterState, random_key: RNGKey | ||
) -> Tuple[Genotype, RNGKey]: | ||
for variation in self.variations_list: | ||
genotypes, random_key = variation.apply_with_clip( | ||
genotypes, emitter_state, random_key | ||
) | ||
return genotypes, random_key |
Empty file.
73 changes: 73 additions & 0 deletions
73
qdax/core/emitters/variation_operators/cross_overs/abstract_cross_over.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import abc | ||
from typing import Optional, Tuple | ||
|
||
import jax | ||
from jax import numpy as jnp | ||
|
||
from qdax.core.emitters.emitter import EmitterState | ||
from qdax.core.emitters.variation_operators.abstract_variation import VariationOperator | ||
from qdax.types import Genotype, RNGKey | ||
|
||
|
||
class CrossOver(VariationOperator, abc.ABC): | ||
def __init__( | ||
self, | ||
cross_over_rate: float = 1.0, | ||
returns_single_genotype: bool = True, | ||
minval: Optional[float] = None, | ||
maxval: Optional[float] = None, | ||
): | ||
super().__init__(minval, maxval) | ||
self.cross_over_rate = cross_over_rate | ||
self.returns_single_genotype = returns_single_genotype | ||
|
||
@property | ||
def number_parents_to_select(self) -> int: | ||
return 2 | ||
|
||
@property | ||
def number_genotypes_returned(self) -> int: | ||
if self.returns_single_genotype: | ||
return 1 | ||
else: | ||
return 2 | ||
|
||
def apply_without_clip( | ||
self, | ||
genotypes: Genotype, | ||
emitter_state: EmitterState, | ||
random_key: RNGKey, | ||
) -> Tuple[Genotype, RNGKey]: | ||
gen_1, gen_2 = self._divide_genotypes(genotypes) | ||
selected_indices, random_key = self._get_random_positions_to_change( | ||
gen_1, self.cross_over_rate, random_key | ||
) | ||
subgen_1 = self._get_sub_genotypes(gen_1, selected_positions=selected_indices) | ||
subgen_2 = self._get_sub_genotypes(gen_2, selected_positions=selected_indices) | ||
|
||
if self.returns_single_genotype: | ||
|
||
new_subgen, random_key = self._cross_over(subgen_1, subgen_2, random_key) | ||
new_gen = self._set_sub_genotypes(gen_1, selected_indices, new_subgen) | ||
return new_gen, random_key | ||
else: | ||
# Not changing random key here to keep same noise for gen_tilde_1 and | ||
# gen_tilde_2 (as done in the literature) | ||
new_subgen_1, _ = self._cross_over(subgen_1, subgen_2, random_key) | ||
new_subgen_2, random_key = self._cross_over(subgen_2, subgen_1, random_key) | ||
|
||
new_gen_1 = self._set_sub_genotypes(gen_1, selected_indices, new_subgen_1) | ||
new_gen_2 = self._set_sub_genotypes(gen_2, selected_indices, new_subgen_2) | ||
|
||
new_gen = jax.tree_util.tree_map( | ||
lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0), | ||
new_gen_1, | ||
new_gen_2, | ||
) | ||
return new_gen, random_key | ||
|
||
@abc.abstractmethod | ||
def _cross_over( | ||
self, gen_1: Genotype, gen_2: Genotype, random_key: RNGKey | ||
) -> Tuple[Genotype, RNGKey]: | ||
... |
60 changes: 60 additions & 0 deletions
60
qdax/core/emitters/variation_operators/cross_overs/iso_line.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Optional, Tuple | ||
|
||
import jax | ||
from jax import numpy as jnp | ||
|
||
from qdax.core.emitters.variation_operators.cross_overs.abstract_cross_over import ( | ||
CrossOver, | ||
) | ||
from qdax.types import Genotype, RNGKey | ||
|
||
|
||
class IsolineVariationOperator(CrossOver): | ||
def __init__( | ||
self, | ||
iso_sigma: float, | ||
line_sigma: float, | ||
cross_over_rate: float = 1.0, | ||
returns_single_genotype: bool = True, | ||
minval: Optional[float] = None, | ||
maxval: Optional[float] = None, | ||
): | ||
super().__init__( | ||
cross_over_rate=cross_over_rate, | ||
returns_single_genotype=returns_single_genotype, | ||
minval=minval, | ||
maxval=maxval, | ||
) | ||
self._iso_sigma = iso_sigma | ||
self._line_sigma = line_sigma | ||
|
||
def _cross_over( | ||
self, gen_1: Genotype, gen_2: Genotype, random_key: RNGKey | ||
) -> Tuple[Genotype, RNGKey]: | ||
# Computing line_noise | ||
random_key, key_line_noise = jax.random.split(random_key) | ||
batch_size = jax.tree_leaves(gen_1)[0].shape[0] | ||
line_noise = ( | ||
jax.random.normal(key_line_noise, shape=(batch_size,)) * self._line_sigma | ||
) | ||
|
||
def _variation_fn( | ||
_x1: jnp.ndarray, _x2: jnp.ndarray, _random_key: RNGKey | ||
) -> jnp.ndarray: | ||
iso_noise = ( | ||
jax.random.normal(_random_key, shape=_x1.shape) * self._iso_sigma | ||
) | ||
x = (_x1 + iso_noise) + jax.vmap(jnp.multiply)((_x2 - _x1), line_noise) | ||
|
||
# Back in bounds if necessary (floating point issues) | ||
if (self._minval is not None) or (self._maxval is not None): | ||
x = jnp.clip(x, self._minval, self._maxval) | ||
return x | ||
|
||
# create a tree with random keys | ||
keys_tree, random_key = self.get_tree_keys(gen_1, random_key) | ||
|
||
# apply isolinedd to each branch of the tree | ||
gen_new = jax.tree_map(_variation_fn, gen_1, gen_2, keys_tree) | ||
|
||
return gen_new, random_key |
17 changes: 17 additions & 0 deletions
17
qdax/core/emitters/variation_operators/cross_overs/recombination.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Tuple | ||
|
||
from qdax.core.emitters.variation_operators.cross_overs.abstract_cross_over import ( | ||
CrossOver, | ||
) | ||
from qdax.types import Genotype, RNGKey | ||
|
||
|
||
class RecombinationCrossOver(CrossOver): | ||
def _cross_over( | ||
self, gen_original: Genotype, gen_exchange: Genotype, random_key: RNGKey | ||
) -> Tuple[Genotype, RNGKey]: | ||
# The exchange cross over is a simple exchange of the two genotypes | ||
# the proportion of the two genotypes that are changed is the | ||
# same as the cross-over rate the parts which are exchanged are | ||
# randomly selected in CrossOver | ||
return gen_exchange, random_key |
Oops, something went wrong.