From 5cac247b7b8b8ce52e22cf06e585e9c3e86eb485 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Fri, 20 Sep 2024 14:19:16 +0000 Subject: [PATCH] fix key management in examples using pmap --- examples/me_sac_pbt.ipynb | 4 ++++ examples/me_td3_pbt.ipynb | 4 ++++ examples/sac_pbt.ipynb | 4 ++++ examples/td3_pbt.ipynb | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 068acab3..42c46188 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -311,6 +311,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index 71289a96..8caca62f 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -314,6 +314,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)" ] }, diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index b1ab220b..53b526db 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -269,6 +269,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)" diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index 5de9e24b..3bbf237e 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -232,6 +232,10 @@ " observation_size=env.observation_size,\n", " buffer_size=buffer_size,\n", ")\n", + "\n", + "# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n", + "keys = jax.random.key_data(keys)\n", + "\n", "keys, training_states, replay_buffers = jax.pmap(\n", " agent_init_fn, axis_name=\"p\", devices=devices\n", ")(keys)"