diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 35eeb4f5f32f..60808a6b35fd 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -197,8 +197,8 @@ def yarn_find_correction_range( max_position_embeddings: int, ): """Find the correction range based on the number of rotations""" - low = tir.floor(yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)) - high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings)) + low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings) + high = yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings) return tir.max(low, 0), tir.min(high, d - 1) @@ -214,16 +214,14 @@ def rope_freq_yarn( beta_slow: int, ): # pylint: disable=too-many-arguments, too-many-locals """Compute the inverse frequency of RoPE for yarn RoPE scaling.""" - freq_extra = tir.const(1, "float32") / tir.power( - theta, d * 2 % d_range / tir.const(d_range, "float32") - ) - freq_inter = tir.const(1, "float32") / tir.power( - scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32") - ) + exponent = d * 2 % d_range / tir.const(d_range, "float32") + freq_power = tir.power(theta, exponent) + freq_extra = tir.const(1, "float32") / freq_power + freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power) low, high = yarn_find_correction_range( - beta_fast, beta_slow, d, theta, original_max_position_embeddings + beta_fast, beta_slow, d_range, theta, original_max_position_embeddings ) high = tir.if_then_else(low == high, high + 0.001, high) inv_freq_mask = tir.const(1, "float32") - tir.max(