Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions src/relax/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,16 @@ class CodegenCutlass : public relax::MemoizedExprTranslator<OutputType>,

class CutlassModuleCodegen {
public:
runtime::Module CreateCSourceModule(Function f, const Map<String, ObjectRef>& options) {
runtime::Module CreateCSourceModule(Array<Function> functions,
const Map<String, ObjectRef>& options) {
std::string headers = "";
auto [code, op_headers] = GenCutlassFunc(f, options);
for (const auto& header : op_headers) {
headers += "#include <" + header + ">\n";
std::string code = "";
for (const auto& f : functions) {
auto [f_code, op_headers] = GenCutlassFunc(f, options);
code += "\n" + f_code;
for (const auto& header : op_headers) {
headers += "#include <" + header + ">\n";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we might be adding duplicated headers. It probably won't matter for compilation speed but the generated file might get ugly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, however since this is a generated file, I felt it is ok to have duplicate entries in header. We can improve upon it in follow up PRs though.

}
}
return Finalize(headers + "\n" + code, func_names_);
}
Expand Down Expand Up @@ -254,17 +259,13 @@ Array<runtime::Module> CUTLASSCompiler(Array<Function> functions, Map<String, Ob

Array<Function> annotated_functions = (*tune_func)(functions, options);

Array<runtime::Module> compiled_functions;
for (const auto& func : annotated_functions) {
auto func_name = GetExtSymbol(func);
auto source_mod = CutlassModuleCodegen().CreateCSourceModule(func, options);
const auto* pf = runtime::Registry::Get("contrib.cutlass.compile");
ICHECK(pf != nullptr) << "The packed function contrib.cutlass.compile not found, please import "
"tvm.contrib.cutlass.build";
compiled_functions.push_back((*pf)(source_mod, options));
}
auto source_mod = CutlassModuleCodegen().CreateCSourceModule(annotated_functions, options);
const auto* pf = runtime::Registry::Get("contrib.cutlass.compile");
ICHECK(pf != nullptr) << "The packed function contrib.cutlass.compile not found, please import "
"tvm.contrib.cutlass.build";
runtime::Module cutlass_mod = (*pf)(source_mod, options);

return compiled_functions;
return {cutlass_mod};
}

TVM_REGISTER_GLOBAL("relax.ext.cutlass").set_body_typed(CUTLASSCompiler);
Expand Down