Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 20, 2024
2 parents b2ef13e + 96163f2 commit 5ee7823
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions qdax/core/containers/mome_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def _update_masked_pareto_front(

pareto_front_len = pareto_front_fitnesses.shape[0] # type: ignore

first_leaf = jax.tree_util.tree_leaves(new_batch_of_genotypes)[0]
genotypes_dim = first_leaf.shape[1]

descriptors_dim = new_batch_of_descriptors.shape[1]

# gather all data
Expand Down Expand Up @@ -235,14 +232,11 @@ def _update_masked_pareto_front(
front_size = len(pareto_front_fitnesses) # type: ignore
new_front_fitness = new_front_fitness[:front_size, :]

genotypes_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1
)
new_front_genotypes = jax.tree_util.tree_map(
lambda x: x * genotypes_mask, new_front_genotypes
lambda x: x * new_mask_indices[0], new_front_genotypes
)
new_front_genotypes = jax.tree_util.tree_map(
lambda x: x[:front_size, :], new_front_genotypes
lambda x: x[:front_size], new_front_genotypes
)

descriptors_mask = jnp.repeat(
Expand Down Expand Up @@ -297,25 +291,31 @@ def _add_one(

index = index.astype(jnp.int32)

# get cell data
cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes)
cell_fitness = carry.fitnesses[index]
cell_descriptor = carry.descriptors[index]
# get current repertoire cell data
cell_genotype = jax.tree_util.tree_map(
lambda x: x[index][0], carry.genotypes
)
cell_fitness = carry.fitnesses[index][0]
cell_descriptor = carry.descriptors[index][0]
cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)

new_genotypes = jax.tree_util.tree_map(
lambda x: jnp.expand_dims(x, axis=0), genotype
)

# update pareto front
(
cell_fitness,
cell_genotype,
cell_genotype, # new pf for cell
cell_descriptor,
cell_mask,
) = self._update_masked_pareto_front(
pareto_front_fitnesses=cell_fitness.squeeze(axis=0),
pareto_front_genotypes=cell_genotype.squeeze(axis=0),
pareto_front_descriptors=cell_descriptor.squeeze(axis=0),
mask=cell_mask.squeeze(axis=0),
pareto_front_fitnesses=cell_fitness,
pareto_front_genotypes=cell_genotype,
pareto_front_descriptors=cell_descriptor,
mask=cell_mask,
new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0),
new_batch_of_genotypes=new_genotypes,
new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
new_mask=jnp.zeros(shape=(1,), dtype=bool),
)
Expand Down

0 comments on commit 5ee7823

Please sign in to comment.