From 6857b287798842c69ab082295e0dd15fb48a7182 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 24 Dec 2025 19:19:31 +0000 Subject: [PATCH 1/6] Pick a leaner set of combinations for TE JAX CP attn tests such that only those cp,dp,tp combinations are picked where cp*dp*tp is equal to num gpus Signed-off-by: Kshitij Lakhani --- tests/jax/distributed_test_base.py | 7 +++++-- tests/jax/utils.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 137fa480d..f86f81ec4 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -12,7 +12,7 @@ from transformer_engine.jax.sharding import MeshResource -from utils import assert_allclose, is_devices_enough +from utils import assert_allclose, is_devices_enough, is_devices_equal def generate_configs(): @@ -49,7 +49,10 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - if is_devices_enough(ndev): + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. + # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. + if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: configsL1.append( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 7194e387c..c3311395a 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -46,6 +46,12 @@ def is_devices_enough(required): """ return len(jax.devices()) >= required +def is_devices_equal(required): + """ + Check if the available GPUs is exactly equal + """ + return len(jax.devices()) == required + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. From d6a29518a2c5e21a537a42dab22f6f229f3884bd Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 8 Jan 2026 01:29:51 +0000 Subject: [PATCH 2/6] Consolidate the test cases run for different B,S,H,D and QKV layout Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 73 ++++++++++++------------------------ 1 file changed, 24 insertions(+), 49 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 49372fda1..85b38e5f5 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1056,56 +1056,31 @@ def check_dqkv(primitive, reference, pad, idx): ], ) @pytest.mark.parametrize( - "qkv_layout", + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ - pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), - pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), - pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"), - pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"), - pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"), - ], -) -@pytest.mark.parametrize( - "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype", - [ - pytest.param( - 2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF" - ), - pytest.param( - 2, - 512, - 1024, - 12, - 12, - 64, - 64, - jnp.bfloat16, - id="2-512-1024-12-12-64-64-BF16-CROSS", - ), - pytest.param( - 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" - ), - pytest.param( - 4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF" - ), - pytest.param( - 4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF" - ), - pytest.param( - 2, - 2048, - 1024, - 12, - 12, - 64, - 32, - jnp.bfloat16, - id="2-2048-1024-12-12-64-32-BF16-CROSS", - ), - pytest.param( - 2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA" - ), + # large + QKV_PACKED and RAGGED_QKV_PACKED? + pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BS3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED"), + pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.T3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED"), + # mid + cross + KV_PACKED and RAGGED_KV_PACKED + pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED"), + pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED"), + # large + cross + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? + pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE"), + pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.THD_THD_THD, id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE"), + # large + gqa + KV_PACKED and RAGGED_KV_PACKED ? + pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED"), + pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED"), + + # small + fp16 + diff hidden v dim + QKV_PACKED and RAGGED_QKV_PACKED ? + pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.BS3HD, id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED"), + pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.T3HD, id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED"), + # small + fp16 + KV_PACKED and RAGGED_KV_PACKED ? + pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.BSHD_BS2HD, id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED"), + pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.THD_T2HD, id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED"), + # large + fp16 + gqa + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? + # TODO: Consider making this a CROSS case ? + pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE"), + pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.THD_THD_THD, id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE"), ], ) @pytest.mark.parametrize( From 2dd5068a90b86d3651098cc80cf2db98ba7776c1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:32:09 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/distributed_test_base.py | 2 +- tests/jax/test_fused_attn.py | 183 ++++++++++++++++++++++++++--- tests/jax/utils.py | 1 + 3 files changed, 170 insertions(+), 16 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index f86f81ec4..1593f50f0 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -49,7 +49,7 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. if is_devices_equal(ndev): diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 85b38e5f5..ebf054369 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1059,28 +1059,181 @@ def check_dqkv(primitive, reference, pad, idx): "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ # large + QKV_PACKED and RAGGED_QKV_PACKED? - pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BS3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED"), - pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.T3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED"), + pytest.param( + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BS3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.T3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED", + ), # mid + cross + KV_PACKED and RAGGED_KV_PACKED - pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED"), - pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED"), + pytest.param( + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED", + ), + pytest.param( + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", + ), # large + cross + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? - pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE"), - pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.THD_THD_THD, id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE"), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE", + ), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.THD_THD_THD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE", + ), # large + gqa + KV_PACKED and RAGGED_KV_PACKED ? - pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED"), - pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED"), - + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED", + ), # small + fp16 + diff hidden v dim + QKV_PACKED and RAGGED_QKV_PACKED ? - pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.BS3HD, id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED"), - pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.T3HD, id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED"), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.BS3HD, + id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.T3HD, + id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED", + ), # small + fp16 + KV_PACKED and RAGGED_KV_PACKED ? - pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.BSHD_BS2HD, id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED"), - pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.THD_T2HD, id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED"), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.THD_T2HD, + id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED", + ), # large + fp16 + gqa + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? # TODO: Consider making this a CROSS case ? - pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE"), - pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.THD_THD_THD, id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE"), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.THD_THD_THD, + id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE", + ), ], ) @pytest.mark.parametrize( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c3311395a..393070750 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -46,6 +46,7 @@ def is_devices_enough(required): """ return len(jax.devices()) >= required + def is_devices_equal(required): """ Check if the available GPUs is exactly equal From 1f13b288ae00f88a500424b09492b0340b0589f6 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 8 Jan 2026 21:39:54 +0000 Subject: [PATCH 4/6] Code and comments clean up Signed-off-by: Kshitij Lakhani --- tests/jax/distributed_test_base.py | 7 +- tests/jax/test_distributed_fused_attn.py | 2 +- tests/jax/test_fused_attn.py | 198 +++-------------------- 3 files changed, 27 insertions(+), 180 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index aea03decc..7b599d593 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -49,9 +49,10 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - # Run only those dp,cp,tp combinations which require exactly ndev GPUs. - # For e.g., if ndev=8 and num GPUs is 8, thent hose combinations will be picked. - # However, if ndev=4, but num GPUs is 8, then those combinations will not be picked. To pick such a combination, one can set CUDA_VISIBLE_DEVICES=0,1,2,3. + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # For e.g., if num_GPUs is 8 and ndev=8 , all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are picked. + # However, if num_GPUs is 8 and ndev=4, then all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are ignored. + # To explicitly pick combinations associated with ndev=4, one can set CUDA_VISIBLE_DEVICES=0,1,2,3, thereby forcing num_GPUs to 4 instead of 8. if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) if cp != 1: diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d0018543d..eef5d4ec4 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -334,7 +334,7 @@ def test_cross_attn( class TestDistributedContextParallelSelfAttn: - + #TODO(KshitijLakhani): parametrize num_segments_per_seq for all CP tests def impl_test_context_parallel_attn( self, device_count, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 6fb40a499..7d15580a5 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1070,182 +1070,28 @@ def check_dqkv(primitive, reference, pad, idx): @pytest.mark.parametrize( "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ - # large + QKV_PACKED and RAGGED_QKV_PACKED? - pytest.param( - 2, - 2048, - 2048, - 12, - 12, - 64, - 64, - jnp.bfloat16, - QKVLayout.BS3HD, - id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED", - ), - pytest.param( - 2, - 2048, - 2048, - 12, - 12, - 64, - 64, - jnp.bfloat16, - QKVLayout.T3HD, - id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED", - ), - # mid + cross + KV_PACKED and RAGGED_KV_PACKED - pytest.param( - 2, - 512, - 1024, - 12, - 12, - 64, - 64, - jnp.bfloat16, - QKVLayout.BSHD_BS2HD, - id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED", - ), - pytest.param( - 2, - 512, - 1024, - 12, - 12, - 64, - 64, - jnp.bfloat16, - QKVLayout.THD_T2HD, - id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", - ), - # large + cross + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? - pytest.param( - 2, - 2048, - 1024, - 12, - 12, - 64, - 32, - jnp.bfloat16, - QKVLayout.BSHD_BSHD_BSHD, - id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE", - ), - pytest.param( - 2, - 2048, - 1024, - 12, - 12, - 64, - 32, - jnp.bfloat16, - QKVLayout.THD_THD_THD, - id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE", - ), - # large + gqa + KV_PACKED and RAGGED_KV_PACKED ? - pytest.param( - 2, - 2048, - 2048, - 12, - 6, - 64, - 64, - jnp.bfloat16, - QKVLayout.BSHD_BS2HD, - id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED", - ), - pytest.param( - 2, - 2048, - 2048, - 12, - 6, - 64, - 64, - jnp.bfloat16, - QKVLayout.THD_T2HD, - id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED", - ), - # small + fp16 + diff hidden v dim + QKV_PACKED and RAGGED_QKV_PACKED ? - pytest.param( - 4, - 128, - 128, - 16, - 16, - 64, - 32, - jnp.float16, - QKVLayout.BS3HD, - id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED", - ), - pytest.param( - 4, - 128, - 128, - 16, - 16, - 64, - 32, - jnp.float16, - QKVLayout.T3HD, - id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED", - ), - # small + fp16 + KV_PACKED and RAGGED_KV_PACKED ? - pytest.param( - 4, - 128, - 128, - 16, - 16, - 64, - 64, - jnp.float16, - QKVLayout.BSHD_BS2HD, - id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED", - ), - pytest.param( - 4, - 128, - 128, - 16, - 16, - 64, - 64, - jnp.float16, - QKVLayout.THD_T2HD, - id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED", - ), - # large + fp16 + gqa + diff hidden v dim + SEPARATE and RAGGED_SEPARATE ? - # TODO: Consider making this a CROSS case ? - pytest.param( - 2, - 2048, - 2048, - 12, - 6, - 128, - 64, - jnp.float16, - QKVLayout.BSHD_BSHD_BSHD, - id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE", - ), - pytest.param( - 2, - 2048, - 2048, - 12, - 6, - 128, - 64, - jnp.float16, - QKVLayout.THD_THD_THD, - id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE", - ), + # large data size + bf16 + qkv packed + pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BS3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED"), + pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.T3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED"), + # mid data size + bf16 + cross attn + kv packed + pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED"), + pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED"), + # large data size + bf16 + cross attn + diff hidden v dim + qkv separate + pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE"), + pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.THD_THD_THD, id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE"), + # large data size + bf16 + gqa + kv packed + pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED"), + pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED"), + + # small data size + fp16 + diff hidden v dim + qkv packed + pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.BS3HD, id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED"), + pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.T3HD, id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED"), + # small data size + fp16 + kv packed + pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.BSHD_BS2HD, id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED"), + pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.THD_T2HD, id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED"), + # large data size + fp16 + gqa + diff hidden v dim + qkv separate + pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE"), + pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.THD_THD_THD, id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE"), ], ) @pytest.mark.parametrize( From 36fbfaa09e41f11037d699059e3a2e0ffcb74947 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 21:43:33 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/distributed_test_base.py | 4 +- tests/jax/test_distributed_fused_attn.py | 2 +- tests/jax/test_fused_attn.py | 183 +++++++++++++++++++++-- 3 files changed, 171 insertions(+), 18 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 7b599d593..6d963f5c7 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -49,9 +49,9 @@ def generate_context_parallel_configs_for_attn(): TP_sizes = (1, 2) for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp - # Run only those dp,cp,tp combinations which require exactly ndev GPUs. + # Run only those dp,cp,tp combinations which require exactly ndev GPUs. # For e.g., if num_GPUs is 8 and ndev=8 , all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are picked. - # However, if num_GPUs is 8 and ndev=4, then all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are ignored. + # However, if num_GPUs is 8 and ndev=4, then all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are ignored. # To explicitly pick combinations associated with ndev=4, one can set CUDA_VISIBLE_DEVICES=0,1,2,3, thereby forcing num_GPUs to 4 instead of 8. if is_devices_equal(ndev): # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index eef5d4ec4..d5ebe9f26 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -334,7 +334,7 @@ def test_cross_attn( class TestDistributedContextParallelSelfAttn: - #TODO(KshitijLakhani): parametrize num_segments_per_seq for all CP tests + # TODO(KshitijLakhani): parametrize num_segments_per_seq for all CP tests def impl_test_context_parallel_attn( self, device_count, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 7d15580a5..572198c20 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1071,27 +1071,180 @@ def check_dqkv(primitive, reference, pad, idx): "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout", [ # large data size + bf16 + qkv packed - pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BS3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED"), - pytest.param(2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.T3HD, id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED"), + pytest.param( + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BS3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.T3HD, + id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED", + ), # mid data size + bf16 + cross attn + kv packed - pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED"), - pytest.param(2, 512, 1024, 12, 12, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED"), + pytest.param( + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED", + ), + pytest.param( + 2, + 512, + 1024, + 12, + 12, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED", + ), # large data size + bf16 + cross attn + diff hidden v dim + qkv separate - pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE"), - pytest.param(2, 2048, 1024, 12, 12, 64, 32, jnp.bfloat16, QKVLayout.THD_THD_THD, id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE"), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE", + ), + pytest.param( + 2, + 2048, + 1024, + 12, + 12, + 64, + 32, + jnp.bfloat16, + QKVLayout.THD_THD_THD, + id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE", + ), # large data size + bf16 + gqa + kv packed - pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.BSHD_BS2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED"), - pytest.param(2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, QKVLayout.THD_T2HD, id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED"), - + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.BSHD_BS2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 64, + 64, + jnp.bfloat16, + QKVLayout.THD_T2HD, + id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED", + ), # small data size + fp16 + diff hidden v dim + qkv packed - pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.BS3HD, id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED"), - pytest.param(4, 128, 128, 16, 16, 64, 32, jnp.float16, QKVLayout.T3HD, id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED"), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.BS3HD, + id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 32, + jnp.float16, + QKVLayout.T3HD, + id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED", + ), # small data size + fp16 + kv packed - pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.BSHD_BS2HD, id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED"), - pytest.param(4, 128, 128, 16, 16, 64, 64, jnp.float16, QKVLayout.THD_T2HD, id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED"), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.BSHD_BS2HD, + id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED", + ), + pytest.param( + 4, + 128, + 128, + 16, + 16, + 64, + 64, + jnp.float16, + QKVLayout.THD_T2HD, + id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED", + ), # large data size + fp16 + gqa + diff hidden v dim + qkv separate - pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.BSHD_BSHD_BSHD, id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE"), - pytest.param(2, 2048, 2048, 12, 6, 128, 64, jnp.float16, QKVLayout.THD_THD_THD, id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE"), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.BSHD_BSHD_BSHD, + id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE", + ), + pytest.param( + 2, + 2048, + 2048, + 12, + 6, + 128, + 64, + jnp.float16, + QKVLayout.THD_THD_THD, + id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE", + ), ], ) @pytest.mark.parametrize( From 4dbd60062c092c9b5732c9e12f210844cf8b183f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 8 Jan 2026 21:46:54 +0000 Subject: [PATCH 6/6] Make FP16 + GQA test cross attn instead of self attn to generalize the test Signed-off-by: Kshitij Lakhani --- tests/jax/test_fused_attn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 572198c20..a0aee5043 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1220,10 +1220,10 @@ def check_dqkv(primitive, reference, pad, idx): QKVLayout.THD_T2HD, id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED", ), - # large data size + fp16 + gqa + diff hidden v dim + qkv separate + # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate pytest.param( 2, - 2048, + 1024, 2048, 12, 6, @@ -1231,11 +1231,11 @@ def check_dqkv(primitive, reference, pad, idx): 64, jnp.float16, QKVLayout.BSHD_BSHD_BSHD, - id="2-2048-2048-12-6-128-64-FP16-GQA-SEPARATE", + id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-SEPARATE", ), pytest.param( 2, - 2048, + 1024, 2048, 12, 6, @@ -1243,7 +1243,7 @@ def check_dqkv(primitive, reference, pad, idx): 64, jnp.float16, QKVLayout.THD_THD_THD, - id="2-2048-2048-12-6-128-64-FP16-GQA-RAGGED_SEPARATE", + id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE", ), ], )