From da4483f6fb50245a3685c844da53471cd2585908 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Thu, 3 Oct 2024 10:21:16 -0700 Subject: [PATCH] Upgrade axlearn to jax 0.4.33 --- .../common/flash_attention/gpu_attention.py | 49 +++++-------------- pyproject.toml | 8 +-- 2 files changed, 17 insertions(+), 40 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 7d8bb3a5..b8bcf592 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -34,12 +34,7 @@ # pytype: disable=import-error # pylint: disable=import-error from jax import lax -from jax._src.cudnn.fused_attention_stablehlo import ( - MaskType, - _dot_product_attention, - _normalize_layout, - check_cudnn_version, -) +from jax._src.cudnn.fused_attention_stablehlo import MaskType, dot_product_attention from jax._src.lib import cuda_versions from jax.experimental import pallas as pl @@ -730,38 +725,20 @@ def cudnn_dot_product_attention( if qkv_layout != "BTNH": raise NotImplementedError(f"Unsupported qkv_layout: {qkv_layout}") - # Check if cuDNN is installed. - cudnn_version = check_cudnn_version() # Support Ampere and Hopper only for now. _check_local_compute_capability((80, 90)) mask_type = MaskType.NO_MASK if not causal else MaskType.CAUSAL - layout = _normalize_layout(qkv_layout) - - has_bias = bias is not None - has_mask = mask is not None - has_dbias = False - variadic_args = (has_bias, has_mask, has_dbias) - if bias is None: - bias = jnp.zeros(0, dtype=query.dtype) - if mask is None: - mask = jnp.zeros(0, dtype=query.dtype) - q_seqlen = jnp.zeros(0, dtype=query.dtype) - kv_seqlen = jnp.zeros(0, dtype=query.dtype) - # pylint: disable-next=too-many-function-args - output = _dot_product_attention( - query, - key, - value, - bias, - mask, - q_seqlen, - kv_seqlen, - softmax_scale, - seed, - dropout_rate, - variadic_args, - mask_type, - layout.value, - cudnn_version, + + output = dot_product_attention( + query=query, + key=key, + value=value, + bias=bias, + mask=mask, + scale=softmax_scale, + seed=seed, + dropout_rate=dropout_rate, + mask_type=mask_type, + qkv_layout=qkv_layout, ) return output diff --git a/pyproject.toml b/pyproject.toml index eb96b334..763d03be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,8 @@ core = [ "absl-py==2.1.0", "chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25. "importlab==0.7", # breaks pytype on 0.8 - "jax==0.4.30", - "jaxlib==0.4.30", + "jax==0.4.33", + "jaxlib==0.4.33", "nltk==3.7", # for text preprocessing "optax==0.1.7", # optimizers (0.1.0 has known bugs). "portpicker", @@ -100,7 +100,7 @@ gcp = [ # Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install. tpu = [ "axlearn[gcp]", - "jax[tpu]==0.4.30", # must be >=0.4.19 for compat with v5p. + "jax[tpu]==0.4.33", # must be >=0.4.19 for compat with v5p. ] # Vertex AI tensorboard. vertexai_tensorboard = [ @@ -124,7 +124,7 @@ dataflow = [ # GPU custom kernel dependency. gpu = [ "triton==2.1.0", - "jax[cuda12_pip]==0.4.30", + "jax[cuda12_pip]==0.4.33", ] # Open API inference. open_api = [