From d50633b9249a539b42cffec3d8b98455a09001ac Mon Sep 17 00:00:00 2001 From: lightzhan-intellif Date: Wed, 4 Jan 2023 02:24:03 +0000 Subject: [PATCH 1/4] Fix the roundtripability of pow intrinsic. --- python/tvm/script/ir_builder/tir/ir.py | 4 ++-- python/tvm/te/__init__.py | 2 +- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 2 +- tests/python/unittest/test_tvmscript_roundtrip.py | 9 +++++++++ 5 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 842e21378fd1..4adcb27886c5 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1532,7 +1532,7 @@ def wrapped(*args, **kwargs): nearbyint = _op_wrapper(_tir_op.nearbyint) nextafter = _op_wrapper(_tir_op.nextafter) popcount = _op_wrapper(_tir_op.popcount) -power = _op_wrapper(_tir_op.power) +pow = _op_wrapper(_tir_op.pow) q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) @@ -1713,7 +1713,7 @@ def f(): "nearbyint", "nextafter", "popcount", - "power", + "pow", "q_multiply_shift", "q_multiply_shift_per_axis", "ret", diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 0907ea2ebf85..af1ad59d5cbb 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -22,7 +22,7 @@ from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, sqrt, rsqrt, floor, ceil from tvm.tir import sinh, cosh, log2, log10 from tvm.tir import asin, asinh, acos, acosh, atan, atanh -from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else +from tvm.tir import trunc, abs, round, nearbyint, pow, popcount, fmod, if_then_else from tvm.tir import isnan, isfinite, isinf from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import comm_reducer, min, max, sum diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a2e341d82354..f9bd98c81c21 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -69,7 +69,7 @@ from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else +from .op import trunc, abs, round, nextafter, nearbyint, pow, popcount, fmod, if_then_else from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index e1adc0a6bbd7..408fa29c5a0c 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -2214,7 +2214,7 @@ def isinf(x, span=None): return _ffi_api.isinf(x, span) # type: ignore -def power(x, y, span=None): +def pow(x, y, span=None): """x power y Parameters diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c0174a0671c0..a4699194bd85 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3622,5 +3622,14 @@ def test_return_none_no_trailing_type(): assert "-> None" not in script +def test_pow_roundtripable(): + @T.prim_func + def original_pow(): + T.pow(T.float32(1), T.float32(1)) + + after_roundtrip = tvm.script.from_source(original_pow.script(show_meta=True)) + tvm.ir.assert_structural_equal(original_pow, after_roundtrip, True) + + if __name__ == "__main__": tvm.testing.main() From 9e70c44859752532a5de78050d14bf52c7e49f7a Mon Sep 17 00:00:00 2001 From: lightzhan-intellif Date: Wed, 4 Jan 2023 03:20:46 +0000 Subject: [PATCH 2/4] fix the lint. --- python/tvm/script/ir_builder/tir/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 4adcb27886c5..ac1e990a96e2 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1532,7 +1532,7 @@ def wrapped(*args, **kwargs): nearbyint = _op_wrapper(_tir_op.nearbyint) nextafter = _op_wrapper(_tir_op.nextafter) popcount = _op_wrapper(_tir_op.popcount) -pow = _op_wrapper(_tir_op.pow) +pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) From 42affc653e58e1fde147af252a0c37211815af4d Mon Sep 17 00:00:00 2001 From: lightzhan-intellif Date: Wed, 4 Jan 2023 07:06:05 +0000 Subject: [PATCH 3/4] Fix the lint. --- python/tvm/script/ir_builder/tir/ir.py | 2 +- python/tvm/te/__init__.py | 2 +- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 2 +- .../unittest/test_tvmscript_roundtrip.py | 18 +++++++++--------- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ac1e990a96e2..c6f9ac6263ea 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1532,7 +1532,7 @@ def wrapped(*args, **kwargs): nearbyint = _op_wrapper(_tir_op.nearbyint) nextafter = _op_wrapper(_tir_op.nextafter) popcount = _op_wrapper(_tir_op.popcount) -pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin +pow = _op_wrapper(_tir_op.power) # pylint: disable=redefined-builtin q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index af1ad59d5cbb..0907ea2ebf85 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -22,7 +22,7 @@ from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, sqrt, rsqrt, floor, ceil from tvm.tir import sinh, cosh, log2, log10 from tvm.tir import asin, asinh, acos, acosh, atan, atanh -from tvm.tir import trunc, abs, round, nearbyint, pow, popcount, fmod, if_then_else +from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else from tvm.tir import isnan, isfinite, isinf from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import comm_reducer, min, max, sum diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f9bd98c81c21..a2e341d82354 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -69,7 +69,7 @@ from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, pow, popcount, fmod, if_then_else +from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 408fa29c5a0c..e1adc0a6bbd7 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -2214,7 +2214,7 @@ def isinf(x, span=None): return _ffi_api.isinf(x, span) # type: ignore -def pow(x, y, span=None): +def power(x, y, span=None): """x power y Parameters diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index a4699194bd85..0e9be0463943 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3550,6 +3550,14 @@ def func(A: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]): return mod["main"] +def intrinsic_pow(): + @T.prim_func + def func(): + T.pow(T.float32(1), T.float32(1)) + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3607,6 +3615,7 @@ def func(A: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]): elif_chain_with_else, *nested_boolean_expressions(), multi_env_threads, + intrinsic_pow, ) @@ -3622,14 +3631,5 @@ def test_return_none_no_trailing_type(): assert "-> None" not in script -def test_pow_roundtripable(): - @T.prim_func - def original_pow(): - T.pow(T.float32(1), T.float32(1)) - - after_roundtrip = tvm.script.from_source(original_pow.script(show_meta=True)) - tvm.ir.assert_structural_equal(original_pow, after_roundtrip, True) - - if __name__ == "__main__": tvm.testing.main() From 8ad05033a55aabeecfaa0c6ac35aedb411ccb38c Mon Sep 17 00:00:00 2001 From: lightzhan-intellif Date: Thu, 5 Jan 2023 01:56:06 +0000 Subject: [PATCH 4/4] add tir.pow to make it consistent. --- python/tvm/script/ir_builder/tir/ir.py | 2 +- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 22 ++++++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c6f9ac6263ea..ac1e990a96e2 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1532,7 +1532,7 @@ def wrapped(*args, **kwargs): nearbyint = _op_wrapper(_tir_op.nearbyint) nextafter = _op_wrapper(_tir_op.nextafter) popcount = _op_wrapper(_tir_op.popcount) -pow = _op_wrapper(_tir_op.power) # pylint: disable=redefined-builtin +pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a2e341d82354..9522181432f2 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -69,7 +69,7 @@ from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else +from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index e1adc0a6bbd7..131e91de876e 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -2236,6 +2236,28 @@ def power(x, y, span=None): return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore +def pow(x, y, span=None): + """x power y + + Parameters + ---------- + x : PrimExpr + Input argument. + + y : PrimExpr + The exponent + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + z : PrimExpr + The result. + """ + return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + + def popcount(x): """Count the number of set bits in input x.