Skip to content

Conversation

@Rohan-Bierneni
Copy link
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented Jan 22, 2026

Description

PR Title: Enable TFLOPs Calculation for Qwen3-Next

Description

This PR implements the TFLOPs calculation logic for the Qwen3-Next architecture. Qwen3-Next utilizes a hybrid design containing both standard full attention layers and linear attention (Gated Delta Net) layers, alongside a Mixture-of-Experts (MoE) structure. This update ensures training TFLOPs are accurately reported for this model family.

Key Changes

1. TFLOPs Calculation Logic (src/MaxText/maxtext_utils.py)

  • Added calculate_gated_delta_net_flops_per_device: Implemented logic to calculate FLOPs for the Linear Attention layers, breaking down operations into:
    • Projections: QKVZ, BA, and Output projections.
    • Convolution: Depthwise convolutions on the key/value/gate states.
    • Core Attention: Intra-chunk and inter-chunk recurrent state operations.
  • Updated calculate_tflops_training_per_device:
    • Added a specific branch for DecoderBlockType.QWEN3_NEXT.
    • Logic now calculates the number of "full attention" layers vs. "linear attention" layers based on the inhomogeneous_layer_cycle_interval.
    • Combines FLOPs from embeddings, MoE FFNs (routed + shared), full causal attention, and linear attention.
  • Updated FFN Helpers: Verified calculate_routed_and_shared_ffn_tflops_per_device and get_dense_moe_layers explicitly support QWEN3_NEXT.

2. Unit Tests (tests/unit/flop_calculation_test.py)

  • Added test_qwen3_next_flops: A new unit test that verifies the implementation against a "golden" manual calculation for the 80B model configuration.
  • Added compute_qwen3_next_attention_flops_per_device: A helper function to compute the expected attention-specific FLOPs for the test assertion.

3. Config Updates (src/MaxText/configs/models/qwen3-next-80b-a3b.yml)

  • Added shared_experts: 1 to the model configuration to ensure correct parameter counting in the FFN FLOPs calculation.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/477291633

Tests

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 22, 2026

Codecov Report

❌ Patch coverage is 97.05882% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/maxtext_utils.py 97.05% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@Rohan-Bierneni Rohan-Bierneni self-assigned this Jan 22, 2026
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

We have some unit tests for those tflops, if you could come up with a test, that would be great, but not mandatory for this PR. Also, could you also have some sanity check with other MoE model (no change impact)? Similar to #2979

@Rohan-Bierneni
Copy link
Collaborator Author

Thanks for the update!

We have some unit tests for those tflops, if you could come up with a test, that would be great, but not mandatory for this PR. Also, could you also have some sanity check with other MoE model (no change impact)? Similar to #2979

I have added a unit test in flop_calculation_test.py. Also will run the sanity checks and update the pr description with results.

Copy link
Collaborator

@parambole parambole left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment regarding the attention calculation. PTAL.

Copy link
Collaborator

@parambole parambole left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants