Skip to content

Conversation

@CharlieFRuan
Copy link
Member

This PR changes topi.power(x,3) in gelu_tanh to x * x * x. Subsequently, we add a warning when pow(x, y), where y >= 3, is detected.

This is motivated by, in ROCm, pow(x, y) returns NaN when x < 0 and y >= 3. This is shown below:

@I.ir_module
class Module:
    @T.prim_func
    def pow3(A: T.Buffer((1,), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("T_multiply_3"):
                    T.where(ax0_fused_0 * 1024 + ax0_fused_1 < 1)
                    T.reads(A[0])
                    T.writes(A[0])
                    # A[0] = T.pow(A[0], T.float16(2))  # Gives 1
                    # A[0] = T.pow(T.float16(-1), T.float16(3))  # Gives -1
                    # A[0] = T.pow(A[0], T.int32(3))  # gives nan
                    A[0] = T.pow(A[0], T.float16(3))  # gives nan
                    # A[0] = T.pow(A[0], T.float16(4.0))  # gives nan

mod = Module
rt_mod = tvm.build(mod, target="rocm")
A_tvm = tvm.nd.array(np.array([-1.0], dtype="float16"), tvm.rocm())
print(f"Before: {A_tvm.numpy()}")
rt_mod["pow3"](A_tvm)
print(f"After: {A_tvm.numpy()}")

@CharlieFRuan
Copy link
Member Author

cc @tqchen @MasterJH5574 @junrushao

@tqchen tqchen merged commit 5ebdd49 into apache:main Feb 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants