diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index b99cae77ec56..8ef68baf6832 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -219,11 +219,16 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, class CutlassModuleCodegen { public: - runtime::Module CreateCSourceModule(Function f, const Map& options) { + runtime::Module CreateCSourceModule(Array functions, + const Map& options) { std::string headers = ""; - auto [code, op_headers] = GenCutlassFunc(f, options); - for (const auto& header : op_headers) { - headers += "#include <" + header + ">\n"; + std::string code = ""; + for (const auto& f : functions) { + auto [f_code, op_headers] = GenCutlassFunc(f, options); + code += "\n" + f_code; + for (const auto& header : op_headers) { + headers += "#include <" + header + ">\n"; + } } return Finalize(headers + "\n" + code, func_names_); } @@ -254,17 +259,13 @@ Array CUTLASSCompiler(Array functions, Map annotated_functions = (*tune_func)(functions, options); - Array compiled_functions; - for (const auto& func : annotated_functions) { - auto func_name = GetExtSymbol(func); - auto source_mod = CutlassModuleCodegen().CreateCSourceModule(func, options); - const auto* pf = runtime::Registry::Get("contrib.cutlass.compile"); - ICHECK(pf != nullptr) << "The packed function contrib.cutlass.compile not found, please import " - "tvm.contrib.cutlass.build"; - compiled_functions.push_back((*pf)(source_mod, options)); - } + auto source_mod = CutlassModuleCodegen().CreateCSourceModule(annotated_functions, options); + const auto* pf = runtime::Registry::Get("contrib.cutlass.compile"); + ICHECK(pf != nullptr) << "The packed function contrib.cutlass.compile not found, please import " + "tvm.contrib.cutlass.build"; + runtime::Module cutlass_mod = (*pf)(source_mod, options); - return compiled_functions; + return {cutlass_mod}; } TVM_REGISTER_GLOBAL("relax.ext.cutlass").set_body_typed(CUTLASSCompiler);