Skip to content

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Jan 21, 2026

Description

Add flops calculation for DeepSeek v3.2, and this PR depends on this change

  • Add indexer flops helper function
  • Add option to combine indexer flops inside of MLA
  • Added a unit test to rough estimate tflops

Tests

  • Tested DS v2-16b, no impact
# Before the change
Per train step:
 Total TFLOPs: 593.27 
 split as 81.24% learnable weight flops and 18.76% attention flops
before the change: 593.272422531072

# After the change
Per train step:
 Total TFLOPs: 593.27 
 split as 81.24% learnable weight flops and 18.76% attention flops
after change: 593.272422531072
  • Tested DS v3.2, with indexer (expected)
# Enable this feature
Per train step:
 Total TFLOPs: 4162.71 
 split as 88.50% learnable weight flops and 11.50% attention flops
enable use_sparse_indexer: 4162.71

# Disable this feature like v3
Per train step:
 Total TFLOPs: 4103.37 
 split as 87.74% learnable weight flops and 12.26% attention flops
disable use_sparse_indexer: 4103.370952409088

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful calculation! Here is the summary of my understanding and questions. Would appreciate your confirmation and clarification (could be in-line).

Let B=batch_size, T=max_target_length, K=index_topk

Indexer Flops

b=B, s=t=T, h=index_n_heads, d=index_head_dim

# indexer linear flop
I_1 = q_proj + k_proj + head_proj
# Q = RoPE(Wq @ q_lora),  [b, t, q_lora_rank] x [b, t, q_lora_rank, h * d] -> [b, t, h * d]
q_proj = 2 * B * T * index_n_heads * index_head_dim * q_lora_rank  
# K = RoPE(Norm(Wk @ X)), [b, s, embed_dim] x [b, s, emb_dim, d] -> [b, s, d]
k_proj = 2 * B * T * index_head_dim * emb_dim 
# Head_Weights = (W_proj @ X), [b, t, embed_dim] x [b, t, emb_dim, h] -> [b, t, h]
head_proj = 2 * B * T * index_n_heads * emb_dim 

# indexer quadratic flop (causal factor 0.5 applied)
I_2 = qk_product + index_score
# Logits = ReLU(Q @ K.T), "bthd, bsd -> btsh"
qk_product = B * T^2 * index_n_heads * index_head_dim 
# Score = Sum_head(Logits * Head_Weights), "btsh, bth -> bts"
index_score = B * T^2 * index_n_heads

[Remark] Your code looks correct. Might be good to clarify in the comments, especially about causality.

Regular causal attention: B * H * D * T^2, where H = num_query_heads and D = qk_head_dim_sum + v_head_dim (MLA specific)

Sparse causal attention ($T\le K$):

  • Current logic: regular_causal_attention_flop + indexer_flop
  • [Question 1] Shall we keep or skip indexer_flop? In our actual implementation, we skip the indexer computation in this case.

Sparse causal attention ($T > K$):

  • Current logic: 2 * B * H * D * (T * K) + indexer_flop.
  • [Question 2] This seems an approximation, overestimating the cost during the "warm-up" phase (first K tokens). Shall we use the exact formula: 2 * B * H * D * (T * K - K^2 / 2) + indexer_flop?

[Question 3] Should the Indexer cost be accounted inside MLA cost or outside?

@RissyRan RissyRan force-pushed the flops_clean branch 3 times, most recently from dcbe7e6 to 86b5b57 Compare January 27, 2026 02:22
@RissyRan RissyRan requested a review from parambole as a code owner January 27, 2026 02:22
@RissyRan RissyRan force-pushed the flops_clean branch 3 times, most recently from 67f8d7d to c8897a5 Compare January 27, 2026 06:33
Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification and visualization! LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants