Skip to content

[Bug] [ROCm] dense op error when compiling model with rocm #16160

@JiaJiDuan

Description

@JiaJiDuan

Expected behavior

Successfully compiled resnet-18 or resnet-50 model with rocm

Actual behavior

The following error was encountered while compiling the model:

Traceback (most recent call last):
  File "onnx_rocm_test.py", line 178, in <module>
    lib = relay.build(mod, target=target, params=params)
  File "/root/tvm/python/tvm/relay/build_module.py", line 366, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/root/tvm/python/tvm/relay/build_module.py", line 161, in build
    self._build(
  File "/root/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/root/tvm/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  10: tvm::relay::backend::RelayBuildModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
  9: tvm::relay::backend::RelayBuildModule::Build(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&, tvm::Target const&, tvm::relay::Executor const&, tvm::relay::Runtime const&, tvm::WorkspaceMemoryPools const&, tvm::ConstantMemoryPools const&, tvm::runtime::String)
  8: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  7: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  6: tvm::codegen::Build(tvm::IRModule, tvm::Target)
  5: _ZN3tvm7runtime13Pac
  4: tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  3: tvm::codegen::BuildAMDGPU(tvm::IRModule, tvm::Target)
  2: tvm::codegen::CodeGenLLVM::Finish()
  1: tvm::codegen::CodeGenLLVM::Verify() const
  0: _ZN3tvm7runtime6detail
  File "/root/tvm/src/target/llvm/codegen_llvm.cc", line 354
TVMError: LLVM module verification failed with the following errors: 
Instruction does not dominate all uses!
  %29 = call i32 @llvm.amdgcn.workitem.id.x()
  %83 = icmp slt i32 %29, 2
Instruction does not dominate all uses!
  %29 = call i32 @llvm.amdgcn.workitem.id.x()
  %97 = mul nsw i32 %29, 4
Instruction does not dominate all uses!
  %29 = call i32 @llvm.amdgcn.workitem.id.x()
  %110 = mul nsw i32 %29, 4

Environment

  • gpu: AMD Radeon RX 7600(gfx1102)
  • tvm v0.14.0
  • system: ubuntu22.04
  • network: resnet-18 and resnet-50
  • rocm: 5.7.1

Steps to reproduce

After trying, I think the problem is caused by the dense op.
This problem occurs with dense_large_batch, but not with dense_small_batch.

My test script:

"""Test code for dense operator"""
import contextlib
import numpy as np
import pytest
import sys

import tvm
import tvm.testing
from tvm import te, topi
from tvm.topi.utils import get_const_tuple


def dense_ref_data(batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype):

    if "float" in in_dtype:
        a_np = np.random.uniform(
            low=-1, high=1, size=(batch_size, in_dim)).astype(in_dtype)
        b_np = np.random.uniform(
            low=-1, high=1, size=(out_dim, in_dim)).astype(in_dtype)
        c_np = np.random.uniform(
            low=-1, high=1, size=(out_dim,)).astype(out_dtype)
    elif in_dtype == "int8":
        a_np = np.random.randint(
            low=-128, high=127, size=(batch_size, in_dim)).astype(in_dtype)
        b_np = np.random.randint(
            low=-128, high=127, size=(out_dim, in_dim)).astype(in_dtype)
        c_np = np.random.randint(
            low=-128, high=127, size=(out_dim,)).astype(out_dtype)
    else:
        raise ValueError(
            "No method to generate test data for data type '{}'".format(in_dtype))

    matmul = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))

    if use_bias:
        matmul += c_np

    d_np = matmul
    return (a_np, b_np, c_np, d_np)


if __name__ == "__main__":
    batch_size = 1
    in_dim = 1024
    out_dim = 1000
    in_dtype = "float32"
    out_dtype = "float32"

    use_bias = True

    A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype)
    B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype)
    C = te.placeholder((out_dim,), name="C", dtype=out_dtype)

    a_np, b_np, c_np, d_np = dense_ref_data(
        batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype)

    # fcompute = topi.gpu.dense_small_batch
    # fschedule = topi.gpu.schedule_dense_small_batch
    fcompute = topi.gpu.dense_large_batch
    fschedule = topi.gpu.schedule_dense_large_batch
    
    target="rocm"
    with tvm.target.Target(target):
        D = fcompute(A, B, C if use_bias else None, out_dtype)
        s = fschedule([D])

    func = tvm.build(s, [A, B, C, D], target, name="dense")
    

    dev = tvm.rocm(0)
    a = tvm.nd.array(a_np, dev)
    b = tvm.nd.array(b_np, dev)
    c = tvm.nd.array(c_np, dev)
    d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev)

    func(a, b, c, d)
    tvm.testing.assert_allclose(d.numpy(), d_np, atol=1e-5, rtol=1e-5)

I think this problem arises from #13847
If I give up the modification of #13847,also use dense_small_batch with rocm. resnet-18 and resnet-50 compile successfully and the results are correct.

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions