-
Notifications
You must be signed in to change notification settings - Fork 455
Add Packing Support for Context Parallelism (Ring Attention) #2906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
richjames0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a couple of nits but lgtm
src/MaxText/layers/attention_op.py
Outdated
| # Handle packing configurations | ||
| if self.config.packing and self.config.dataset_type != "synthetic": | ||
| if using_context_parallelism and not using_load_balanced_ring_cp: | ||
| raise AssertionError("Packing is only supported for load balanced ring attention with context parallelism.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: AssertionError feels weird here to me. Maybe an argumenterror?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
converted to ValueError
src/MaxText/maxtext_utils.py
Outdated
|
|
||
|
|
||
| def get_reorder_callable(cp_size, shard_mode): | ||
| def get_reorder_callable(cp_size, shard_mode, reorder_strategy=0): # 0=DualChunkSwap, 1=Striped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I read this late at night I imagine you're using an integer here so it's comprehensible by JAX but could this be made into an enum without breaking things (at worse using .value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to enum @richjames0
src/MaxText/maxtext_utils.py
Outdated
|
|
||
|
|
||
| def shard_reorder_causal_load_balanced(batch, cp_size, shard_mode): | ||
| def shard_reorder_causal_load_balanced(batch, cp_size, shard_mode, reorder_strategy=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make reorder_strategy configurable via base.yml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, ptal @gobbleturk
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment - Terminology/naming is hard, folks have been using the name "striped" to refer to DUAL_CHUNK_SWAP. I guess the string "striped" has to be passed to transformer engine for the other strategy (I would prefer the name "interleaved")...
I highly appreciate your comment with examples of the two strategies to clearly show what they mean in our codebase anyway!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gobbleturk , so do you want to change the MaxText strategy name to interleaved?
…llelism - Add ReorderStrategy enum to common_types.py (AUTO, DUAL_CHUNK_SWAP, STRIPED) - Add context_parallel_reorder_strategy config option - Update pyconfig, types.py, and train_utils.py to use enum - Map MaxText enum to TE ReorderStrategy in max_utils.py
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tests and great comments illustrating the two reorder strategies!
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just minor comments.
| ### Determine if we want to use load balance for context parallelism | ||
| context_parallel_load_balance: True | ||
| context_parallel_strategy: "all_gather" # "all_gather" or "ring" | ||
| context_parallel_reorder_strategy: "auto" # "auto", "dual_chunk_swap", or "striped" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this strategy! Could you help add some explanation here for each from reorder_causal_load_balanced?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @RissyRan, I was ooo, could not follow up. The explanation for each of the reorder strategies are there with some example in max_utils.py. Are you asking to include it in the base.yml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, could you add some short explanation here. Users usually refer to base config instead of checking source codes. Some concise explanation will be great!
| data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator) | ||
|
|
||
| # Determine load balancing reorder strategy based on whether packing is enabled | ||
| if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if AUTO usually gives the best result? Or if there will be an compatible issue if user select ReorderStrategy.STRIPED, but without packing?
Trying to understand if we just provides 2 strategies are good enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @RissyRan, you can think of AUTO as a heuristics for users who may not want to deal with different reordering + cp strategy intricacies to begin with. We plan to add more support for reordering, packing with other CP strategies in future. That's why proposing this. If user has high level confidence and understanding, then they can just choose themselves, otherwise AUTO will try to pick based on the CP type, packing etc. configs. Let me know if it makes sense.
Description
Enables sequence packing for context parallelism with
ringstrategy using TransformerEngine's DotProductAttention. Includes comprehensive GPU tests for ring attention with packing for sm90+.reorder_causal_load_balancingapiTests
Added a GPU integration test that works for sm90+.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.