From 6c7f7617c1edf3d9f0ce4c640bd5b78a09db9603 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Mon, 11 Aug 2025 17:52:51 -0400 Subject: [PATCH 01/13] changes to support llama4; Note: cleanup needed --- python/tvm/relax/expr.py | 4 ++ .../frontend/nn/llm/position_embedding.py | 39 ++++++++++---- python/tvm/relax/frontend/nn/op.py | 51 +++++++++++++++++++ src/relax/analysis/struct_info_analysis.cc | 15 ++++-- src/relax/op/op.cc | 3 ++ tests/python/relax/test_frontend_nn_op.py | 7 ++- 6 files changed, 106 insertions(+), 13 deletions(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 1a7a5c224add..4467a56ead59 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -22,6 +22,10 @@ import numpy as _np # type: ignore import tvm_ffi +import ml_dtypes + +import tvm +import tvm.ffi import tvm.ir import tvm.relax from tvm import DataType diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 1a1659b29e18..cd75217eeba4 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -75,6 +75,7 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st return cos_freq, sin_freq, {freq_var: freq} + def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals s: tir.Var, d: tir.Var, @@ -91,14 +92,34 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals theta, d * 2 % d_range / tir.const(d_range, "float32") ) orig_freq_var = tir.Var("orig_freq", "float32") - inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) + llama3_inv_scaling_factor = 1.0 / factor - llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor - llama3_beta = low_freq_factor * inv_diff_freq_factor - smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) - smoothed_freq = s * ( - (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var - ) + + if high_freq_factor == low_freq_factor: + # When factors are equal, use simple threshold-based scaling + # Check if wavelength > threshold using TIR conditional + wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var + threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") + + # Use tir.if_then_else for conditional logic + scaled_freq = tir.if_then_else( + wavelength > threshold_wavelen, + orig_freq_var * tir.const(llama3_inv_scaling_factor, "float32"), + orig_freq_var + ) + smoothed_freq = s * scaled_freq + + else: + # Original smooth interpolation logic + inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) + + llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor + llama3_beta = low_freq_factor * inv_diff_freq_factor + smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) + smoothed_freq = s * ( + (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var + ) + smoothed_freq_var = tir.Var("smoothed_freq", "float32") cos_freq = tir.cos(smoothed_freq_var).astype(dtype) sin_freq = tir.sin(smoothed_freq_var).astype(dtype) @@ -444,14 +465,14 @@ def _rope( # pylint: disable=too-many-arguments expr = tir.Let(var, value, expr) return expr - @T.prim_func + @T.prim_func(private=True) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, - apply_rope: T.int32, + apply_rope: T.int64, ): T.func_attr( { diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 714ae9478250..93031bfae97b 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1173,6 +1173,54 @@ def exp(x: Tensor, name: str = "exp") -> Tensor: """ return wrap_nested(_op.exp(x._expr), name) +def log(x: Tensor, name: str = "log") -> Tensor: + r"""Applies the natural logarithm function. + + .. math:: + \text{Log}(x) = \log(x) + + Parameters + ---------- + x : Tensor + The input data to the operator. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.log(x._expr), name) + +def floor(x: Tensor, name: str = "floor") -> Tensor: + r"""Computes the floor of the input tensor. + + .. math:: + \text{Floor}(x) = \floor(x) + + Parameters + ---------- + x : Tensor + The input data to the operator. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.floor(x._expr), name) def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor: """Permutes the dimensions of the input tensor. @@ -2006,6 +2054,9 @@ def tensor_ir_op( if len(tir_vars) == 0: tir_vars = None + # if tir_vars: + # print(f"tir_vars {tir_vars} dtype: {[tir_var.dtype for tir_var in tir_vars]}") + return wrap_nested( bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, tir_vars=tir_vars)), name=name_hint, diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 3952b1ce4a6e..4587e0a590a0 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -321,7 +321,9 @@ class StructInfoBaseChecker BaseCheckResult VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()) { + return BaseCheckResult::kFailL1; + } return BaseCheckResult::kFailL0; } @@ -425,7 +427,11 @@ class StructInfoBaseChecker BaseCheckResult VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) return BaseCheckResult::kFailL1; + if (other.as()){ + LOG(INFO) << "1"; + return BaseCheckResult::kFailL1; + } + LOG(INFO) << "2"; return BaseCheckResult::kFailL0; } return ArrayCheck(lhs->fields, rhs->fields); @@ -586,7 +592,10 @@ class StructInfoBaseChecker for (size_t i = 0; i < lhs.size(); ++i) { auto cmp_ret = this->VisitStructInfo(lhs[i], rhs[i]); - if (ret == BaseCheckResult::kFailL0) return ret; + if (ret == BaseCheckResult::kFailL0){ + LOG(INFO) << "4"; + return ret; + } ret = CombineCheck(cmp_ret, ret); } return ret; diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index e15d87472316..ec16840566a7 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -566,6 +566,9 @@ void ValidateCallTIR(Call call) { auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); if (inferred_sinfo.defined()) { + // if (!IsBaseOf(inferred_sinfo.value(), explicit_sinfo)){ + // LOG(INFO) << "inferred_sinfo" << inferred_sinfo.value() << "explicit_sinfo" << explicit_sinfo; + // } CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) << "TypeError: " << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index e827f643b33c..1d937017730d 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -398,6 +398,8 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor): rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1]) rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1]) group_norm_out = op.group_norm(x, num_groups=1, weight=bias, bias=bias) + log_out = op.log(x) + floor_out = op.floor(x) return x @R.function @@ -409,6 +411,8 @@ def test( ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): + log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.log(x) + floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.floor(x) relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x) relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x) silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) @@ -1269,4 +1273,5 @@ def foo(x: R.Tensor(("seq_len", 64), dtype="float16")): if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_tensor_ir_op() From f18417dc8ff8e6756e2cda8aa19b387b905b9af4 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Mon, 8 Sep 2025 21:40:51 -0400 Subject: [PATCH 02/13] custom rope for llama4 --- .../frontend/nn/llm/position_embedding.py | 249 +++++++++++++++++- 1 file changed, 248 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index cd75217eeba4..8ad3832b6850 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -75,6 +75,59 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st return cos_freq, sin_freq, {freq_var: freq} +def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals + s: tir.Var, + d: tir.Var, + d_range: int, + theta: float, + dtype: str, + factor: float, + low_freq_factor: float, + high_freq_factor: float, + original_max_position_embeddings: float, +): + """Compute the inverse frequency of RoPE for llama3 RoPE scaling.""" + # # LLama3 impl + # orig_freq = tir.const(1, "float32") / tir.power( + # theta, d * 2 % d_range / tir.const(d_range, "float32") + # ) + # Modified impl + orig_freq = tir.const(1, "float32") / tir.power( + theta, 2 * (d // 2) / tir.const(d_range, "float32") + ) + orig_freq_var = tir.Var("orig_freq", "float32") + + llama3_inv_scaling_factor = 1.0 / factor + + if high_freq_factor == low_freq_factor: + # When factors are equal, use simple threshold-based scaling + # Check if wavelength > threshold using TIR conditional + wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var + threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") + + # Use tir.if_then_else for conditional logic + scaled_freq = tir.if_then_else( + wavelength > threshold_wavelen, + orig_freq_var / factor, #* tir.const(llama3_inv_scaling_factor, "float32"), + orig_freq_var + ) + smoothed_freq = s * scaled_freq + + else: + # Original smooth interpolation logic + inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) + + llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor + llama3_beta = low_freq_factor * inv_diff_freq_factor + smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) + smoothed_freq = s * ( + (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var + ) + + smoothed_freq_var = tir.Var("smoothed_freq", "float32") + cos_freq = tir.cos(smoothed_freq_var).astype(dtype) + sin_freq = tir.sin(smoothed_freq_var).astype(dtype) + return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq} def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals s: tir.Var, @@ -104,7 +157,7 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals # Use tir.if_then_else for conditional logic scaled_freq = tir.if_then_else( wavelength > threshold_wavelen, - orig_freq_var * tir.const(llama3_inv_scaling_factor, "float32"), + orig_freq_var / factor, #* tir.const(llama3_inv_scaling_factor, "float32"), orig_freq_var ) smoothed_freq = s * scaled_freq @@ -229,6 +282,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable: high_freq_factor=rope_scaling["high_freq_factor"], original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], ) + if rope_scaling["rope_type"] == "llama4": + return partial( + rope_freq_llama4, + factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + ) if rope_scaling["rope_type"] == "longrope": return partial( rope_freq_longrope, @@ -566,3 +627,189 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals if is_longrope_scaling: return fused_rope_longrope_scaling return fused_rope + + +def llama4_rope_with_position_map( # pylint: disable=too-many-arguments + theta: float, + scale: float, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + dtype: str, + rope_scaling: Dict[str, Any], + rotary_dim: Optional[int] = None, +): + """Return the TIR function that computes Llama-style RoPE with q position map. + + Parameters + ---------- + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + head_dim : int + The number of features on each head. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + dtype : str + The dtype of qkv data. + + rope_scaling : Dict + The configuration of RoPE scaling. + + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + """ + fused_heads = num_q_heads + num_kv_heads * 2 + if rotary_dim is None: + rotary_dim = head_dim + scale = tir.const(scale, "float32") + is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" + + def _rope( # pylint: disable=too-many-arguments + x: T.Buffer, + s: tir.Var, + h: tir.Var, + d: tir.Var, + pos: tir.Var, + ext_factors: Optional[T.Buffer] = None, + ): + kwargs = {} + if ext_factors: + kwargs["ext_factors"] = ext_factors + cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( + pos * scale, d, rotary_dim, theta, "float32", **kwargs + ) + cos = cos_freq * x[s, h, d].astype("float32") + if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj": + sin = sin_freq * 0.0 + # sin = sin_freq * tir.if_then_else( + # d % 2 == 0, + # -x[s, h, d + 1], + # x[s, h, d - 1], + # ).astype("float32") + else: + # Data layout is different for llama4 vs llama3 + sin = sin_freq * 0.0 + # sin = sin_freq * tir.if_then_else( + # # d < rotary_dim // 2, + # # -x[s, h, d + rotary_dim // 2], + # # x[s, h, d - rotary_dim // 2], + # d % 2 == 0, + # -x[s, h, d + 1], + # x[s, h, d - 1], + # ).astype("float32") + expr = (cos + sin).astype(dtype) + for var, value in var_map.items(): + expr = tir.Let(var, value, expr) + return expr + + @T.prim_func(private=True) + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + apply_rope: T.int64, + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": True, + } + ) + seq_len = T.int32() + position_map_elem_offset = T.int32() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + + @T.prim_func + def fused_rope_longrope_scaling( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": True, + } + ) + seq_len = T.int64() + position_map_elem_offset = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + ext_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + ext_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + + if is_longrope_scaling: + return fused_rope_longrope_scaling + return fused_rope From 12c44911fb3ee30c74436b7f23118482ad767f3d Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Thu, 11 Sep 2025 22:55:23 -0400 Subject: [PATCH 03/13] custom rope for llama4 and ops --- .../frontend/nn/llm/position_embedding.py | 29 +++++++++---------- python/tvm/relax/frontend/nn/op.py | 28 ++++++++++++++++++ tests/python/relax/test_frontend_nn_op.py | 3 ++ 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 8ad3832b6850..7f914a98a220 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -690,23 +690,22 @@ def _rope( # pylint: disable=too-many-arguments ) cos = cos_freq * x[s, h, d].astype("float32") if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj": - sin = sin_freq * 0.0 - # sin = sin_freq * tir.if_then_else( - # d % 2 == 0, - # -x[s, h, d + 1], - # x[s, h, d - 1], - # ).astype("float32") + sin = sin_freq * tir.if_then_else( + d % 2 == 0, + -x[s, h, d + 1], + x[s, h, d - 1], + ).astype("float32") else: # Data layout is different for llama4 vs llama3 - sin = sin_freq * 0.0 - # sin = sin_freq * tir.if_then_else( - # # d < rotary_dim // 2, - # # -x[s, h, d + rotary_dim // 2], - # # x[s, h, d - rotary_dim // 2], - # d % 2 == 0, - # -x[s, h, d + 1], - # x[s, h, d - 1], - # ).astype("float32") + # sin = sin_freq * 0.0 + sin = sin_freq * tir.if_then_else( + # d < rotary_dim // 2, + # -x[s, h, d + rotary_dim // 2], + # x[s, h, d - rotary_dim // 2], + d % 2 == 0, + -x[s, h, d + 1], + x[s, h, d - 1], + ).astype("float32") expr = (cos + sin).astype(dtype) for var, value in var_map.items(): expr = tir.Let(var, value, expr) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 93031bfae97b..3ffbc8110440 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1222,6 +1222,34 @@ def floor(x: Tensor, name: str = "floor") -> Tensor: """ return wrap_nested(_op.floor(x._expr), name) +def arange(start: int, end: Optional[int] = None, step: int = 1, dtype: Optional[str] = "float32", name: str = "arange") -> Tensor: + r"""Construct a tensor with evenly spaced elements. + + Parameters + ---------- + start : int + The start of the interval. + + end : Optional[int] + The end of the interval. If not given, it will be set to start, + and start will be set to 0. + + step : int + The step size. + + dtype : Optional[str] + The data type of the created tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.arange(start, end, step, dtype), name) + def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor: """Permutes the dimensions of the input tensor. diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 1d937017730d..035df84d453d 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -467,6 +467,8 @@ def test(self, x: Tensor): ) zeros_out = op.zeros([10, 10]) zeros_fp16_out = op.zeros([10, 10], dtype="float16") + + arange_out = op.arange(0, 10, 1, "float32") return x # fmt: off @@ -480,6 +482,7 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 10]), R.const(10, "float32"), dtype="float32") zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10, 10]), dtype="float32") zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10, 10]), dtype="float16") + arange: R.Tensor((10), dtype="float32") = R.arange(R.const(0, "int"), R.const(10, "int"), R.const(1, "int"), dtype="float32") gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io,) R.output(gv1) return gv1 From c906f562b584577fcead85c257861b690c994e86 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Fri, 12 Sep 2025 21:57:53 -0400 Subject: [PATCH 04/13] cleanup rope dead code --- python/tvm/relax/frontend/nn/llm/position_embedding.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 7f914a98a220..ea64692f3679 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -697,11 +697,7 @@ def _rope( # pylint: disable=too-many-arguments ).astype("float32") else: # Data layout is different for llama4 vs llama3 - # sin = sin_freq * 0.0 - sin = sin_freq * tir.if_then_else( - # d < rotary_dim // 2, - # -x[s, h, d + rotary_dim // 2], - # x[s, h, d - rotary_dim // 2], + sin = sin_freq * tir.if_then_else( d % 2 == 0, -x[s, h, d + 1], x[s, h, d - 1], From c232fddff7c460db359e791e99ede9af7f877904 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Fri, 12 Sep 2025 22:36:54 -0400 Subject: [PATCH 05/13] remove debugging cases --- .../frontend/nn/llm/position_embedding.py | 46 ++++--------------- python/tvm/relax/frontend/nn/op.py | 2 - src/relax/analysis/struct_info_analysis.cc | 15 ++---- src/relax/op/op.cc | 3 -- tests/python/relax/test_frontend_nn_op.py | 3 +- 5 files changed, 13 insertions(+), 56 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index ea64692f3679..e09d60ecac15 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -87,11 +87,6 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals original_max_position_embeddings: float, ): """Compute the inverse frequency of RoPE for llama3 RoPE scaling.""" - # # LLama3 impl - # orig_freq = tir.const(1, "float32") / tir.power( - # theta, d * 2 % d_range / tir.const(d_range, "float32") - # ) - # Modified impl orig_freq = tir.const(1, "float32") / tir.power( theta, 2 * (d // 2) / tir.const(d_range, "float32") ) @@ -100,15 +95,12 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals llama3_inv_scaling_factor = 1.0 / factor if high_freq_factor == low_freq_factor: - # When factors are equal, use simple threshold-based scaling - # Check if wavelength > threshold using TIR conditional wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") - - # Use tir.if_then_else for conditional logic + scaled_freq = tir.if_then_else( wavelength > threshold_wavelen, - orig_freq_var / factor, #* tir.const(llama3_inv_scaling_factor, "float32"), + orig_freq_var / factor, orig_freq_var ) smoothed_freq = s * scaled_freq @@ -145,34 +137,14 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals theta, d * 2 % d_range / tir.const(d_range, "float32") ) orig_freq_var = tir.Var("orig_freq", "float32") - + inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) llama3_inv_scaling_factor = 1.0 / factor - - if high_freq_factor == low_freq_factor: - # When factors are equal, use simple threshold-based scaling - # Check if wavelength > threshold using TIR conditional - wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var - threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") - - # Use tir.if_then_else for conditional logic - scaled_freq = tir.if_then_else( - wavelength > threshold_wavelen, - orig_freq_var / factor, #* tir.const(llama3_inv_scaling_factor, "float32"), - orig_freq_var - ) - smoothed_freq = s * scaled_freq - - else: - # Original smooth interpolation logic - inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) - - llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor - llama3_beta = low_freq_factor * inv_diff_freq_factor - smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) - smoothed_freq = s * ( - (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var - ) - + llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor + llama3_beta = low_freq_factor * inv_diff_freq_factor + smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) + smoothed_freq = s * ( + (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var + ) smoothed_freq_var = tir.Var("smoothed_freq", "float32") cos_freq = tir.cos(smoothed_freq_var).astype(dtype) sin_freq = tir.sin(smoothed_freq_var).astype(dtype) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 3ffbc8110440..853a85f512b5 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -2082,8 +2082,6 @@ def tensor_ir_op( if len(tir_vars) == 0: tir_vars = None - # if tir_vars: - # print(f"tir_vars {tir_vars} dtype: {[tir_var.dtype for tir_var in tir_vars]}") return wrap_nested( bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, tir_vars=tir_vars)), diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 4587e0a590a0..3952b1ce4a6e 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -321,9 +321,7 @@ class StructInfoBaseChecker BaseCheckResult VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()) { - return BaseCheckResult::kFailL1; - } + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } @@ -427,11 +425,7 @@ class StructInfoBaseChecker BaseCheckResult VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - if (other.as()){ - LOG(INFO) << "1"; - return BaseCheckResult::kFailL1; - } - LOG(INFO) << "2"; + if (other.as()) return BaseCheckResult::kFailL1; return BaseCheckResult::kFailL0; } return ArrayCheck(lhs->fields, rhs->fields); @@ -592,10 +586,7 @@ class StructInfoBaseChecker for (size_t i = 0; i < lhs.size(); ++i) { auto cmp_ret = this->VisitStructInfo(lhs[i], rhs[i]); - if (ret == BaseCheckResult::kFailL0){ - LOG(INFO) << "4"; - return ret; - } + if (ret == BaseCheckResult::kFailL0) return ret; ret = CombineCheck(cmp_ret, ret); } return ret; diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ec16840566a7..e15d87472316 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -566,9 +566,6 @@ void ValidateCallTIR(Call call) { auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); if (inferred_sinfo.defined()) { - // if (!IsBaseOf(inferred_sinfo.value(), explicit_sinfo)){ - // LOG(INFO) << "inferred_sinfo" << inferred_sinfo.value() << "explicit_sinfo" << explicit_sinfo; - // } CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) << "TypeError: " << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 035df84d453d..9825a23294d0 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -1276,5 +1276,4 @@ def foo(x: R.Tensor(("seq_len", 64), dtype="float16")): if __name__ == "__main__": - # tvm.testing.main() - test_tensor_ir_op() + tvm.testing.main() From ec45ee69bb2a729664f51a60e4f06b93f687080e Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Sat, 13 Sep 2025 11:29:01 -0400 Subject: [PATCH 06/13] reformat --- .../tvm/relax/frontend/nn/llm/position_embedding.py | 13 ++++++------- python/tvm/relax/frontend/nn/op.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index e09d60ecac15..dfe2c83a3f4a 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -99,16 +99,14 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") scaled_freq = tir.if_then_else( - wavelength > threshold_wavelen, - orig_freq_var / factor, - orig_freq_var + wavelength > threshold_wavelen, orig_freq_var / factor, orig_freq_var ) smoothed_freq = s * scaled_freq - + else: # Original smooth interpolation logic inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) - + llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor llama3_beta = low_freq_factor * inv_diff_freq_factor smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) @@ -121,6 +119,7 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals sin_freq = tir.sin(smoothed_freq_var).astype(dtype) return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq} + def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals s: tir.Var, d: tir.Var, @@ -668,8 +667,8 @@ def _rope( # pylint: disable=too-many-arguments x[s, h, d - 1], ).astype("float32") else: - # Data layout is different for llama4 vs llama3 - sin = sin_freq * tir.if_then_else( + # Data layout is different for llama4 vs llama3 + sin = sin_freq * tir.if_then_else( d % 2 == 0, -x[s, h, d + 1], x[s, h, d - 1], diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 853a85f512b5..50d4772d8ca1 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1173,6 +1173,7 @@ def exp(x: Tensor, name: str = "exp") -> Tensor: """ return wrap_nested(_op.exp(x._expr), name) + def log(x: Tensor, name: str = "log") -> Tensor: r"""Applies the natural logarithm function. @@ -1197,6 +1198,7 @@ def log(x: Tensor, name: str = "log") -> Tensor: """ return wrap_nested(_op.log(x._expr), name) + def floor(x: Tensor, name: str = "floor") -> Tensor: r"""Computes the floor of the input tensor. @@ -1222,7 +1224,14 @@ def floor(x: Tensor, name: str = "floor") -> Tensor: """ return wrap_nested(_op.floor(x._expr), name) -def arange(start: int, end: Optional[int] = None, step: int = 1, dtype: Optional[str] = "float32", name: str = "arange") -> Tensor: + +def arange( + start: int, + end: Optional[int] = None, + step: int = 1, + dtype: Optional[str] = "float32", + name: str = "arange", +) -> Tensor: r"""Construct a tensor with evenly spaced elements. Parameters @@ -1250,6 +1259,7 @@ def arange(start: int, end: Optional[int] = None, step: int = 1, dtype: Optional """ return wrap_nested(_op.arange(start, end, step, dtype), name) + def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor: """Permutes the dimensions of the input tensor. @@ -2082,7 +2092,6 @@ def tensor_ir_op( if len(tir_vars) == 0: tir_vars = None - return wrap_nested( bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, tir_vars=tir_vars)), name=name_hint, From 7c18090877c84f82dc6e75616c39b54175f9c73e Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Wed, 17 Sep 2025 14:29:04 -0400 Subject: [PATCH 07/13] moved ml_dtypes import --- python/tvm/relax/expr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 4467a56ead59..becb931362f8 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -22,7 +22,6 @@ import numpy as _np # type: ignore import tvm_ffi -import ml_dtypes import tvm import tvm.ffi @@ -1157,6 +1156,9 @@ def const( - bool maps to "bool" - other using the same default rule as numpy. """ + # Needed for bf16 and fp8 support (does not come with numpy) + import ml_dtypes # pylint: disable=unused-import + if isinstance(value, (Number, (bool, list))): value = _np.array(value, dtype=dtype) From e04181fe714954fa69df5aacc28dcb9e7137644e Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Wed, 17 Sep 2025 14:52:51 -0400 Subject: [PATCH 08/13] lint fix --- python/tvm/relax/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index becb931362f8..d18922e09258 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1157,7 +1157,7 @@ def const( - other using the same default rule as numpy. """ # Needed for bf16 and fp8 support (does not come with numpy) - import ml_dtypes # pylint: disable=unused-import + import ml_dtypes # pylint: disable=unused-import if isinstance(value, (Number, (bool, list))): value = _np.array(value, dtype=dtype) From f3de2604a1f2090b0b646fdb5d97c2bd8993747f Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Wed, 17 Sep 2025 15:25:00 -0400 Subject: [PATCH 09/13] lint fix --- python/tvm/relax/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index d18922e09258..149c9269b75e 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1157,7 +1157,7 @@ def const( - other using the same default rule as numpy. """ # Needed for bf16 and fp8 support (does not come with numpy) - import ml_dtypes # pylint: disable=unused-import + import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel if isinstance(value, (Number, (bool, list))): value = _np.array(value, dtype=dtype) From ff7561df32f9ed64b1f08bd6779a3b3982276634 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Wed, 17 Sep 2025 15:25:50 -0400 Subject: [PATCH 10/13] remove imports --- python/tvm/relax/expr.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 149c9269b75e..8dd4eff5c703 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -23,8 +23,6 @@ import tvm_ffi -import tvm -import tvm.ffi import tvm.ir import tvm.relax from tvm import DataType From 4eb5a7eb39d1615604438519a47319969b7f694b Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Fri, 19 Sep 2025 11:57:45 -0400 Subject: [PATCH 11/13] fix tests --- tests/python/relax/test_frontend_nn_op.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 9825a23294d0..e83b697cac1d 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -384,6 +384,8 @@ def test( def test_nn(): class Model(Module): def test(self, x: Tensor, weight: Tensor, bias: Tensor): + log_out = op.log(x) + floor_out = op.floor(x) relu_out = op.relu(x) relu6_out = op.relu6(x) silu_out = op.silu(x) @@ -398,8 +400,6 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor): rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1]) rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1]) group_norm_out = op.group_norm(x, num_groups=1, weight=bias, bias=bias) - log_out = op.log(x) - floor_out = op.floor(x) return x @R.function @@ -411,8 +411,8 @@ def test( ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): - log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.log(x) - floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.floor(x) + log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.log(x) + floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.floor(x) relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x) relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x) silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) @@ -482,7 +482,7 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 10]), R.const(10, "float32"), dtype="float32") zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10, 10]), dtype="float32") zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10, 10]), dtype="float16") - arange: R.Tensor((10), dtype="float32") = R.arange(R.const(0, "int"), R.const(10, "int"), R.const(1, "int"), dtype="float32") + arange: R.Tensor((10,), dtype="float32") = R.arange(T.int64(0), T.int64(10), T.int64(1), dtype="float32") gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io,) R.output(gv1) return gv1 @@ -511,7 +511,7 @@ def test( lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32") lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, axis=[1]) lv3: R.Tensor((5,), dtype="float32") = R.arange( - R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="float32" + R.prim_value(T.int64(0)), R.prim_value(T.int64(5)), R.prim_value(T.int64(1)), dtype="float32" ) lv4: R.Tensor((5,), dtype="float32") = R.multiply( R.const(-9.2103404998779297, "float32"), lv3 From 9d7f40a6128b48da0be2fd0ad0e4e3b6bab4300e Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Fri, 19 Sep 2025 12:11:23 -0400 Subject: [PATCH 12/13] black format --- tests/python/relax/test_frontend_nn_op.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index e83b697cac1d..50c8225c0154 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -341,9 +341,7 @@ def test(self, x: Tensor): return chunk @R.function - def test( - x: R.Tensor((8,), dtype="float32"), _io: R.Object - ) -> R.Tuple( + def test(x: R.Tensor((8,), dtype="float32"), _io: R.Object) -> R.Tuple( R.Tuple( R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), @@ -511,7 +509,10 @@ def test( lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32") lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, axis=[1]) lv3: R.Tensor((5,), dtype="float32") = R.arange( - R.prim_value(T.int64(0)), R.prim_value(T.int64(5)), R.prim_value(T.int64(1)), dtype="float32" + R.prim_value(T.int64(0)), + R.prim_value(T.int64(5)), + R.prim_value(T.int64(1)), + dtype="float32", ) lv4: R.Tensor((5,), dtype="float32") = R.multiply( R.const(-9.2103404998779297, "float32"), lv3 @@ -553,9 +554,9 @@ def test( ) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): - scaled_dot_product_attention: R.Tensor( - (1, 32, 32, 32), dtype="float32" - ) = R.nn.attention(query, key, value, scale=None, causal_mask=None) + scaled_dot_product_attention: R.Tensor((1, 32, 32, 32), dtype="float32") = ( + R.nn.attention(query, key, value, scale=None, causal_mask=None) + ) gv1: R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)) = ( scaled_dot_product_attention, (_io,), From 45e2dd147a3dd561979f63ba45d47496f1a71f66 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Fri, 19 Sep 2025 14:21:03 -0400 Subject: [PATCH 13/13] undo lint for 2 lines --- tests/python/relax/test_frontend_nn_op.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 50c8225c0154..28c11f6dfaf5 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -341,7 +341,9 @@ def test(self, x: Tensor): return chunk @R.function - def test(x: R.Tensor((8,), dtype="float32"), _io: R.Object) -> R.Tuple( + def test( + x: R.Tensor((8,), dtype="float32"), _io: R.Object + ) -> R.Tuple( R.Tuple( R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), @@ -554,9 +556,9 @@ def test( ) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): - scaled_dot_product_attention: R.Tensor((1, 32, 32, 32), dtype="float32") = ( - R.nn.attention(query, key, value, scale=None, causal_mask=None) - ) + scaled_dot_product_attention: R.Tensor( + (1, 32, 32, 32), dtype="float32" + ) = R.nn.attention(query, key, value, scale=None, causal_mask=None) gv1: R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)) = ( scaled_dot_product_attention, (_io,),