-
Notifications
You must be signed in to change notification settings - Fork 456
Update tflops calc for qwen3 next & GatedDeltaNet #2999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a unit test in flop_calculation_test.py. Also will run the sanity checks and update the pr description with results. |
parambole
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment regarding the attention calculation. PTAL.
parambole
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
Description
PR Title: Enable TFLOPs Calculation for Qwen3-Next
Description
This PR implements the TFLOPs calculation logic for the Qwen3-Next architecture. Qwen3-Next utilizes a hybrid design containing both standard full attention layers and linear attention (Gated Delta Net) layers, alongside a Mixture-of-Experts (MoE) structure. This update ensures training TFLOPs are accurately reported for this model family.
Key Changes
1. TFLOPs Calculation Logic (
src/MaxText/maxtext_utils.py)calculate_gated_delta_net_flops_per_device: Implemented logic to calculate FLOPs for the Linear Attention layers, breaking down operations into:calculate_tflops_training_per_device:DecoderBlockType.QWEN3_NEXT.inhomogeneous_layer_cycle_interval.calculate_routed_and_shared_ffn_tflops_per_deviceandget_dense_moe_layersexplicitly supportQWEN3_NEXT.2. Unit Tests (
tests/unit/flop_calculation_test.py)test_qwen3_next_flops: A new unit test that verifies the implementation against a "golden" manual calculation for the 80B model configuration.compute_qwen3_next_attention_flops_per_device: A helper function to compute the expected attention-specific FLOPs for the test assertion.3. Config Updates (
src/MaxText/configs/models/qwen3-next-80b-a3b.yml)shared_experts: 1to the model configuration to ensure correct parameter counting in the FFN FLOPs calculation.If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/477291633
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.