Skip to content

[Bug] [Relax] cannot remove the hint_on_device #17205

@MellowArtisan

Description

@MellowArtisan

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/7bug_assert.py", line 25, in <module>
    tvm.ir.assert_structural_equal(mod_seq, mod)  # assert failed
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/base.py", line 256, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  5: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  4: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}>(tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  3: _ZN3tvm20SEqualHandlerDefault5EqualERKNS_
  2: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
  1: tvm::SEqualHandlerDefault::Impl::RunTasks()
  0: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  File "/software/tvm/src/node/structural_equal.cc", line 392
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = y
                                                                       ^
            R.output(lv0)
        return lv0
and rhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
                                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            R.output(lv0)
        return lv0

Environment

TVM: 0.17.dev0

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})

    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
            R.output(lv0)
        return lv0

mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.RealizeVDevice()])(mod)
mod = relax.transform.RealizeVDevice()(mod)
mod_seq.show()
mod.show()  # cannot remove the 'hint_on_device'
tvm.ir.assert_structural_equal(mod_seq, mod)  # assert failed

cc @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions