Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable environment variables to guard against hangs with Triton gemms #624

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ ENV XLA_FLAGS=""
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_async_all_gather=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_async_reduce_scatter=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_NVLS_ENABLE=0
ENV CUDA_MODULE_LOADING=EAGER
ENV JAX_SHARE_BINARY_BETWEEN_HOSTS=True
ENV JAX_SHARE_AUTOTUNE_CONFIG_BETWEEN_HOSTS=True

ADD --chmod=777 create-distribution.sh ${DEST_MANIFEST_DIR}/

Expand Down
36 changes: 18 additions & 18 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
jax:
url: https://github.com/google/jax.git
tracking_ref: main
latest_verified_commit: 75cdef7626b92b8b6563ea68ae4747fd6994db2e
latest_verified_commit: f0afc1b43df7b4a469f447c4244c4b1c45165b70
mode: git-clone
xla:
url: https://github.com/openxla/xla.git
tracking_ref: main
latest_verified_commit: 831e9cef85493ff7ee2e24fd4cc64377d682aecc
latest_verified_commit: 0cfae60de663e2b227f8605a457c4072c4f4d6d8
mode: git-clone
flax:
url: https://github.com/google/flax.git
mirror_url: https://github.com/nvjax-svc-0/flax.git
tracking_ref: main
latest_verified_commit: aaf130c90eb46160a3234c258a48bf1b932d7829
latest_verified_commit: e4282ee187efbefbde268c6873c592a352f56313
mode: git-clone
patches:
pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
transformer-engine:
url: https://github.com/NVIDIA/TransformerEngine.git
tracking_ref: main
latest_verified_commit: 9b2fed514ea419141146f843ab2c84b22b86bfd7
latest_verified_commit: 8255f87f3ee8076db21777795ce15b6ddf8754c0
mode: git-clone
t5x:
url: https://github.com/google-research/t5x.git
mirror_url: https://github.com/nvjax-svc-0/t5x.git
tracking_ref: main
latest_verified_commit: ecb126e1f5c2aea648f39869d4e69fb4374a4868
latest_verified_commit: ecc6369c95b7b3066a46af5050c8ff9113eb219b
mode: git-clone
patches:
mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore
Expand All @@ -43,7 +43,7 @@ praxis:
url: https://github.com/google/praxis.git
mirror_url: https://github.com/nvjax-svc-0/praxis.git
tracking_ref: main
latest_verified_commit: c4271181833d540ea22b1e3875e2bd54951763e9
latest_verified_commit: 493eb5c7c67b0c8da91fff423a7a1cc7bdb614dc
mode: git-clone
patches:
pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas.
Expand All @@ -52,7 +52,7 @@ lingvo:
# Used only in ARM pax builds
url: https://github.com/tensorflow/lingvo.git
tracking_ref: master
latest_verified_commit: 5bbe38c046519b86fa5c0488f813ffbf8b467d7e
latest_verified_commit: 05a076b0783a8bbf4a770095966c472bb37bbf65
mode: git-clone
tensorflow-text:
# Used only in ARM pax and t5x builds
Expand All @@ -73,12 +73,12 @@ fiddle:
airio:
url: https://github.com/google/airio.git
tracking_ref: main
latest_verified_commit: e4c682e691354d75a6bea521cd61709b1ab81d34
latest_verified_commit: 9051cd3375de479df41e0385d200da5101285e55
mode: pip-vcs
clu:
url: https://github.com/google/CommonLoopUtils.git
tracking_ref: main
latest_verified_commit: eed40a1facd526df0e0faa192525f357a3321dca
latest_verified_commit: 0b961e5390da283448b5fc4632b61641980e0a0e
mode: pip-vcs
dllogger:
url: https://github.com/NVIDIA/dllogger.git
Expand All @@ -93,12 +93,12 @@ jestimator:
optax:
url: https://github.com/google-deepmind/optax.git
tracking_ref: main
latest_verified_commit: 623609c7a77a19d48b021cbc300262308846317e
latest_verified_commit: 18110a40153097d15015af2d5d2ee73a86c2a9df
mode: pip-vcs
seqio:
url: https://github.com/google/seqio.git
tracking_ref: main
latest_verified_commit: e31af8c1a11f749edeac512f34d148b9933f863f
latest_verified_commit: 11706e4a1e01a81ea6b3e02c5ad147028d5b94bb
mode: pip-vcs
# used by Pallas
openxla-triton:
Expand All @@ -109,38 +109,38 @@ openxla-triton:
jax-triton:
url: https://github.com/jax-ml/jax-triton.git
tracking_ref: main
latest_verified_commit: 708d3e8afe13b52e4191ad3b677c6f1238677c9e
latest_verified_commit: 08f10633740924de30d85a65386e226b5cbbe564
mode: git-clone
maxtext:
url: https://github.com/google/maxtext.git
tracking_ref: main
latest_verified_commit: 5420bc5753fec4b3a811664cdb58f3c9e98d35fb
latest_verified_commit: fac91938e2f23670ced1f6569557927d31185c9e
mode: git-clone
levanter:
url: https://github.com/stanford-crfm/levanter.git
tracking_ref: main
latest_verified_commit: 94a432e7999ae016645bc72e9dda55e724d0f834
latest_verified_commit: 26577417d6c48144e30d97b6aceaf5da917c4f79
mode: git-clone
haliax:
url: https://github.com/stanford-crfm/haliax.git
tracking_ref: main
latest_verified_commit: 690623131e107972ec2ec67d6183c77649d4b7e0
latest_verified_commit: b94cef72e8bafec57e170fceb9ee86bbe2b1bcfa
mode: git-clone
mujoco:
url: https://github.com/google-deepmind/mujoco.git
tracking_ref: main
latest_verified_commit: c6a41fbfe64ee7b2680a6bde90200ca660d08c2a
latest_verified_commit: ef3cb7ec2e75da93e3307b88b53923ee8f8c7f66
mode: git-clone
grain:
# Used only in ARM t5x builds
url: https://github.com/google/grain.git
tracking_ref: main
latest_verified_commit: f58031724ff06bcc84943c9a8ec501c8941dd660
latest_verified_commit: 98531edd5332c3f212853f88513ac55f29be96b7
mode: git-clone
mujoco-mpc:
url: https://github.com/google-deepmind/mujoco_mpc.git
tracking_ref: main
latest_verified_commit: 50a0159cbc70b38a7fee425b8bf5edbc04a1b62e
latest_verified_commit: 4700f4a13be18398f5aaf6a33ed42e531967e3ae
mode: git-clone
language-to-reward-2023:
url: https://github.com/google-deepmind/language_to_reward_2023.git
Expand Down
6 changes: 3 additions & 3 deletions .github/container/patches/flax/PR-3340.patch
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ index abfbfb5a..bab40243 100644


