diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index cf3530c0afce..9acd1ca903d1 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -63,7 +63,9 @@ class ROCMModuleNode : public runtime::ModuleNode { } const char* type_key() const final { return "hip"; } - + int GetPropertyMask() const final { + return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; + } PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; void SaveToFile(const String& file_name, const String& format) final { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 67c81d2803b6..02d203b7e97a 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -702,8 +702,8 @@ llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t s llvm::GlobalValue::LinkageTypes linkage) { llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size); llvm::GlobalVariable* global = - new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr, - llvm::GlobalValue::NotThreadLocal, shared_address_space); + new llvm::GlobalVariable(*module_, type, false, linkage, llvm::UndefValue::get(type), "shmem", + nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(alignment)); #else diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index fba62a0c18ac..abc288f0eb24 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -729,7 +729,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && (std::any_of(types.begin(), types.end(), [](DataType ty) { - if (ty.is_vector()) return true; + if ((ty.is_vector()) || !ty.is_int()) return true; return ty.bits() != 32; }))) { return false;