diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index 0449adadf..61fa07d2e 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -59,10 +59,11 @@ ENV XLA_FLAGS="" ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true" ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_async_all_gather=true" ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_async_reduce_scatter=true" -ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false" ENV CUDA_DEVICE_MAX_CONNECTIONS=1 ENV NCCL_NVLS_ENABLE=0 ENV CUDA_MODULE_LOADING=EAGER +ENV JAX_SHARE_BINARY_BETWEEN_HOSTS=True +ENV JAX_SHARE_AUTOTUNE_CONFIG_BETWEEN_HOSTS=True ADD --chmod=777 create-distribution.sh ${DEST_MANIFEST_DIR}/ diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 5b84c91aa..40ca19eac 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -1,31 +1,31 @@ jax: url: https://github.com/google/jax.git tracking_ref: main - latest_verified_commit: 75cdef7626b92b8b6563ea68ae4747fd6994db2e + latest_verified_commit: f0afc1b43df7b4a469f447c4244c4b1c45165b70 mode: git-clone xla: url: https://github.com/openxla/xla.git tracking_ref: main - latest_verified_commit: 831e9cef85493ff7ee2e24fd4cc64377d682aecc + latest_verified_commit: 0cfae60de663e2b227f8605a457c4072c4f4d6d8 mode: git-clone flax: url: https://github.com/google/flax.git mirror_url: https://github.com/nvjax-svc-0/flax.git tracking_ref: main - latest_verified_commit: aaf130c90eb46160a3234c258a48bf1b932d7829 + latest_verified_commit: e4282ee187efbefbde268c6873c592a352f56313 mode: git-clone patches: pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules transformer-engine: url: https://github.com/NVIDIA/TransformerEngine.git tracking_ref: main - latest_verified_commit: 9b2fed514ea419141146f843ab2c84b22b86bfd7 + latest_verified_commit: 8255f87f3ee8076db21777795ce15b6ddf8754c0 mode: git-clone t5x: url: https://github.com/google-research/t5x.git mirror_url: https://github.com/nvjax-svc-0/t5x.git tracking_ref: main - latest_verified_commit: ecb126e1f5c2aea648f39869d4e69fb4374a4868 + latest_verified_commit: ecc6369c95b7b3066a46af5050c8ff9113eb219b mode: git-clone patches: mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore @@ -43,7 +43,7 @@ praxis: url: https://github.com/google/praxis.git mirror_url: https://github.com/nvjax-svc-0/praxis.git tracking_ref: main - latest_verified_commit: c4271181833d540ea22b1e3875e2bd54951763e9 + latest_verified_commit: 493eb5c7c67b0c8da91fff423a7a1cc7bdb614dc mode: git-clone patches: pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas. @@ -52,7 +52,7 @@ lingvo: # Used only in ARM pax builds url: https://github.com/tensorflow/lingvo.git tracking_ref: master - latest_verified_commit: 5bbe38c046519b86fa5c0488f813ffbf8b467d7e + latest_verified_commit: 05a076b0783a8bbf4a770095966c472bb37bbf65 mode: git-clone tensorflow-text: # Used only in ARM pax and t5x builds @@ -73,12 +73,12 @@ fiddle: airio: url: https://github.com/google/airio.git tracking_ref: main - latest_verified_commit: e4c682e691354d75a6bea521cd61709b1ab81d34 + latest_verified_commit: 9051cd3375de479df41e0385d200da5101285e55 mode: pip-vcs clu: url: https://github.com/google/CommonLoopUtils.git tracking_ref: main - latest_verified_commit: eed40a1facd526df0e0faa192525f357a3321dca + latest_verified_commit: 0b961e5390da283448b5fc4632b61641980e0a0e mode: pip-vcs dllogger: url: https://github.com/NVIDIA/dllogger.git @@ -93,12 +93,12 @@ jestimator: optax: url: https://github.com/google-deepmind/optax.git tracking_ref: main - latest_verified_commit: 623609c7a77a19d48b021cbc300262308846317e + latest_verified_commit: 18110a40153097d15015af2d5d2ee73a86c2a9df mode: pip-vcs seqio: url: https://github.com/google/seqio.git tracking_ref: main - latest_verified_commit: e31af8c1a11f749edeac512f34d148b9933f863f + latest_verified_commit: 11706e4a1e01a81ea6b3e02c5ad147028d5b94bb mode: pip-vcs # used by Pallas openxla-triton: @@ -109,38 +109,38 @@ openxla-triton: jax-triton: url: https://github.com/jax-ml/jax-triton.git tracking_ref: main - latest_verified_commit: 708d3e8afe13b52e4191ad3b677c6f1238677c9e + latest_verified_commit: 08f10633740924de30d85a65386e226b5cbbe564 mode: git-clone maxtext: url: https://github.com/google/maxtext.git tracking_ref: main - latest_verified_commit: 5420bc5753fec4b3a811664cdb58f3c9e98d35fb + latest_verified_commit: fac91938e2f23670ced1f6569557927d31185c9e mode: git-clone levanter: url: https://github.com/stanford-crfm/levanter.git tracking_ref: main - latest_verified_commit: 94a432e7999ae016645bc72e9dda55e724d0f834 + latest_verified_commit: 26577417d6c48144e30d97b6aceaf5da917c4f79 mode: git-clone haliax: url: https://github.com/stanford-crfm/haliax.git tracking_ref: main - latest_verified_commit: 690623131e107972ec2ec67d6183c77649d4b7e0 + latest_verified_commit: b94cef72e8bafec57e170fceb9ee86bbe2b1bcfa mode: git-clone mujoco: url: https://github.com/google-deepmind/mujoco.git tracking_ref: main - latest_verified_commit: c6a41fbfe64ee7b2680a6bde90200ca660d08c2a + latest_verified_commit: ef3cb7ec2e75da93e3307b88b53923ee8f8c7f66 mode: git-clone grain: # Used only in ARM t5x builds url: https://github.com/google/grain.git tracking_ref: main - latest_verified_commit: f58031724ff06bcc84943c9a8ec501c8941dd660 + latest_verified_commit: 98531edd5332c3f212853f88513ac55f29be96b7 mode: git-clone mujoco-mpc: url: https://github.com/google-deepmind/mujoco_mpc.git tracking_ref: main - latest_verified_commit: 50a0159cbc70b38a7fee425b8bf5edbc04a1b62e + latest_verified_commit: 4700f4a13be18398f5aaf6a33ed42e531967e3ae mode: git-clone language-to-reward-2023: url: https://github.com/google-deepmind/language_to_reward_2023.git diff --git a/.github/container/patches/flax/PR-3340.patch b/.github/container/patches/flax/PR-3340.patch index d19f134be..f215a0e25 100644 --- a/.github/container/patches/flax/PR-3340.patch +++ b/.github/container/patches/flax/PR-3340.patch @@ -381,7 +381,7 @@ index abfbfb5a..bab40243 100644 -- -2.25.1 +2.34.1 From c945c2ff513282b4af2e956c9c09c784e6d48c44 Mon Sep 17 00:00:00 2001 @@ -444,7 +444,7 @@ index 4656abf9..187ab6f5 100644 else: bias = None -- -2.25.1 +2.34.1 From 8b184f603e31feabb7580f1a969e101a7fe9e992 Mon Sep 17 00:00:00 2001 @@ -495,5 +495,5 @@ index bab40243..1e1169a0 100644 field = dataclasses.field -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/paxml/PR-46.patch b/.github/container/patches/paxml/PR-46.patch index f15905537..697a2a82e 100644 --- a/.github/container/patches/paxml/PR-46.patch +++ b/.github/container/patches/paxml/PR-46.patch @@ -654,7 +654,7 @@ index 933784a..70247a7 100644 train_state_partition_specs = ( -- -2.25.1 +2.34.1 From f80c7b08e6eda1946821e779307a9388e714d57a Mon Sep 17 00:00:00 2001 @@ -688,7 +688,7 @@ index d44ca67..2b9dba4 100644 assert self.packed_input == False assert len(self.moe_layers) == 0 -- -2.25.1 +2.34.1 From 4aba0e9ff9622962cf586a02ca1f399b991c62ab Mon Sep 17 00:00:00 2001 @@ -715,7 +715,7 @@ index 2b9dba4..ef20305 100644 return x_out -- -2.25.1 +2.34.1 From d1db96e4cb34f8bda22bc79b4180157cb60dc849 Mon Sep 17 00:00:00 2001 @@ -742,7 +742,7 @@ index 70247a7..0a31c30 100644 vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt( mdl_vars, excluded_for_learner -- -2.25.1 +2.34.1 From f70581b80d35979804fd4250a3dcc4ad508f618d Mon Sep 17 00:00:00 2001 @@ -779,7 +779,7 @@ index ef20305..fed1601 100644 finally: pass -- -2.25.1 +2.34.1 From 9f33a95fbbf26d7453016dd58da5e5b29031a6d1 Mon Sep 17 00:00:00 2001 @@ -946,7 +946,7 @@ index fed1601..5914e54 100644 def update_fp8_metas_if_needed(mdl_vars, grads): return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) -- -2.25.1 +2.34.1 From d7126f9906e9bba110f482e46eb81e068153e8e8 Mon Sep 17 00:00:00 2001 @@ -1107,7 +1107,7 @@ index 0a31c30..d3df86b 100644 grads, states.opt_states[0], vars_with_opt, wps_with_opt ) -- -2.25.1 +2.34.1 From 9ee9830083cba5f9441af37c8af1c8f66a4d195b Mon Sep 17 00:00:00 2001 @@ -1499,7 +1499,7 @@ index fd482df..b271258 100644 @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): -- -2.25.1 +2.34.1 From ed1b33f85c9f4d2fa55a97ee0777f6ba67efefda Mon Sep 17 00:00:00 2001 @@ -1534,5 +1534,5 @@ index b271258..cbac7cf 100644 class TransformerEngineHelperBase: -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/praxis/PR-27.patch b/.github/container/patches/praxis/PR-27.patch index 516d2fe74..997d17d2e 100644 --- a/.github/container/patches/praxis/PR-27.patch +++ b/.github/container/patches/praxis/PR-27.patch @@ -34,5 +34,5 @@ index a35ce8b..52886bc 100644 self.add_summary('attention_mask', atten_mask) if self.attention_extra_logit is None: -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/praxis/PR-36.patch b/.github/container/patches/praxis/PR-36.patch index d73d8214d..1161a36fd 100644 --- a/.github/container/patches/praxis/PR-36.patch +++ b/.github/container/patches/praxis/PR-36.patch @@ -1,7 +1,7 @@ From 41488517eb6d95eb7943681e706c8804e6102c41 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 15 Nov 2023 11:38:27 +0800 -Subject: [PATCH 1/3] Adding TE support +Subject: [PATCH 01/10] Adding TE support --- praxis/contrib/gpu/scripts_gpu/te_helper.py | 176 ++++++++++++++++++++ @@ -247,13 +247,13 @@ index ab6cff3..c79dac9 100644 # Annotate the inputs before the pipeline to prevent unexpected # propagation from earlier layers. -- -2.25.1 +2.34.1 From ff1745796009cf1ec59f463f8e776c66f1286938 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 17 Nov 2023 15:21:06 +0800 -Subject: [PATCH 2/3] Fix missing vars wiht PP. +Subject: [PATCH 02/10] Fix missing vars wiht PP. --- praxis/contrib/gpu/scripts_gpu/te_helper.py | 34 ++++++++++++--------- @@ -358,13 +358,13 @@ index e3b2f7c..b31526e 100644 trans_in_fn=_get_to_f32_converter(bf16_vars_to_convert), trans_out_fn=_get_to_bf16_converter(bf16_vars_to_convert), -- -2.25.1 +2.34.1 From 99e26aaf14131ca73501f08162be628b55c86a88 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 17 Jan 2024 02:36:31 -0800 -Subject: [PATCH 3/3] Add checkpoint_policy checker for fused attn + dropout +Subject: [PATCH 03/10] Add checkpoint_policy checker for fused attn + dropout Signed-off-by: Reese Wang --- @@ -476,5 +476,436 @@ index c79dac9..e076530 100644 repeats.Repeat, sub_tpl=self.block, -- -2.25.1 +2.34.1 + + +From ab12f857404d84ed423e095d59e0bd336b94f151 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Mon, 15 Jan 2024 08:25:59 -0800 +Subject: [PATCH 04/10] Support more TE configurations + +Signed-off-by: Reese Wang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 84 +++++++++++++++++++-- + 1 file changed, 76 insertions(+), 8 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index 290b74c..9defcbd 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -5,6 +5,9 @@ from praxis import pax_fiddle + from praxis import pytypes + from praxis import layers + from praxis.layers.checkpoint_policy import AutodiffCheckpointType ++from praxis.layers import activations ++from praxis.layers import attentions, grouped_query_attention, multi_query_attention ++from praxis.layers import normalizations + + try: + import transformer_engine.jax as te +@@ -112,19 +115,85 @@ class TEInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def set_layer_params_to_stack_transformer(stacked_transformer_obj, _, layer_id): ++ enable_sp = bool(int(os.environ.get('ENABLE_TE_SP', 0))) + te_transformer_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, + name=f'layer_{layer_id}', +- layernorm_type='layernorm', +- zero_centered_gamma = True, +- mlp_activations=('gelu',), +- use_bias=True, +- self_attn_mask_type='causal', + enable_relative_embedding=False, +- scaled_query_init=False, +- scale_attn_logits=True, ++ enable_sequence_parallel=enable_sp, + transpose_batch_sequence=False + ) + ++ def update_ln_te_tpl(te_tpl, transformer_layer_tpl): ++ # TE requires all normalization are the same ++ assert transformer_layer_tpl.ln_tpl == transformer_layer_tpl.tr_fflayer_tpl.ln_tpl ++ ln_tpl = transformer_layer_tpl.ln_tpl ++ if issubclass(ln_tpl.cls, normalizations.LayerNorm): ++ te_tpl.layernorm_type = 'layernorm' ++ assert ln_tpl.use_scale ++ assert ln_tpl.use_bias ++ elif issubclass(ln_tpl.cls, normalizations.RmsNorm): ++ te_tpl.layernorm_type = 'rmsnorm' ++ else: ++ raise ValueError(f'Unsupported {ln_tpl.cls=}, LayerNorm, RmsNorm are supported.') ++ te_tpl.zero_centered_gamma = not ln_tpl.direct_scale ++ te_tpl.layernorm_epsilon = ln_tpl.epsilon ++ return te_tpl ++ ++ def update_ff_te_tpl(te_tpl, ff_layer_tpl): ++ if issubclass(ff_layer_tpl.activation_tpl.cls, activations.Identity): ++ mlp_activations = ('linear',) ++ else: ++ mlp_activations = (ff_layer_tpl.activation_tpl.cls.__name__.lower(),) ++ if ff_layer_tpl.use_gated_activation: ++ mlp_activations += ('linear',) ++ te_tpl.mlp_activations = mlp_activations ++ return te_tpl ++ ++ def update_attn_te_tpl(te_tpl, attn_tpl): ++ # TODO(rewang): rope ++ if issubclass(attn_tpl.cls, attentions.DotProductAttention): ++ # Check the DotProductAttention parameters are aligned to TE's attention ++ assert attn_tpl.internal_enable_query_scale or attn_tpl.scale_logits_by_head_dims ++ assert not (attn_tpl.internal_enable_query_scale and attn_tpl.scale_logits_by_head_dims) ++ assert not attn_tpl.internal_enable_per_dim_scale ++ assert not attn_tpl.scale_query_by_dim_per_head ++ assert not attn_tpl.dconv_qkv ++ assert not attn_tpl.internal_gshard_gaussian_init ++ assert not attn_tpl.use_rotary_position_emb ++ assert attn_tpl.relative_bias_tpl is None ++ assert attn_tpl.attention_extra_logit is None ++ assert attn_tpl.ngrammer_tpl is None ++ te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb ++ elif issubclass(attn_tpl.cls, grouped_query_attention.GroupedQueryAttention): ++ te_tpl.num_gqa_groups = attn_tpl.num_kv_heads ++ if attn_tpl.rope_min_max_timescales is not None: ++ te_tpl.enable_rotary_pos_emb = True ++ te_tpl.rotary_pos_emb_windows = attn_tpl.rope_min_max_timescales ++ assert attn_tpl.atten_temp == 1. ++ elif issubclass(attn_tpl.cls, multi_query_attention.MultiQueryDotProductAttention): ++ te_tpl.num_gqa_groups = attn_tpl.num_kv_heads ++ te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb ++ else: ++ raise ValueError(f'Unsupported {attn_tpl.cls=}') ++ assert attn_tpl.atten_logit_cap <= 0., 'atten_logit_cap > 0. is not supported in TE' ++ te_tpl.scaled_query_init = False ++ te_tpl.scale_attn_logits = True ++ return te_tpl ++ ++ transformer_layer_tpl = stacked_transformer_obj.transformer_layer_params_tpl ++ # Update TE normalization layer configs ++ te_transformer_tpl = update_ln_te_tpl(te_transformer_tpl, transformer_layer_tpl) ++ # Update TE feed forward layer configs ++ te_transformer_tpl = update_ff_te_tpl(te_transformer_tpl, transformer_layer_tpl.tr_fflayer_tpl) ++ # Update TE attention layer configs ++ te_transformer_tpl = update_attn_te_tpl(te_transformer_tpl, transformer_layer_tpl.tr_atten_tpl) ++ # TE currently only allow the bias config to be same between feed forward, qkv proj, out proj ++ assert (transformer_layer_tpl.tr_fflayer_tpl.has_bias == ++ transformer_layer_tpl.tr_atten_tpl.use_bias), "TE only allows same bias settings." ++ te_transformer_tpl.use_bias = transformer_layer_tpl.tr_fflayer_tpl.has_bias ++ te_transformer_tpl.self_attn_mask_type = 'causal' \ ++ if stacked_transformer_obj.mask_self_attention else 'padding' ++ + te_transformer_tpl.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) + te_transformer_tpl.params_init = stacked_transformer_obj.params_init + te_transformer_tpl.dtype = stacked_transformer_obj.fprop_dtype +@@ -133,7 +202,6 @@ class TEInstalledHelper(TransformerEngineHelperBase): + te_transformer_tpl.num_attention_heads = stacked_transformer_obj.num_heads + te_transformer_tpl.hidden_size = stacked_transformer_obj.model_dims + te_transformer_tpl.mlp_hidden_size = stacked_transformer_obj.hidden_dims +- te_transformer_tpl.layernorm_epsilon = stacked_transformer_obj.transformer_layer_params_tpl.ln_tpl.epsilon + te_transformer_tpl.dropout_rng_name = base_layer.RANDOM + te_transformer_tpl.attention_dropout = stacked_transformer_obj.atten_dropout_prob or stacked_transformer_obj.dropout_prob + te_transformer_tpl.hidden_dropout = stacked_transformer_obj.residual_dropout_prob or stacked_transformer_obj.dropout_prob +-- +2.34.1 + + +From dc632f0b2bf8da8724e4959360da034b3a7b4075 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Sun, 18 Feb 2024 01:14:41 -0800 +Subject: [PATCH 05/10] Change the gated activations orders + +Signed-off-by: Reese Wang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 11 +++++++---- + 1 file changed, 7 insertions(+), 4 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index 9defcbd..1f8f6d6 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -140,12 +140,15 @@ class TEInstalledHelper(TransformerEngineHelperBase): + return te_tpl + + def update_ff_te_tpl(te_tpl, ff_layer_tpl): +- if issubclass(ff_layer_tpl.activation_tpl.cls, activations.Identity): +- mlp_activations = ('linear',) +- else: +- mlp_activations = (ff_layer_tpl.activation_tpl.cls.__name__.lower(),) ++ mlp_activations = () + if ff_layer_tpl.use_gated_activation: ++ mlp_activations += ('linear',) ++ ++ if issubclass(ff_layer_tpl.activation_tpl.cls, activations.Identity): + mlp_activations += ('linear',) ++ else: ++ mlp_activations += (ff_layer_tpl.activation_tpl.cls.__name__.lower(),) ++ + te_tpl.mlp_activations = mlp_activations + return te_tpl + +-- +2.34.1 + + +From f2e8560e6861a6dea981209b402b27dc6bc92022 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Wed, 28 Feb 2024 22:37:46 -0800 +Subject: [PATCH 06/10] Remove RoPE restriction from DPA module + +Signed-off-by: Reese Wang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 2 -- + 1 file changed, 2 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index 1f8f6d6..7d83c08 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -153,7 +153,6 @@ class TEInstalledHelper(TransformerEngineHelperBase): + return te_tpl + + def update_attn_te_tpl(te_tpl, attn_tpl): +- # TODO(rewang): rope + if issubclass(attn_tpl.cls, attentions.DotProductAttention): + # Check the DotProductAttention parameters are aligned to TE's attention + assert attn_tpl.internal_enable_query_scale or attn_tpl.scale_logits_by_head_dims +@@ -162,7 +161,6 @@ class TEInstalledHelper(TransformerEngineHelperBase): + assert not attn_tpl.scale_query_by_dim_per_head + assert not attn_tpl.dconv_qkv + assert not attn_tpl.internal_gshard_gaussian_init +- assert not attn_tpl.use_rotary_position_emb + assert attn_tpl.relative_bias_tpl is None + assert attn_tpl.attention_extra_logit is None + assert attn_tpl.ngrammer_tpl is None +-- +2.34.1 + + +From 3d6b5c34a64939dadcc751cf2e989cec1affc648 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Wed, 28 Feb 2024 22:56:44 -0800 +Subject: [PATCH 07/10] Add rotary_pos_emb_group dispatch + +Signed-off-by: Reese Wang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 5 +++++ + 1 file changed, 5 insertions(+) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index 7d83c08..d187ba1 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -7,6 +7,7 @@ from praxis import layers + from praxis.layers.checkpoint_policy import AutodiffCheckpointType + from praxis.layers import activations + from praxis.layers import attentions, grouped_query_attention, multi_query_attention ++from praxis.layers import embedding_softmax + from praxis.layers import normalizations + + try: +@@ -165,6 +166,8 @@ class TEInstalledHelper(TransformerEngineHelperBase): + assert attn_tpl.attention_extra_logit is None + assert attn_tpl.ngrammer_tpl is None + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb ++ if issubclass(attn_tpl.rotary_position_emb_tpl, embedding_softmax.RotaryPositionalEmbedding): ++ te_tpl.rotary_pos_emb_group_method = 'alternate' + elif issubclass(attn_tpl.cls, grouped_query_attention.GroupedQueryAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads + if attn_tpl.rope_min_max_timescales is not None: +@@ -174,6 +177,8 @@ class TEInstalledHelper(TransformerEngineHelperBase): + elif issubclass(attn_tpl.cls, multi_query_attention.MultiQueryDotProductAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb ++ if issubclass(attn_tpl.rotary_position_emb_tpl, embedding_softmax.RotaryPositionalEmbedding): ++ te_tpl.rotary_pos_emb_group_method = 'alternate' + else: + raise ValueError(f'Unsupported {attn_tpl.cls=}') + assert attn_tpl.atten_logit_cap <= 0., 'atten_logit_cap > 0. is not supported in TE' +-- +2.34.1 + + +From 454f760a095562d995e7e9102f97c05158415312 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Sun, 3 Mar 2024 01:16:46 -0800 +Subject: [PATCH 08/10] Fix the missing .cls + +Signed-off-by: Reese Wang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index d187ba1..733e9bf 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -166,7 +166,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + assert attn_tpl.attention_extra_logit is None + assert attn_tpl.ngrammer_tpl is None + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb +- if issubclass(attn_tpl.rotary_position_emb_tpl, embedding_softmax.RotaryPositionalEmbedding): ++ if issubclass(attn_tpl.rotary_position_emb_tpl.cls, embedding_softmax.RotaryPositionalEmbedding): + te_tpl.rotary_pos_emb_group_method = 'alternate' + elif issubclass(attn_tpl.cls, grouped_query_attention.GroupedQueryAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads +@@ -177,7 +177,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + elif issubclass(attn_tpl.cls, multi_query_attention.MultiQueryDotProductAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb +- if issubclass(attn_tpl.rotary_position_emb_tpl, embedding_softmax.RotaryPositionalEmbedding): ++ if issubclass(attn_tpl.rotary_position_emb_tpl.cls, embedding_softmax.RotaryPositionalEmbedding): + te_tpl.rotary_pos_emb_group_method = 'alternate' + else: + raise ValueError(f'Unsupported {attn_tpl.cls=}') +-- +2.34.1 + + +From 0bd4a531a3d8e4ddafb5ed680092fd5636aa9f8e Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Fri, 2 Feb 2024 10:58:40 +0800 +Subject: [PATCH 09/10] Fixed the unexpected input sharding pattern when TE + enabled. + +Signed-off-by: Ming-Xu Huang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 25 +++++++++++++++++++-- + praxis/layers/transformer_models.py | 3 +++ + 2 files changed, 26 insertions(+), 2 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index 733e9bf..ee0fc84 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -26,10 +26,13 @@ try: + + TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST = [te.fp8.FP8Helper.FP8_COLLECTION_NAME] + ++ ENABLE_TE_SP = bool(int(os.environ.get('ENABLE_TE_SP', 0))) ++ + except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False + TE_PIPELINE_EXTRA_VMAP_VAR_AXES = {} + TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST = [] ++ ENABLE_TE_SP = False + + LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] + JTensor = pytypes.JTensor +@@ -45,6 +48,10 @@ class TransformerEngineHelperBase: + def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, layer_id): + raise NotImplementedError + ++ @staticmethod ++ def get_input_bld(original_bld, batch_axes, mdl_axis): ++ raise NotImplementedError ++ + @staticmethod + def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): + raise NotImplementedError +@@ -81,6 +88,10 @@ class TENotInstalledHelper(TransformerEngineHelperBase): + ) + return layer_p + ++ @staticmethod ++ def get_input_bld(original_bld, _, _): ++ return original_bld ++ + @staticmethod + def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): + return xformer_layer_p.tr_atten_tpl.activation_split_dims_mapping.bld +@@ -116,11 +127,10 @@ class TEInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def set_layer_params_to_stack_transformer(stacked_transformer_obj, _, layer_id): +- enable_sp = bool(int(os.environ.get('ENABLE_TE_SP', 0))) + te_transformer_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, + name=f'layer_{layer_id}', + enable_relative_embedding=False, +- enable_sequence_parallel=enable_sp, ++ enable_sequence_parallel=ENABLE_TE_SP, + transpose_batch_sequence=False + ) + +@@ -224,6 +234,12 @@ class TEInstalledHelper(TransformerEngineHelperBase): + + return te_transformer_tpl + ++ @staticmethod ++ def get_input_bld(_, batch_axes, mdl_axis): ++ if ENABLE_TE_SP: ++ return [batch_axes, mdl_axis, None] ++ return [batch_axes, None, None] ++ + @staticmethod + def get_bld_mapping_for_pipelined_transformer(_): + rules = te_flax.extend_logical_axis_rules(tuple()) +@@ -289,6 +305,11 @@ class TransformerEngineHelper(TransformerEngineHelperBase): + return TransformerEngineHelper.get_helper().set_layer_params_to_stack_transformer( + stacked_transformer_obj, layer_p, layer_id) + ++ @staticmethod ++ def get_input_bld(original_bld, batch_axes, mdl_axis): ++ return TransformerEngineHelper.get_helper().get_input_bld( ++ original_bld, batch_axes, mdl_axis) ++ + @staticmethod + def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): + return TransformerEngineHelper.get_helper().get_bld_mapping_for_pipelined_transformer( +diff --git a/praxis/layers/transformer_models.py b/praxis/layers/transformer_models.py +index d8720ae..235d63e 100644 +--- a/praxis/layers/transformer_models.py ++++ b/praxis/layers/transformer_models.py +@@ -33,6 +33,7 @@ from praxis.layers import embedding_softmax + from praxis.layers import multi_query_attention + from praxis.layers import normalizations + from praxis.layers import transformers ++from praxis.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper + + NestedMap = py_utils.NestedMap + JTensor = pytypes.JTensor +@@ -323,6 +324,8 @@ class TransformerLm(base_layer.BaseLayer): + if training_optimized + else [batch_axes, None, None] + ) ++ bld = TransformerEngineHelper.get_input_bld(bld, batch_axes, mdl_axis) ++ + egcm = ( + [data_axis, None, None, mdl_axis] + if training_optimized +-- +2.34.1 + + +From e3e785cfedb4f350cfcb2c43b093f94288dbc846 Mon Sep 17 00:00:00 2001 +From: Ming-Xu Huang +Date: Wed, 7 Feb 2024 12:34:01 +0800 +Subject: [PATCH 10/10] Addind a comment to get_input_bld + +Signed-off-by: Ming-Xu Huang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 3 ++- + 1 file changed, 2 insertions(+), 1 deletion(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index ee0fc84..d0afc1a 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -50,6 +50,7 @@ class TransformerEngineHelperBase: + + @staticmethod + def get_input_bld(original_bld, batch_axes, mdl_axis): ++ # This is used to specify the sharding pattern of inputs to TransformerLayers. + raise NotImplementedError + + @staticmethod +@@ -89,7 +90,7 @@ class TENotInstalledHelper(TransformerEngineHelperBase): + return layer_p + + @staticmethod +- def get_input_bld(original_bld, _, _): ++ def get_input_bld(original_bld, *_): + return original_bld + + @staticmethod +-- +2.34.1 diff --git a/.github/container/patches/t5x/mirror-patch-dali-support.patch b/.github/container/patches/t5x/mirror-patch-dali-support.patch index b1c4e7ad1..28207adc8 100644 --- a/.github/container/patches/t5x/mirror-patch-dali-support.patch +++ b/.github/container/patches/t5x/mirror-patch-dali-support.patch @@ -355,7 +355,7 @@ index 3d592d8..7a321cd 100644 def _warn_action_not_run(action, task, metric): logging.warning( -- -2.25.1 +2.34.1 From 79d36a39921b83271ff75748a211185884744f8b Mon Sep 17 00:00:00 2001 @@ -381,5 +381,5 @@ index e6027ce..ec1c2fa 100644 checkpoint_cfg=checkpoint_cfg, partitioner=partitioner, -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch b/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch index 9250d85e0..4b893d84a 100644 --- a/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch +++ b/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch @@ -25,5 +25,5 @@ index 61682ed..77e0860 100644 ] # 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set. -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch index 1b8b1f846..2d4887380 100644 --- a/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch +++ b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch @@ -2465,7 +2465,7 @@ index 965bd09..cebb815 100644 metrics["learning_rate"] = clu.metrics.Average.from_model_output( jnp.asarray([learning_rate]) -- -2.25.1 +2.34.1 From 5b760642c8f41579baef87c0d03c47f283a3c17a Mon Sep 17 00:00:00 2001 @@ -2500,7 +2500,7 @@ index 388d2ec..135ecf6 100755 # Global batch size BSIZE=$(( GPUS_PER_NODE * BSIZE_PER_GPU * SLURM_JOB_NUM_NODES / TP_SIZE)) -- -2.25.1 +2.34.1 From 511e69e39448f36a1bfaf6aeeb39a86432458346 Mon Sep 17 00:00:00 2001 @@ -2528,7 +2528,7 @@ index 8dfb094..b104128 100644 # Start warming up the input pipeline in the background. This must happen # after input pipeline checkpoints were restored. -- -2.25.1 +2.34.1 From 1af79c4f0aa399132f2785a0a7b637914e3b3b6a Mon Sep 17 00:00:00 2001 @@ -2565,7 +2565,7 @@ index def1a1a..0d12f30 100755 + 2>&1 | tee \ ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log -- -2.25.1 +2.34.1 From 9be7aca9e50f0030348319e18bf6e3217765c9de Mon Sep 17 00:00:00 2001 @@ -2606,7 +2606,7 @@ index fb5f48f..f3750ca 100644 class TransformerEngineHelper(TransformerEngineHelperBase): -- -2.25.1 +2.34.1 From b150ec1b3b425eb20b196c5121aaf0d5d8fc1735 Mon Sep 17 00:00:00 2001 @@ -2633,7 +2633,7 @@ index f3750ca..f585752 100644 "Transformer Engine does not support dataset.packing, please turn it off." -- -2.25.1 +2.34.1 From 28b7cad906c14096ff350c3732c09b89da8d2b2d Mon Sep 17 00:00:00 2001 @@ -2660,7 +2660,7 @@ index a9974e1..660df3a 100644 | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) | [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) -- -2.25.1 +2.34.1 From 9a94ef9d79a7c376cc433dd14be4d27a86c109ae Mon Sep 17 00:00:00 2001 @@ -2686,7 +2686,7 @@ index 660df3a..c31094d 100644 | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 256 | 1 | 8 | ~9688 | **37.8** | **2.4 days** | **614** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) -- -2.25.1 +2.34.1 From 6812d6dce1c417999b3d3167fb58572f49f08780 Mon Sep 17 00:00:00 2001 @@ -2712,7 +2712,7 @@ index 5891c08..e13b810 100644 enable_dropout=False, method=self.module.encode, -- -2.25.1 +2.34.1 From bb53dabb9aeff2f9a68b420f071ba24f403edeff Mon Sep 17 00:00:00 2001 @@ -2753,7 +2753,7 @@ index e13b810..cc3348f 100644 # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop # after the prompt by matching to `output_vocabulary.eos_id`. -- -2.25.1 +2.34.1 From f3d56a1db5ade8a6fae94bbdeb40c54564828fa3 Mon Sep 17 00:00:00 2001 @@ -2805,7 +2805,7 @@ index f585752..568f596 100644 name=name) -- -2.25.1 +2.34.1 From 6e48146736b418e12279c11c34c68dad69bb46b7 Mon Sep 17 00:00:00 2001 @@ -2833,7 +2833,7 @@ index 568f596..05c5f6b 100644 @staticmethod def update_fp8_metas(grad_accum, flax_mutables): -- -2.25.1 +2.34.1 From afec601e706f3b4f9d5c7521a950f144327533a8 Mon Sep 17 00:00:00 2001 @@ -2996,7 +2996,7 @@ index 05c5f6b..7657c52 100644 return TENotInstalledHelper -- -2.25.1 +2.34.1 From 681aa5a2e2cc1607901fe584dd57b1e36f5484db Mon Sep 17 00:00:00 2001 @@ -3032,7 +3032,7 @@ index 7657c52..b064d2b 100644 @staticmethod -- -2.25.1 +2.34.1 From 24efb64b639695a1fb34fb6fe31d68ad6842e685 Mon Sep 17 00:00:00 2001 @@ -3072,7 +3072,7 @@ index 847dc24..63873cf 100644 class PjitPartitioner(BasePjitPartitioner): -- -2.25.1 +2.34.1 From dfc8e08ff38e4013e10a43718f5216d7db20a1bc Mon Sep 17 00:00:00 2001 @@ -3549,7 +3549,7 @@ index d083540..56919a5 100755 -set +x +echo Finished -- -2.25.1 +2.34.1 From 872719719603a77fe30186fbc58805aa41ea127f Mon Sep 17 00:00:00 2001 @@ -3575,7 +3575,7 @@ index cebb815..8910ce8 100644 if num_microbatches is None or num_microbatches <= 1: -- -2.25.1 +2.34.1 From 3ca8e3453ec706b8ea0f237ad34cb8da0afa78f3 Mon Sep 17 00:00:00 2001 @@ -3601,5 +3601,5 @@ index cc3348f..eb7bd37 100644 """Predicts a batch of outputs from the model. -- -2.25.1 +2.34.1 diff --git a/README.md b/README.md index 3d96c3269..061950cbb 100644 --- a/README.md +++ b/README.md @@ -163,13 +163,14 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb | `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels | | `--xla_gpu_enable_async_all_gather` | `true` | allows XLA to run NCCL [AllGather](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#allgather) kernels on a separate CUDA stream to allow overlap with compute kernels | | `--xla_gpu_enable_async_reduce_scatter` | `true` | allows XLA to run NCCL [ReduceScatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#reducescatter) kernels on a separate CUDA stream to allow overlap with compute kernels | -| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels | | Environment Variable | Value | Explanation | | -------------------- | ----- | ----------- | | `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches | | `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. | | `CUDA_MODULE_LOADING` | `EAGER` | Disables lazy-loading ([1](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables)) which uses slightly more GPU memory. | +| `JAX_SHARE_BINARY_BETWEEN_HOSTS` | `True` | Forces the coordinator process to share the optimized XLA module with other processes. Helps prevent hangs resulting from different processes having different optimized modules. | +| `JAX_SHARE_AUTOTUNE_CONFIG_BETWEEN_HOSTS` | `True` | Forces the coordinator process to share the autotune config with other participants. Helps prevent hangs, but can increase compilation time by ~1.5x. | ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.