--
2.25.1
2.34.1


From c945c2ff513282b4af2e956c9c09c784e6d48c44 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -444,7 +444,7 @@ index 4656abf9..187ab6f5 100644
else:
bias = None
--
2.25.1
2.34.1


From 8b184f603e31feabb7580f1a969e101a7fe9e992 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -495,5 +495,5 @@ index bab40243..1e1169a0 100644

field = dataclasses.field
--
2.25.1
2.34.1

18 changes: 9 additions & 9 deletions .github/container/patches/paxml/PR-46.patch
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ index 933784a..70247a7 100644

train_state_partition_specs = (
--
2.25.1
2.34.1


From f80c7b08e6eda1946821e779307a9388e714d57a Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -688,7 +688,7 @@ index d44ca67..2b9dba4 100644
assert self.packed_input == False
assert len(self.moe_layers) == 0
--
2.25.1
2.34.1


From 4aba0e9ff9622962cf586a02ca1f399b991c62ab Mon Sep 17 00:00:00 2001
Expand All @@ -715,7 +715,7 @@ index 2b9dba4..ef20305 100644
return x_out

--
2.25.1
2.34.1


From d1db96e4cb34f8bda22bc79b4180157cb60dc849 Mon Sep 17 00:00:00 2001
Expand All @@ -742,7 +742,7 @@ index 70247a7..0a31c30 100644
vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt(
mdl_vars, excluded_for_learner
--
2.25.1
2.34.1


From f70581b80d35979804fd4250a3dcc4ad508f618d Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -779,7 +779,7 @@ index ef20305..fed1601 100644
finally:
pass
--
2.25.1
2.34.1


From 9f33a95fbbf26d7453016dd58da5e5b29031a6d1 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -946,7 +946,7 @@ index fed1601..5914e54 100644
def update_fp8_metas_if_needed(mdl_vars, grads):
return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads)
--
2.25.1
2.34.1


From d7126f9906e9bba110f482e46eb81e068153e8e8 Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1107,7 +1107,7 @@ index 0a31c30..d3df86b 100644
grads, states.opt_states[0], vars_with_opt, wps_with_opt
)
--
2.25.1
2.34.1


From 9ee9830083cba5f9441af37c8af1c8f66a4d195b Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1499,7 +1499,7 @@ index fd482df..b271258 100644
@contextmanager
def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"):
--
2.25.1
2.34.1


From ed1b33f85c9f4d2fa55a97ee0777f6ba67efefda Mon Sep 17 00:00:00 2001
Expand Down Expand Up @@ -1534,5 +1534,5 @@ index b271258..cbac7cf 100644

class TransformerEngineHelperBase:
--
2.25.1
2.34.1

2 changes: 1 addition & 1 deletion .github/container/patches/praxis/PR-27.patch
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ index a35ce8b..52886bc 100644
self.add_summary('attention_mask', atten_mask)
if self.attention_extra_logit is None:
--
2.25.1
2.34.1

Loading
Loading