Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2d2e33a
Fix incorrect calculation of segment pos from segment ids for thd cas…
KshitijLakhani Dec 16, 2025
65e6b4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
9857577
Correct the assert condition
KshitijLakhani Dec 17, 2025
b20ac22
Modify fused attn tests to pass new args to from_segment_ids_and_pos()
KshitijLakhani Dec 17, 2025
03398a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
0a47eb6
Calculate seg ids before pos
KshitijLakhani Dec 17, 2025
217ea58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
ca9d3bc
1. Change the signature for from_segment_ids_and_pos()
Dec 23, 2025
0ee40a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
ceec1ea
Pass keyword-only args by name
Dec 23, 2025
ab00bb0
nit: Fix typo to use seg_ids instead of segment_ids
Dec 23, 2025
059c48d
nit: Fix comments
Dec 23, 2025
d524ad6
Modify the function call to differentiate between load balancing and …
Dec 23, 2025
d419f98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
3efa504
Fix the is_segment_ids_reordered to be set only when CP and load bala…
KshitijLakhani Dec 24, 2025
e65062c
Fix comments for from_segment_ids_and_pos()
KshitijLakhani Dec 24, 2025
74a352e
Code clean up
pre-commit-ci[bot] Dec 24, 2025
e5381fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 24, 2025
879afb4
Merge branch 'main' into klakhani/fix/incorrect-sequence-descr-from-s…
ksivaman Dec 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,14 +668,24 @@ def generate_random_segment_ids(
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
# Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
# if no CP and load balancing, else explicitly pass the segment_pos
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
)
if self.cp_size > 1 and self.cp_load_balanced
else None
),
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=(
True if self.cp_size > 1 and self.cp_load_balanced else False
),
)
case _:
Expand Down Expand Up @@ -704,6 +714,8 @@ def generate_random_segment_ids(
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=False,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
Expand Down
85 changes: 76 additions & 9 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ class SequenceDescriptor:
- SequenceDescriptor.from_seqlens_and_offsets
For THD (packed) cases, where each batch may have not only 1 sequence.
- SequenceDescriptor.from_segment_ids_and_pos
Experimental feature for THD (packed) cases with context parallelism.
Experimental feature for BSHD (with and without reordering) and THD (packed) cases without reordering
"""

seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
Expand Down Expand Up @@ -796,9 +796,14 @@ def from_segment_ids_and_pos(
cls,
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
*,
is_thd: bool,
is_segment_ids_reordered: bool,
) -> SequenceDescriptor:
"""
Experimental factory method for inputs with segment IDs and optional positions. (THD)
Experimental factory method for inputs with segment IDs and optional positions.
segment_pos = None to be used only for: BSHD with or without load balancing and,
THD without load balancing
Args:
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
- q_segment_ids (jnp.ndarray):
Expand All @@ -812,22 +817,84 @@ def from_segment_ids_and_pos(
The position inside each segment for query, with shape [batch, max_seqlen].
- kv_segment_pos (jnp.ndarray):
The position inside each segment for key, value, with shape [batch, max_seqlen].
is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD
is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing.
Only THD with load balancing is expected to have this flag set to True
Return:
A SequenceDescriptor with segment_ids/segment_pos initialized.
"""
q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)

if segment_pos is not None:
segment_pos = cls._expand_to_pair(segment_pos)
else:

def generate_default_pos(segment_ids):
seqlen = segment_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)
# Using defaults : segment pos has to be generated.
if segment_pos is None:
# THD + load balanced segment_ids are not supported in this function
# BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself
if is_segment_ids_reordered:
assert not is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered"
" (Striped) THD inputs. Please pass the load balanced reordered segment_pos"
" and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}"
" using convenience function reorder_causal_load_balancing()"
)
assert is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered (Dual"
" Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load"
" balanced reordered. The reordering for these is performed within the"
" primitive"
)

# Generate the default pos for THD and BSHD non-reordered segment_ids
def generate_default_pos(seg_ids):
if is_thd:
batch_size, seq_size = seg_ids.shape
# Assume that the first token belongs to a segment and is not a padded token
first_is_segment = jnp.full((batch_size, 1), True, dtype=bool)
# Get segment start positions
segment_start = jnp.concatenate(
[
first_is_segment,
(seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0),
],
axis=-1,
)
# Get offset for location where new segment starts
segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(
segment_start
)
segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx)

# Get the last non-zero index - after this everything is padding
# (B,)
last_nonzero_idx = jax.vmap(
lambda segids_row: jnp.max(
jnp.where(segids_row != 0, jnp.arange(seq_size), -1)
)
)(seg_ids)
seg_pos_no_thd = jnp.arange(seq_size)
# Get a mask which can be used to zero out all the padding at the end (after the non-zero index)
mask = seg_pos_no_thd <= last_nonzero_idx[:, None]

# Get the unmasked seg_pos for the THD sequence
seg_pos = (
jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape)
- segment_start_offsets
)

# Use the mask to zero out the padding at the end (after the non-zero index)
segment_pos = jax.vmap(
lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0)
)(seg_pos, mask)
return segment_pos

seqlen = seg_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape)

q_seg_pos = generate_default_pos(q_seg_ids)
kv_seg_pos = generate_default_pos(kv_seg_ids)
segment_pos = (q_seg_pos, kv_seg_pos)
# Explicitly passed segment_pos
else:
segment_pos = cls._expand_to_pair(segment_pos)

return cls(
segment_ids=(q_seg_ids, kv_seg_ids),
Expand Down
Loading