From 8ac79ba3f6d34b867abafea14b84b5e06a86ad06 Mon Sep 17 00:00:00 2001 From: Yeongjae Jang <54150647+Liberatedwinner@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:33:37 +0900 Subject: [PATCH 1/3] [Relax] Refactor position embedding calculations Refactor frequency calculations for RoPE (YaRN) scaling and correct range finding --- .../relax/frontend/nn/llm/position_embedding.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 35eeb4f5f32f..36531c596fcf 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -197,8 +197,8 @@ def yarn_find_correction_range( max_position_embeddings: int, ): """Find the correction range based on the number of rotations""" - low = tir.floor(yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)) - high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings)) + low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)) + high = yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings)) return tir.max(low, 0), tir.min(high, d - 1) @@ -214,16 +214,15 @@ def rope_freq_yarn( beta_slow: int, ): # pylint: disable=too-many-arguments, too-many-locals """Compute the inverse frequency of RoPE for yarn RoPE scaling.""" - freq_extra = tir.const(1, "float32") / tir.power( - theta, d * 2 % d_range / tir.const(d_range, "float32") - ) - freq_inter = tir.const(1, "float32") / tir.power( - scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32") - ) + exponent = d * 2 % d_range / tir.const(d_range, "float32") + freq_power = tir.power(theta, exponent) + + freq_extra = tir.const(1, "float32") / freq_power + freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power) low, high = yarn_find_correction_range( - beta_fast, beta_slow, d, theta, original_max_position_embeddings + beta_fast, beta_slow, d_range, theta, original_max_position_embeddings ) high = tir.if_then_else(low == high, high + 0.001, high) inv_freq_mask = tir.const(1, "float32") - tir.max( From 8b07e0f255526a9f41ce1dbabb6703636ffa284d Mon Sep 17 00:00:00 2001 From: Yeongjae Jang <54150647+Liberatedwinner@users.noreply.github.com> Date: Wed, 10 Dec 2025 20:11:00 +0900 Subject: [PATCH 2/3] Update python/tvm/relax/frontend/nn/llm/position_embedding.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/relax/frontend/nn/llm/position_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 36531c596fcf..40cb181ca851 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -197,8 +197,8 @@ def yarn_find_correction_range( max_position_embeddings: int, ): """Find the correction range based on the number of rotations""" - low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)) - high = yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings)) + low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings) + high = yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings) return tir.max(low, 0), tir.min(high, d - 1) From 8109546d663787ebe5be11c0eb4da2d0d00b11a1 Mon Sep 17 00:00:00 2001 From: Yeongjae Jang <54150647+Liberatedwinner@users.noreply.github.com> Date: Wed, 31 Dec 2025 10:17:09 +0900 Subject: [PATCH 3/3] [chore] remove unnecessary line; retrigger CI Remove unnecessary line in inverse frequency computation. --- python/tvm/relax/frontend/nn/llm/position_embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 40cb181ca851..60808a6b35fd 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -217,7 +217,6 @@ def rope_freq_yarn( exponent = d * 2 % d_range / tir.const(d_range, "float32") freq_power = tir.power(theta, exponent) - freq_extra = tir.const(1, "float32") / freq_power freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)