diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 842e21378fd1..ac1e990a96e2 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1532,7 +1532,7 @@ def wrapped(*args, **kwargs): nearbyint = _op_wrapper(_tir_op.nearbyint) nextafter = _op_wrapper(_tir_op.nextafter) popcount = _op_wrapper(_tir_op.popcount) -power = _op_wrapper(_tir_op.power) +pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) @@ -1713,7 +1713,7 @@ def f(): "nearbyint", "nextafter", "popcount", - "power", + "pow", "q_multiply_shift", "q_multiply_shift_per_axis", "ret", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a2e341d82354..9522181432f2 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -69,7 +69,7 @@ from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else +from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else from .op import likely, isnan, isnullptr, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index e1adc0a6bbd7..131e91de876e 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -2236,6 +2236,28 @@ def power(x, y, span=None): return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore +def pow(x, y, span=None): + """x power y + + Parameters + ---------- + x : PrimExpr + Input argument. + + y : PrimExpr + The exponent + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + z : PrimExpr + The result. + """ + return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + + def popcount(x): """Count the number of set bits in input x. diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c0174a0671c0..0e9be0463943 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3550,6 +3550,14 @@ def func(A: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]): return mod["main"] +def intrinsic_pow(): + @T.prim_func + def func(): + T.pow(T.float32(1), T.float32(1)) + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3607,6 +3615,7 @@ def func(A: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]): elif_chain_with_else, *nested_boolean_expressions(), multi_env_threads, + intrinsic_pow, )