Skip to content

Commit

Permalink
fix key management in examples using pmap
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Sep 20, 2024
1 parent 2040636 commit 5cac247
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)"
]
},
Expand Down
4 changes: 4 additions & 0 deletions examples/me_td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)"
]
},
Expand Down
4 changes: 4 additions & 0 deletions examples/sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, replay_buffers = jax.pmap(\n",
" agent_init_fn, axis_name=\"p\", devices=devices\n",
")(keys)"
Expand Down
4 changes: 4 additions & 0 deletions examples/td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, replay_buffers = jax.pmap(\n",
" agent_init_fn, axis_name=\"p\", devices=devices\n",
")(keys)"
Expand Down

0 comments on commit 5cac247

Please sign in to comment.