Skip to content

Commit

Permalink
tests: include scan_reeval and remove brax environment
Browse files Browse the repository at this point in the history
  • Loading branch information
manon-but-yes committed Sep 2, 2024
1 parent d8ee819 commit 2e14e2e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 88 deletions.
5 changes: 5 additions & 0 deletions qdax/utils/uncertainty_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ def _perform_reevaluation(

# If need for scan, call the sampling function multiple times
else:

# Ensure that num_reevals is a multiple of scan_size
assert (
num_reevals % scan_size == 0
), "num_reevals should be a multiple of scan_size to be able to scan."
num_loops = num_reevals // scan_size

def _sampling_scan(
Expand Down
127 changes: 39 additions & 88 deletions tests/utils_test/uncertainty_metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import functools
from typing import Callable, Tuple

import jax
import jax.numpy as jnp
import pytest

from qdax import environments
from qdax.core.containers.mapelites_repertoire import (
MapElitesRepertoire,
compute_cvt_centroids,
)
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.tasks.arm import arm_scoring_function
from qdax.tasks.brax_envs import scoring_function_brax_envs
from qdax.types import EnvState, Params, RNGKey
from qdax.tasks.arm import arm_scoring_function, noisy_arm_scoring_function
from qdax.utils.uncertainty_metrics import (
reevaluation_function,
reevaluation_reproducibility_function,
Expand All @@ -24,22 +18,19 @@
def test_uncertainty_metrics() -> None:
seed = 42
num_reevals = 512
scan_size = 128
batch_size = 512
num_init_cvt_samples = 50000
num_centroids = 1024
genotype_dim = 8

# Init a random key
random_key = jax.random.PRNGKey(seed)

# First, init a deterministic environment
genotype_dim = 8

# Init policies
init_policies = jax.random.uniform(
random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1
)

# Evaluate in the deterministic environment
fitnesses, descriptors, extra_scores, random_key = arm_scoring_function(
init_policies, random_key
)
Expand Down Expand Up @@ -85,6 +76,21 @@ def test_uncertainty_metrics() -> None:
)
)

# Test that scanned reevaluation_function accurately predicts no change
corrected_repertoire, random_key = reevaluation_function(
repertoire=repertoire,
empty_corrected_repertoire=empty_corrected_repertoire,
scoring_fn=arm_scoring_function,
num_reevals=num_reevals,
random_key=random_key,
scan_size=scan_size,
)
pytest.assume(
jnp.allclose(
corrected_repertoire.fitnesses, repertoire.fitnesses, rtol=1e-05, atol=1e-05
)
)

# Test that reevaluation_reproducibility_function accurately predicts no change
(
corrected_repertoire,
Expand Down Expand Up @@ -125,80 +131,27 @@ def test_uncertainty_metrics() -> None:
)
)

# Second, init a Brax environment
env_name = "walker2d_uni"
episode_length = 100
policy_hidden_layer_sizes = (64, 64)
env = environments.create(env_name, episode_length=episode_length)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jnp.tanh,
# Second, init a stochastic environment
init_policies = jax.random.uniform(
random_key, shape=(batch_size, genotype_dim), minval=0, maxval=1
)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_policies = jax.vmap(policy_network.init)(keys, fake_batch)

# Define the fonction to play a step with the policy in the environment
def play_step_fn(
env_state: EnvState,
policy_params: Params,
random_key: RNGKey,
) -> Tuple[EnvState, Params, RNGKey, QDTransition]:

actions = policy_network.apply(policy_params, env_state.obs)

state_desc = env_state.info["state_descriptor"]
next_state = env.step(env_state, actions)

transition = QDTransition(
obs=env_state.obs,
next_obs=next_state.obs,
rewards=next_state.reward,
dones=next_state.done,
actions=actions,
truncations=next_state.info["truncation"],
state_desc=state_desc,
next_state_desc=next_state.info["state_descriptor"],
)

return next_state, policy_params, random_key, transition

# Create the initial environment states for samples and final indivs
reset_fn = jax.jit(jax.vmap(env.reset))
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
init_states = reset_fn(keys)

# Create the scoring function
bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
brax_scoring_fn: Callable = functools.partial(
scoring_function_brax_envs,
init_states=init_states,
episode_length=episode_length,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=bd_extraction_fn,
noisy_scoring_function = functools.partial(
noisy_arm_scoring_function,
fit_variance=0.01,
desc_variance=0.01,
params_variance=0.0,
)

# Evaluate in the Brax environment
fitnesses, descriptors, extra_scores, random_key = brax_scoring_fn(
fitnesses, descriptors, extra_scores, random_key = noisy_scoring_function(
init_policies, random_key
)

# Initialise a container
min_bd, max_bd = env.behavior_descriptor_limits
centroids, random_key = compute_cvt_centroids(
num_descriptors=env.behavior_descriptor_length,
num_descriptors=2,
num_init_cvt_samples=num_init_cvt_samples,
num_centroids=num_centroids,
minval=min_bd,
maxval=max_bd,
minval=jnp.array([0.0, 0.0]),
maxval=jnp.array([1.0, 1.0]),
random_key=random_key,
)
repertoire = MapElitesRepertoire.init(
Expand All @@ -220,20 +173,18 @@ def play_step_fn(
)

# Test that reevaluation_function runs and keeps at least one solution
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=num_centroids, axis=0)
init_states = reset_fn(keys)
reeval_brax_scoring_fn: Callable = functools.partial(
scoring_function_brax_envs,
init_states=init_states,
episode_length=episode_length,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=bd_extraction_fn,
)
corrected_repertoire, random_key = reevaluation_function(
(
corrected_repertoire,
fit_reproducibility_repertoire,
desc_reproducibility_repertoire,
random_key,
) = reevaluation_reproducibility_function(
repertoire=repertoire,
empty_corrected_repertoire=empty_corrected_repertoire,
scoring_fn=reeval_brax_scoring_fn,
scoring_fn=noisy_scoring_function,
num_reevals=num_reevals,
random_key=random_key,
)
pytest.assume(jnp.any(corrected_repertoire.fitnesses > -jnp.inf))
pytest.assume(jnp.any(fit_reproducibility_repertoire.fitnesses > -jnp.inf))
pytest.assume(jnp.any(desc_reproducibility_repertoire.fitnesses > -jnp.inf))

0 comments on commit 2e14e2e

Please sign in to comment.