From 71145a8dcff70ece7703e6f5f794547d29e81d49 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 15 Feb 2025 20:19:21 -0500 Subject: [PATCH 1/5] Update argument order for relax.op.pad to make it round-trippable --- python/tvm/relax/frontend/nn/op.py | 4 ++-- python/tvm/relax/op/nn/nn.py | 10 +++++----- src/relax/op/nn/nn.cc | 2 +- tests/python/relax/test_frontend_nn_op.py | 2 +- tests/python/relax/test_op_nn.py | 10 +++++----- tests/python/relax/test_transform_legalize_ops_nn.py | 2 +- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 4664ec549388..20cecc65f5d1 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1464,7 +1464,7 @@ def pad( result : Tensor Padded output tensor. """ - return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_value=value, pad_mode=mode), name) + return wrap_nested(_op.nn.pad(x._expr, pad_value=value, pad_width=pad, pad_mode=mode), name) def square(x: Tensor, name: str = "square") -> Tensor: @@ -1567,7 +1567,7 @@ def get_timestep_embedding( # Zero pad if embedding_dim % 2 == 1: - emb = _op.nn.pad(emb, (0, 1, 0, 0)) + emb = _op.nn.pad(emb, 0, (0, 1, 0, 0)) # Cast to proper output type emb = _op.astype(emb, dtype) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 62d8b84321ce..55c164980561 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -513,7 +513,7 @@ def conv2d_transpose( ) -def pad(data, pad_width, pad_value=0, pad_mode="constant"): +def pad(data: Expr, pad_value: Union[int, Expr], pad_width: Tuple[Tuple[int, int], ...], pad_mode: str = "constant"): r"""Padding This operator takes in a tensor and pads each axis by the specified @@ -523,11 +523,11 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"): ---------- data: relax.Expr The input data to the operator - pad_width: tuple of >, required + pad_value: Union[int, relax.Expr] + The value used for padding. Default is 0. + pad_width: Tuple[Tuple[int, int], ...], required Number of values padded to the edges of each axis, in the format of ((before_1, after_1), ..., (before_N, after_N)) - pad_value: float - The value used for padding pad_mode: 'constant', 'edge', 'reflect' 'constant' pads with constant_value pad_value 'edge' pads using the edge values of the input array @@ -539,7 +539,7 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"): """ if not isinstance(pad_value, Expr): pad_value = const(pad_value) - return _ffi_api.pad(data, pad_width, pad_value, pad_mode) + return _ffi_api.pad(data, pad_value, pad_width, pad_mode) def max_pool1d( diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 7eccf47e4b06..7044775ab4d7 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -136,7 +136,7 @@ TVM_REGISTER_OP("relax.nn.log_softmax") /* relax.nn.pad */ TVM_REGISTER_NODE_TYPE(PadAttrs); -Expr pad(Expr data, Array pad_width, Expr pad_value, String pad_mode) { +Expr pad(Expr data, Expr pad_value, Array pad_width, String pad_mode) { auto attrs = make_object(); attrs->pad_width = std::move(pad_width); attrs->pad_mode = std::move(pad_mode); diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 6a337b34c114..5c49d7a75fb7 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -234,7 +234,7 @@ def test( ): R.func_attr({"num_input": 4}) with R.dataflow(): - lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, (0, 0, 0, 0, 1, 1, 1, 1)) + lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, 0, (0, 0, 0, 0, 1, 1, 1, 1)) lv1: R.Tensor((1, 32, 32, 32), dtype="float32") = R.nn.conv2d( lv0, weight, diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 7adfc8428355..fade6de9883a 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -32,7 +32,7 @@ def test_op_correctness(): assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax") assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") - assert relax.op.nn.pad(x, (1, 1, 1, 1)).op == Op.get("relax.nn.pad") + assert relax.op.nn.pad(x, 0, (1, 1, 1, 1)).op == Op.get("relax.nn.pad") x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) gamma = relax.Var("gamma", R.Tensor((3,), "float32")) @@ -1792,19 +1792,19 @@ def test_pad_infer_struct_info(): pad_width1 = (1, 1, 1, 1) pad_width2 = (0, 1, 1, 0) - _check_inference(bb, relax.op.nn.pad(x, pad_width0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.pad(x, 0, pad_width0), relax.TensorStructInfo((2, 3), "float32")) _check_inference( bb, - relax.op.nn.pad(x, pad_width1), + relax.op.nn.pad(x, 0, pad_width1), relax.TensorStructInfo((4, 5), dtype="float32"), ) _check_inference( bb, - relax.op.nn.pad(x, pad_width2), + relax.op.nn.pad(x, 0, pad_width2), relax.TensorStructInfo((3, 4), dtype="float32"), ) _check_inference( - bb, relax.op.nn.pad(x1, pad_width1), relax.TensorStructInfo(dtype="float32", ndim=2) + bb, relax.op.nn.pad(x1, 0, pad_width1), relax.TensorStructInfo(dtype="float32", ndim=2) ) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 12436cf8023f..c57da65c524b 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3614,7 +3614,7 @@ def test_pad(): class Pad: @R.function def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130, 30), "float32"): - gv: R.Tensor((2, 130, 30), "float32") = R.nn.pad(x, (0, 0, 1, 1, 1, 1)) + gv: R.Tensor((2, 130, 30), "float32") = R.nn.pad(x, 0, (0, 0, 1, 1, 1, 1)) return gv @tvm.script.ir_module From 49e6e3086888c20e3db70c9bdd42fdb9fb1446fd Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sat, 15 Feb 2025 20:52:09 -0500 Subject: [PATCH 2/5] fix lint --- python/tvm/relax/op/nn/nn.py | 7 ++++++- tests/python/relax/test_frontend_nn_op.py | 14 +++++++------- tests/python/relax/test_op_nn.py | 4 +++- .../python/relax/test_transform_legalize_ops_nn.py | 2 +- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 55c164980561..ed2508aba779 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -513,7 +513,12 @@ def conv2d_transpose( ) -def pad(data: Expr, pad_value: Union[int, Expr], pad_width: Tuple[Tuple[int, int], ...], pad_mode: str = "constant"): +def pad( + data: Expr, + pad_value: Union[int, Expr], + pad_width: Tuple[Tuple[int, int], ...], + pad_mode: str = "constant", +): r"""Padding This operator takes in a tensor and pads each axis by the specified diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 5c49d7a75fb7..e5df800cb1d8 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -234,7 +234,9 @@ def test( ): R.func_attr({"num_input": 4}) with R.dataflow(): - lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, 0, (0, 0, 0, 0, 1, 1, 1, 1)) + lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad( + x, 0, (0, 0, 0, 0, 1, 1, 1, 1) + ) lv1: R.Tensor((1, 32, 32, 32), dtype="float32") = R.nn.conv2d( lv0, weight, @@ -293,9 +295,7 @@ def test(self, x: Tensor): return chunk @R.function - def test( - x: R.Tensor((8,), dtype="float32"), _io: R.Object - ) -> R.Tuple( + def test(x: R.Tensor((8,), dtype="float32"), _io: R.Object) -> R.Tuple( R.Tuple( R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), @@ -490,9 +490,9 @@ def test( ) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): - scaled_dot_product_attention: R.Tensor( - (1, 32, 32, 32), dtype="float32" - ) = R.nn.attention(query, key, value, scale=None, causal_mask=None) + scaled_dot_product_attention: R.Tensor((1, 32, 32, 32), dtype="float32") = ( + R.nn.attention(query, key, value, scale=None, causal_mask=None) + ) gv1: R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)) = ( scaled_dot_product_attention, (_io,), diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index fade6de9883a..8038b7899276 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -1792,7 +1792,9 @@ def test_pad_infer_struct_info(): pad_width1 = (1, 1, 1, 1) pad_width2 = (0, 1, 1, 0) - _check_inference(bb, relax.op.nn.pad(x, 0, pad_width0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.nn.pad(x, 0, pad_width0), relax.TensorStructInfo((2, 3), "float32") + ) _check_inference( bb, relax.op.nn.pad(x, 0, pad_width1), diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index c57da65c524b..da3d13e30b12 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3621,7 +3621,7 @@ def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130, 30), "float3 class Expected: @R.function def main( - x: R.Tensor((2, 128, 28), dtype="float32") + x: R.Tensor((2, 128, 28), dtype="float32"), ) -> R.Tensor((2, 130, 30), dtype="float32"): gv = R.call_tir(Expected.pad, (x), out_sinfo=R.Tensor((2, 130, 30), dtype="float32")) return gv From b8d6c7c089ba32012ea4d5dd74b520371d8808e3 Mon Sep 17 00:00:00 2001 From: Renat Idrisov Date: Sun, 26 Jan 2025 23:30:00 +0000 Subject: [PATCH 3/5] Fix inconsistent IR when printing R.nn.pad operator --- include/tvm/relax/attrs/nn.h | 2 ++ python/tvm/relax/transform/legalize_ops/nn.py | 2 +- src/relax/op/nn/nn.cc | 8 ++++---- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e26cee26584b..2f861631c209 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -562,12 +562,14 @@ struct AttentionAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { Array pad_width; + runtime::Float pad_value; tvm::String pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { TVM_ATTR_FIELD(pad_width).describe( "Number of values padded to the edges of each axis, " "in the format of (before_1, after_1, ..., before_N, after_N)"); + TVM_ATTR_FIELD(pad_value).describe("The value to fill in padded area with"); TVM_ATTR_FIELD(pad_mode) .set_default("constant") .describe( diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 8317d4504e1e..d9fb4701f7e9 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -231,7 +231,7 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: call.args[0], pad_before=pad_before, pad_after=pad_after, - pad_value=float(call.args[1].data.numpy()), + pad_value=call.attrs.pad_value, primfunc_name_hint="pad", ) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 7044775ab4d7..8a28ebc9336d 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -136,12 +136,13 @@ TVM_REGISTER_OP("relax.nn.log_softmax") /* relax.nn.pad */ TVM_REGISTER_NODE_TYPE(PadAttrs); -Expr pad(Expr data, Expr pad_value, Array pad_width, String pad_mode) { +Expr pad(Expr data, Array pad_width, runtime::Float pad_value, String pad_mode) { auto attrs = make_object(); attrs->pad_width = std::move(pad_width); + attrs->pad_value = pad_value; attrs->pad_mode = std::move(pad_mode); static const Op& op = Op::Get("relax.nn.pad"); - return Call(op, {data, pad_value}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); @@ -171,9 +172,8 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { } TVM_REGISTER_OP("relax.nn.pad") - .set_num_inputs(2) + .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .add_argument("pad_value", "Tensor", "The value to fill in padded area with.") .set_attrs_type() .set_attr("FInferStructInfo", InferStructInfoPad) .set_attr("FPurity", Bool(true)); From 5e5df02fbc2136568a232bea5a047ea51a38c0c2 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sun, 16 Feb 2025 16:47:32 -0500 Subject: [PATCH 4/5] Change value to be the last arg --- include/tvm/relax/attrs/nn.h | 6 +++--- python/tvm/relax/frontend/nn/op.py | 6 +++--- python/tvm/relax/op/nn/nn.py | 17 +++++++++-------- src/relax/op/nn/nn.cc | 4 ++-- tests/python/relax/test_frontend_nn_op.py | 4 +--- tests/python/relax/test_op_nn.py | 12 +++++------- .../relax/test_transform_legalize_ops_nn.py | 2 +- 7 files changed, 24 insertions(+), 27 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 2f861631c209..832934417484 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -562,14 +562,14 @@ struct AttentionAttrs : public tvm::AttrsNode { /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { Array pad_width; - runtime::Float pad_value; + runtime::Float pad_value = 0.0; tvm::String pad_mode; - TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { + TVM_DECLARE_ATTRS(PadAttrs, "relax.attrs.PadAttrs") { TVM_ATTR_FIELD(pad_width).describe( "Number of values padded to the edges of each axis, " "in the format of (before_1, after_1, ..., before_N, after_N)"); - TVM_ATTR_FIELD(pad_value).describe("The value to fill in padded area with"); + TVM_ATTR_FIELD(pad_value).set_default(0.0).describe("The value to fill in padded area with"); TVM_ATTR_FIELD(pad_mode) .set_default("constant") .describe( diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 20cecc65f5d1..9e708ca874f0 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1438,7 +1438,7 @@ def pad( x: Tensor, pad: List[int], mode: str = "constant", - value: int = 0, + value: float = 0.0, name: str = "pad", ) -> Tensor: """ @@ -1454,7 +1454,7 @@ def pad( mod : str Padding mode to use, constant implies padded elements will use value argument. - value : int + value : float What to pad with in constant mode. name : str Name hint for this operator. @@ -1464,7 +1464,7 @@ def pad( result : Tensor Padded output tensor. """ - return wrap_nested(_op.nn.pad(x._expr, pad_value=value, pad_width=pad, pad_mode=mode), name) + return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_mode=mode, pad_value=value), name) def square(x: Tensor, name: str = "square") -> Tensor: diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index ed2508aba779..211d6661e062 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -515,9 +515,9 @@ def conv2d_transpose( def pad( data: Expr, - pad_value: Union[int, Expr], pad_width: Tuple[Tuple[int, int], ...], - pad_mode: str = "constant", + pad_mode: Optional[str] = "constant", + pad_value: Optional[Union[float, Expr]] = 0.0, ): r"""Padding @@ -528,23 +528,24 @@ def pad( ---------- data: relax.Expr The input data to the operator - pad_value: Union[int, relax.Expr] - The value used for padding. Default is 0. pad_width: Tuple[Tuple[int, int], ...], required Number of values padded to the edges of each axis, in the format of ((before_1, after_1), ..., (before_N, after_N)) - pad_mode: 'constant', 'edge', 'reflect' + pad_mode: Optional[str] + 'constant', 'edge', or 'reflect' 'constant' pads with constant_value pad_value 'edge' pads using the edge values of the input array 'reflect' pads by reflecting values with respect to the edge + Default is 'constant' + pad_value: Optional[Union[float, Expr]] + The value used for padding. Default is 0. + Returns ------- result : relax.Expr The computed result. """ - if not isinstance(pad_value, Expr): - pad_value = const(pad_value) - return _ffi_api.pad(data, pad_value, pad_width, pad_mode) + return _ffi_api.pad(data, pad_width, pad_mode, pad_value) def max_pool1d( diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 8a28ebc9336d..526b816d0945 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -136,11 +136,11 @@ TVM_REGISTER_OP("relax.nn.log_softmax") /* relax.nn.pad */ TVM_REGISTER_NODE_TYPE(PadAttrs); -Expr pad(Expr data, Array pad_width, runtime::Float pad_value, String pad_mode) { +Expr pad(Expr data, Array pad_width, String pad_mode, runtime::Float pad_value) { auto attrs = make_object(); attrs->pad_width = std::move(pad_width); - attrs->pad_value = pad_value; attrs->pad_mode = std::move(pad_mode); + attrs->pad_value = pad_value; static const Op& op = Op::Get("relax.nn.pad"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index e5df800cb1d8..a61d3a416ca2 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -234,9 +234,7 @@ def test( ): R.func_attr({"num_input": 4}) with R.dataflow(): - lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad( - x, 0, (0, 0, 0, 0, 1, 1, 1, 1) - ) + lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, (0, 0, 0, 0, 1, 1, 1, 1)) lv1: R.Tensor((1, 32, 32, 32), dtype="float32") = R.nn.conv2d( lv0, weight, diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 8038b7899276..7adfc8428355 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -32,7 +32,7 @@ def test_op_correctness(): assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax") assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") - assert relax.op.nn.pad(x, 0, (1, 1, 1, 1)).op == Op.get("relax.nn.pad") + assert relax.op.nn.pad(x, (1, 1, 1, 1)).op == Op.get("relax.nn.pad") x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) gamma = relax.Var("gamma", R.Tensor((3,), "float32")) @@ -1792,21 +1792,19 @@ def test_pad_infer_struct_info(): pad_width1 = (1, 1, 1, 1) pad_width2 = (0, 1, 1, 0) - _check_inference( - bb, relax.op.nn.pad(x, 0, pad_width0), relax.TensorStructInfo((2, 3), "float32") - ) + _check_inference(bb, relax.op.nn.pad(x, pad_width0), relax.TensorStructInfo((2, 3), "float32")) _check_inference( bb, - relax.op.nn.pad(x, 0, pad_width1), + relax.op.nn.pad(x, pad_width1), relax.TensorStructInfo((4, 5), dtype="float32"), ) _check_inference( bb, - relax.op.nn.pad(x, 0, pad_width2), + relax.op.nn.pad(x, pad_width2), relax.TensorStructInfo((3, 4), dtype="float32"), ) _check_inference( - bb, relax.op.nn.pad(x1, 0, pad_width1), relax.TensorStructInfo(dtype="float32", ndim=2) + bb, relax.op.nn.pad(x1, pad_width1), relax.TensorStructInfo(dtype="float32", ndim=2) ) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index da3d13e30b12..d83d0567e482 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3614,7 +3614,7 @@ def test_pad(): class Pad: @R.function def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130, 30), "float32"): - gv: R.Tensor((2, 130, 30), "float32") = R.nn.pad(x, 0, (0, 0, 1, 1, 1, 1)) + gv: R.Tensor((2, 130, 30), "float32") = R.nn.pad(x, (0, 0, 1, 1, 1, 1)) return gv @tvm.script.ir_module From 94fc6619e8456ba4bc61784cb1149a96201cb7fe Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Sun, 16 Feb 2025 19:02:10 -0500 Subject: [PATCH 5/5] Fix lint --- python/tvm/relax/frontend/nn/op.py | 2 +- python/tvm/relax/op/nn/nn.py | 2 +- tests/python/relax/test_frontend_nn_op.py | 10 ++++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 9e708ca874f0..4c6d921db79b 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1567,7 +1567,7 @@ def get_timestep_embedding( # Zero pad if embedding_dim % 2 == 1: - emb = _op.nn.pad(emb, 0, (0, 1, 0, 0)) + emb = _op.nn.pad(emb, (0, 1, 0, 0)) # Cast to proper output type emb = _op.astype(emb, dtype) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 211d6661e062..5a1895cbc14f 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -20,7 +20,7 @@ from tvm import DataType from tvm.tir import FloatImm -from ...expr import Expr, const +from ...expr import Expr from . import _ffi_api diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index a61d3a416ca2..6a337b34c114 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -293,7 +293,9 @@ def test(self, x: Tensor): return chunk @R.function - def test(x: R.Tensor((8,), dtype="float32"), _io: R.Object) -> R.Tuple( + def test( + x: R.Tensor((8,), dtype="float32"), _io: R.Object + ) -> R.Tuple( R.Tuple( R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), @@ -488,9 +490,9 @@ def test( ) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): - scaled_dot_product_attention: R.Tensor((1, 32, 32, 32), dtype="float32") = ( - R.nn.attention(query, key, value, scale=None, causal_mask=None) - ) + scaled_dot_product_attention: R.Tensor( + (1, 32, 32, 32), dtype="float32" + ) = R.nn.attention(query, key, value, scale=None, causal_mask=None) gv1: R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)) = ( scaled_dot_product_attention, (_io,),