diff --git a/python/tvm/contrib/cutlass/__init__.py b/python/tvm/contrib/cutlass/__init__.py index 69d3e9c4bd7c..4b56ac4e164a 100644 --- a/python/tvm/contrib/cutlass/__init__.py +++ b/python/tvm/contrib/cutlass/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """BYOC support for CUTLASS.""" -from .build import tune_cutlass_kernels, build_cutlass_kernels, build_cutlass_kernels_vm +from .build import has_cutlass, num_cutlass_partitions, finalize_modules, finalize_modules_vm diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index bd372572c403..0c8c2ad0b2b9 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -22,6 +22,7 @@ import tvm from tvm import runtime, relay from tvm.contrib.nvcc import get_cuda_version +from tvm._ffi.registry import register_func from .gen_gemm import CutlassGemmProfiler from .gen_conv2d import CutlassConv2DProfiler from .library import ConvKind @@ -29,6 +30,11 @@ logger = logging.getLogger("cutlass") +def has_cutlass(): + """Returns true if the CUTLASS custom codegen is available""" + return tvm.get_global_func("relay.ext.cutlass.create_c_source_module", True) is not None + + def _get_cutlass_path(): tvm_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../../") cutlass_path = os.path.join(tvm_root, "3rdparty/cutlass") @@ -49,6 +55,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): kwargs = {} kwargs["cc"] = "nvcc" kwargs["options"] = [ + "-c", "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", "-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm), "-Xcompiler=-fPIC", @@ -77,7 +84,7 @@ def __init__(self): def visit_call(self, call): op = call.op - if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: + if isinstance(op, relay.Function) and "Composite" in op.attrs: self.signature["op_type"] = op.attrs["Composite"] for i, arg in enumerate(op.params): self.signature["arg%d_shape" % i] = arg.checked_type.shape @@ -285,6 +292,10 @@ def handle_conv2d( } +def num_cutlass_partitions(mod): + return sum([(1 if "cutlass" in var.name_hint else 0) for var in mod.get_global_vars()]) + + def tune_cutlass_kernels( mod, sm, @@ -346,187 +357,279 @@ def tune_cutlass_kernels( for var in mod.get_global_vars(): fun_name = var.name_hint func = mod[fun_name] - annotator = OpAnnotator() if "cutlass" in fun_name: num_cutlass_partition += 1 - annotator.visit(func) - out_shape = annotator.signature["ret_shape"] - out_dtype = annotator.signature["ret_dtype"] - op_type = annotator.signature["op_type"] - - new_attrs = {"op_type": op_type} - new_attrs.update(annotator.signature) - new_attrs.update(func.attrs) - arg0_shape = new_attrs["arg0_shape"] - arg1_shape = new_attrs["arg1_shape"] - arg0_dtype = new_attrs["arg0_dtype"] - arg1_dtype = new_attrs["arg1_dtype"] - - if "conv2d" in op_type: - new_attrs["padding"] = annotator.op_attrs.padding - new_attrs["strides"] = annotator.op_attrs.strides - new_attrs["dilation"] = annotator.op_attrs.dilation - - if "conv2d_transpose" in op_type: - d_shape = out_shape - w_shape = arg1_shape - elif "conv2d_backward_weight" in op_type: - d_shape = arg1_shape - w_shape = out_shape - else: - d_shape = arg0_shape - w_shape = arg1_shape - - new_attrs.update( - handle_conv2d( - conv2d_profiler, - op_type, - d_shape, - w_shape, - annotator.op_attrs.padding, - annotator.op_attrs.strides, - annotator.op_attrs.dilation, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - split_k_slices, - profile_all_alignments, - find_first_valid, - use_multiprocessing, - ) - ) - elif "batch_matmul" in op_type: - new_attrs.update( - handle_batch_matmul( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - elif "dense" in op_type: - new_attrs.update( - handle_dense( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - else: - raise ValueError("%s unsupported composite" % op_type) - - new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) - new_func = relay.Function( - func.params, - func.body, - ret_type=func.ret_type, - type_params=func.type_params, - attrs=new_attrs, + new_func = tune_cutlass_function( + func, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + gemm_profiler, + conv2d_profiler, ) mod.update_func(var, new_func) return mod, num_cutlass_partition -def build_cutlass_kernels( - lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, use_fast_math=False +def tune_cutlass_function( + func, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + gemm_profiler, + conv2d_profiler, ): - """Compile CUTLASS kernels in lib and return the runtime module ready to run. + """Given a function intended to be offloaded to CUTLASS, profile each workload to select which + kernels to emit. Parameters ---------- - lib : GraphExecutorFactoryModule - The output from relay.build containing compiled host code and non-cutlass kernels. + func : IRModule + The Relay Function to tune for. - sm : int - An integer specifying the compute capability. For example, 75 for Turing and - 80 or 86 for Ampere. + use_3xtf32 : bool + Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for + fp32 inputs on tensorcore. - tmp_dir : string, optional - A temporary directory where intermediate compiled artifacts will be stored. + split_k_slices : list of int + Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in + parallel accross split-K blocks, and a seperate global reduction kernel is launched to + accumulate partial reductions. The profiler will pick the best split-k factor from the + given candidate list. Note that the larger split-K factor requires a larger workspace. + Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d + kinds, split_k_slices is ignored. - lib_path : string, optional - The path to a shared library which will be generated as the result of the build process. + profile_all_alignments : bool + When True, profile all kernal variants with smaller alignments than the largest possible. + + find_first_valid : bool + Whether or not profile all candidate kernels, or stop profiling after + the first applicable kernel is found. + + use_multiprocessing : bool + Whether or not compile profiler executables for different kernels in parallel. - threads : int, optional - The number of threads to use for compiling generated kernels. Only available for - CUDA 11.2 or later. Use all physical cores by default. + gemm_profiler : CutlassGemmProfiler + Profiler for dense operators. May cache results between tuned functions. - use_fast_math : bool, optional - Whether or not to use faster but less accurate math intrinsics. + conv2d_profiler : CutlassConv2DProfiler + Profiler for conv2d operators. May cach results between tuned functions. Returns ------- - updated_lib : runtime.Module - The updated module with compiled cutlass kernels. + annot_func : Function + The input function with attributes capturing the best CUTLASS kernel found by tuning. """ - kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math) - lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) - return runtime.load_module(lib_path) + annotator = OpAnnotator() + annotator.visit(func) + out_shape = annotator.signature["ret_shape"] + out_dtype = annotator.signature["ret_dtype"] + op_type = annotator.signature["op_type"] + + new_attrs = {"op_type": op_type} + new_attrs.update(annotator.signature) + new_attrs.update(func.attrs) + arg0_shape = new_attrs["arg0_shape"] + arg1_shape = new_attrs["arg1_shape"] + arg0_dtype = new_attrs["arg0_dtype"] + arg1_dtype = new_attrs["arg1_dtype"] + + if "conv2d" in op_type: + new_attrs["padding"] = annotator.op_attrs.padding + new_attrs["strides"] = annotator.op_attrs.strides + new_attrs["dilation"] = annotator.op_attrs.dilation + + if "conv2d_transpose" in op_type: + d_shape = out_shape + w_shape = arg1_shape + elif "conv2d_backward_weight" in op_type: + d_shape = arg1_shape + w_shape = out_shape + else: + d_shape = arg0_shape + w_shape = arg1_shape + + new_attrs.update( + handle_conv2d( + conv2d_profiler, + op_type, + d_shape, + w_shape, + annotator.op_attrs.padding, + annotator.op_attrs.strides, + annotator.op_attrs.dilation, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + ) + ) + elif "batch_matmul" in op_type: + new_attrs.update( + handle_batch_matmul( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + elif "dense" in op_type: + new_attrs.update( + handle_dense( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + else: + raise ValueError("%s unsupported composite" % op_type) + + new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) + return relay.Function( + func.params, + func.body, + ret_type=func.ret_type, + type_params=func.type_params, + attrs=new_attrs, + ) -def build_cutlass_kernels_vm( - vm_exec, - sm, - tmp_dir="./tmp", - lib_path="compile.so", - vmcode_path="vmcode.ro", - threads=-1, - use_fast_math=False, -): - """Compile CUTLASS kernels in vm_exec and return a VM executable ready to run. +@register_func("relay.ext.cutlass.compile_for_cutlass") +def compile_for_cutlass(mod, cutlass_target): + """Given an IRModule with at least one Compiler='cutlass' Relay function, return a + LibraryModule with all such functions compiled into their PackedFunc-compatible form. + - First runs CUTLASS tuning to decide on the best kernels, which itself requires the + repeated compilation and execution of CUDA code using nvcc. The results of this + is captured as annotation on each relevant function. Kernel performance is cached + overall all functions. + - Then generates a single CSourceModule containing C code implementing all the + Compiler='cutlass' Relay functions, accounting for the tuning done above. + - Then compiles that CSourceModule with the appropriate nvcc arguments to yield + a static .o library. An export_library step will be required on the final runtime + module to link that library into the overall .so library. + See CompileForCutlass in src/relay/backend/contrib/cutlass/codegen.cc for where this + helper function is used to implement the RelayToTIR pass hook for CUTLASS.""" + + # Recover options from the current 'cutlass' Target + assert cutlass_target.kind.name == "cutlass" + tuning_config = { + key: cutlass_target.attrs.get(key) + for key in [ + "sm", + "use_3xtf32", + "split_k_slices", + "profile_all_alignments", + "find_first_valid", + "use_multiprocessing", + ] + } + compile_config = { + key: cutlass_target.attrs.get(key) for key in ["sm", "threads", "use_fast_math"] + } + tmp_dir = cutlass_target.attrs.get("tmp_dir") + + # Tune + logger.info("Tuning for CUTLASS") + mod, _ = tune_cutlass_kernels(mod, tmp_dir=tmp_dir, **tuning_config) + + # Compile + logger.info("Creating CSource module for CUTLASS") + create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module") + c_module = create_c_source_module(mod) + function_names = c_module.get_function("get_func_names")() + compile_options = _get_cutlass_compile_options(**compile_config) + lib_path = os.path.join(tmp_dir, "cutlass.o") + logger.info("Compiling generated CUTLASS code") + c_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options) + + # Recover static library + logger.info("Loading compiled CUTLASS code") + final_mod = tvm.runtime.load_static_library(lib_path, function_names) + + logger.info("Done with CUTLASS compilation") + return final_mod + + +def finalize_modules(lib, lib_path="compile.so", tmp_dir="./tmp"): + """Returns lib with any C source, LLVM and static library modules complied and linked in ready + for use by the graph or AOT executors. This method is not specific to CUTLASS, however it does + assume nvcc will be used for final compilation and linking. It is provided here for + convenience. Parameters ---------- - vm_exec : vm.Executable - The output from relay.vm.compile containing compiled host code and non-cutlass kernels. + lib : runtime.Module + The output from relay.build. - sm : int - An integer specifying the compute capability. For example, 75 for Turing and - 80 or 86 for Ampere. + lib_path : string + The path to a shared library which will be generated as the result of the build process. - tmp_dir : string, optional + tmp_dir : string A temporary directory where intermediate compiled artifacts will be stored. - lib_path : string, optional - The path to a shared library which will be generated as the result of the build process. + Returns + ------- + updated_lib : runtime.Module + The updated library with all compilation and linking completed. - vmcode_path : string, optional - The path where the VM bytecode will be serialized to. + """ + lib_path = os.path.join(tmp_dir, lib_path) + lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") + return runtime.load_module(lib_path) - threads : int, optional - The number of threads to use for compiling generated kernels. Only available for - CUDA 11.2 or later. Use all physical cores by default. - use_fast_math : bool, optional - Whether or not to use faster but less accurate math intrinsics. +def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", tmp_dir="./tmp"): + """Returns vm_exec with any C source, LLVM and static library modules compiled and linked in + ready for use by the VM executor. This method is not specific to CUTLASS, however it does + assume nvcc will be used for final compilation and linking. It is provided here for + convenience. + + Parameters + ---------- + vm_exec : vm.Executable + The output from relay.vm.compile containing compiled host code and kernels. + + lib_path : string + The path to a shared library which will be generated as the result of the build process. + + vmcode_path : string + The path where the VM bytecode will be serialized to as a side-effect. + + tmp_dir : string + A temporary directory where intermediate compiled artifacts will be stored. Returns ------- - updated_vm_exec: vm.Executable - The updated exectuable with compiled cutlass kernels. + updated_vm_exec : vm.Executable + The updated VM executable with all compilation and linking completed. """ code, lib = vm_exec.save() - kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math) lib_path = os.path.join(tmp_dir, lib_path) vmcode_path = os.path.join(tmp_dir, vmcode_path) - lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) + lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") with open(vmcode_path, "wb") as fo: fo.write(code) lib = tvm.runtime.load_module(lib_path) - code = bytearray(open(vmcode_path, "rb").read()) return tvm.runtime.vm.Executable.load_exec(code, lib) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index b3f40f09419c..3c7e1aba2a19 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -332,10 +332,11 @@ def _compile(self, op): opath = os.path.join(self.binary_prefix, op["name"]) if os.path.exists(opath): return - fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu") + fi = tempfile.NamedTemporaryFile("w", delete=False, prefix=self.binary_prefix, suffix=".cu") fi.write(op["src"]) fi.close() cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath) + logger.info("invoking compilation %s", cmd) os.system(cmd) os.unlink(fi.name) @@ -361,6 +362,7 @@ def evaluate(self, op, args): for arg in args: cmd.append(str(arg)) try: + logger.info("invoking evaluation %s", cmd) sp = subprocess.run(cmd, capture_output=True, check=True) rt = float(sp.stdout) if rt == 0.0: diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 5c906f7e69be..1a441a6f03c2 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -21,6 +21,7 @@ from tvm.ir.transform import Sequential, PassContext from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name +from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from ...dataflow_pattern import wildcard, is_op, is_constant @@ -200,8 +201,10 @@ def check_conv2d_residual(call, binary_op): return all(x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape)) -def partition_for_cutlass(mod, params=None): - """Partition the input module into CUTLASS-supported subgraphs.""" +@register_pattern_table("cutlass") +def pattern_table(): + """Returns list of triples describing the name, dataflow pattern and predicate for all + the CUTLASS-supported operators.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) @@ -273,9 +276,11 @@ def partition_for_cutlass(mod, params=None): ) ) - cutlass_patterns = ( - residual_block_patterns + dense_patterns + conv2d_patterns + conv2d_grad_patterns - ) + return residual_block_patterns + dense_patterns + conv2d_patterns + conv2d_grad_patterns + + +def partition_for_cutlass(mod, params=None): + """Partition the input module into CUTLASS-supported subgraphs.""" if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) @@ -290,6 +295,8 @@ def partition_for_cutlass(mod, params=None): with PassContext(opt_level=3): mod = remove_bn_pass(mod) + cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass") + seq = Sequential( [ transform.InferType(), diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 59ff93cfea5c..569ea0cca7ff 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -926,6 +926,9 @@ def _any_gpu_exists(): # Mark a test as requiring microTVM to run requires_micro = Feature("micro", "MicroTVM", cmake_flag="USE_MICRO") +# Mark a test as requiring CUTLASS to run +requires_cutlass = Feature("cutlass", "CUTLASS", cmake_flag="USE_CUTLASS") + # Mark a test as requiring rpc to run requires_rpc = Feature("rpc", "RPC", cmake_flag="USE_RPC") diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index b12da1ac62cb..db36d02896a2 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -19,28 +19,30 @@ /*! * \file src/relay/backend/contrib/cutlass/codegen.cc - * \brief Implementation of CUTLASS codegen. + * \brief The 'custom' compilation pass for CUTLASS (invoked by the RelayToTIRTargetHook pass). */ +#include #include -#include #include #include #include #include -#include #include #include +#include "../../../transforms/compiler_function_utils.h" #include "../../utils.h" #include "../codegen_c/codegen_c.h" namespace tvm { namespace relay { namespace contrib { +namespace cutlass { + +namespace { -using namespace backend; using Str2StrMap = std::unordered_map; static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, @@ -507,7 +509,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, return conv2d_decl.str(); } -class CodegenCutlass : public MemoizedExprTranslator>, public CodegenCBase { +class CodegenCutlass : public backend::MemoizedExprTranslator>, + public CodegenCBase { public: CodegenCutlass(const std::string& id, const Map& attrs) { this->ext_func_id_ = id; @@ -593,6 +596,8 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee, const CallNode* caller) { + using backend::GetRootCall; + const auto pattern_name = callee->GetAttr(attr::kComposite); ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported."; @@ -780,22 +785,22 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi std::vector buf_decl_; }; // class CodegenCutlass -class CutlassModuleCodegen : public CSourceModuleCodegenBase { +class CutlassModuleCodegen { public: - std::pair> GenCutlassFunc(const Function& func) { - ICHECK(func.defined()) << "Input error: expect a Relay function."; - // Record the external symbol for runtime lookup. - auto sid = GetExtSymbol(func); - const auto* attrs = func->attrs.as(); - ICHECK(attrs != nullptr); - const auto dict = attrs->dict; - CodegenCutlass builder(sid, dict); - auto out = builder.VisitExpr(func->body); - code_stream_ << builder.JIT(out); - return {sid, {}}; + explicit CutlassModuleCodegen(IRModule mod) : mod_(std::move(mod)) {} + + runtime::Module CreateCSourceModule() { + EmitPreamble(); + for (const auto& kv : mod_->functions) { + if (const auto* function_node = GetCutlassFunctionNode(kv.second)) { + GenCutlassFunc(GetRef(function_node)); + } + } + return Finalize(); } - runtime::Module CreateCSourceModule(const ObjectRef& ref) override { + private: + void EmitPreamble() { // create header code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -825,34 +830,101 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + } - ICHECK(ref->IsInstance()); - auto res = GenCutlassFunc(Downcast(ref)); - std::string code = code_stream_.str(); - String sym = std::get<0>(res); - Array variables = std::get<1>(res); - // Create a CSource module + void GenCutlassFunc(const Function& function) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + + // Record the external symbol for runtime lookup. + Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()) + << "CUTLASS functions must have a " << tvm::attr::kGlobalSymbol << " attribute"; + std::string sid = opt_global_symbol.value(); + if (std::find(func_names_.begin(), func_names_.end(), sid) != func_names_.end()) { + // Already emitted. + return; + } + func_names_.push_back(sid); + + const auto* attrs = function->attrs.as(); + ICHECK(attrs != nullptr); + const auto dict = attrs->dict; + CodegenCutlass builder(sid, dict); + VLOG(1) << "Creating cutlass C code for '" << sid << "' from:\n" << PrettyPrint(function); + auto out = builder.VisitExpr(function->body); + code_stream_ << builder.JIT(out); + } + + runtime::Module Finalize() { + ICHECK(!func_names_.empty()) + << "Should only create CUTLASS CSourceModule if have at least one CUTLASS partition"; const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; - return (*pf)(code, "cu", Array{sym}, variables); + VLOG(1) << "Generated CUTLASS code:" << std::endl << code_stream_.str(); + return (*pf)(code_stream_.str(), "cu", func_names_, const_vars_); } - private: - /*! \brief The code stream that will be compiled by NVCC */ + /*! + * \brief Returns \p expr as function if it is a \p Function with "Compiler" attribute + * value "cutlass". + */ + const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && opt_compiler.value() == "cutlass") { + return function_node; + } + } + return nullptr; + } + + /*! \brief Module we are compiling. */ + IRModule mod_; + /*! \brief The accumulated code stream that will be compiled by NVCC */ std::ostringstream code_stream_; + /*! \brief The accumulated function names. */ + Array func_names_; + /*! \brief The accumulated constant names. */ + Array const_vars_; }; // CutlassModuleCodegen /*! - * \brief The external cutlass compiler/codegen tool. It takes a Relay - * expression/module and compile it into a runtime module. + * \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python + * function which does the main CUTLASS training, c-code generation and compilation steps. */ -runtime::Module CutlassCompiler(const ObjectRef& ref) { - CutlassModuleCodegen cutlass; - return cutlass.CreateCSourceModule(ref); +transform::Pass CompileForCutlassImpl() { + auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { + VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod); + const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass"); + ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function"; + Optional opt_cutlass_target = Target::Current(); + ICHECK(opt_cutlass_target.defined()) << "Expecting Target::Current to be available"; + ICHECK_EQ(opt_cutlass_target.value()->kind->name, "cutlass"); + runtime::Module runtime_mod = (*pf)(mod, opt_cutlass_target.value()); + Array external_mods = + mod->GetAttr>("external_mods", Array()).value(); + external_mods.push_back(runtime_mod); + return WithAttr(mod, "external_mods", external_mods); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "CompileForCutlass", {}); +} + +runtime::Module CreateCSourceModule(const IRModule& mod) { + VLOG(1) << "Creating CUTLASS CSource module from:" << std::endl << PrettyPrint(mod); + return CutlassModuleCodegen(mod).CreateCSourceModule(); } -TVM_REGISTER_GLOBAL("relay.ext.cutlass").set_body_typed(CutlassCompiler); +} // namespace + +TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule); + +transform::Pass CompileForCutlass() { + return transform::Sequential( + {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"), + CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")}); +} +} // namespace cutlass } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/cutlass/codegen.h b/src/relay/backend/contrib/cutlass/codegen.h new file mode 100644 index 000000000000..e70e97a2fafa --- /dev/null +++ b/src/relay/backend/contrib/cutlass/codegen.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/cutlass/codegen.h + * \brief The 'custom' compilation pass for CUTLASS (invoked by the RelayToTIRTargetHook pass). + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_CUTLASS_CODEGEN_H_ +#define TVM_RELAY_BACKEND_CONTRIB_CUTLASS_CODEGEN_H_ + +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace cutlass { + +/*! + * \brief Returns the pass which replaces all calls to "Primitive" functions with "Compiler" + * attribute of "cutlass" with an call to an extern, and binds a \p runtime::StaticLibrary + * to the IRModule's "external_mods" attribute containing compiled implementations of + * those functions using the CUTLASS C++ template library. + */ +transform::Pass CompileForCutlass(); + +} // namespace cutlass +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_CUTLASS_CODEGEN_H_ diff --git a/src/relay/backend/contrib/cutlass/target.cc b/src/relay/backend/contrib/cutlass/target.cc index 3a7384fb19cc..7b377f340a57 100644 --- a/src/relay/backend/contrib/cutlass/target.cc +++ b/src/relay/backend/contrib/cutlass/target.cc @@ -24,9 +24,12 @@ #include +#include "./codegen.h" + namespace tvm { namespace relay { namespace contrib { +namespace cutlass { /*! * \brief This external codegen target can use the CUTLASS template library included in @@ -36,8 +39,36 @@ namespace contrib { * src/relay/backend/contrib/cutlass/codegen.cc */ TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr("RelayToTIR", CompileForCutlass()) + // An integer specifying the compute capability. For example, 75 for Turing and + // 80 or 86 for Ampere. + .add_attr_option("sm", Integer(80)) + // Whether to use slower but very accurate (compared to tf32) 3xtf32 mode for + // fp32 inputs on tensorcore. + .add_attr_option("use_3xtf32", Bool(true)) + // Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in + // parallel across split-K blocks, and a separate global reduction kernel is launched to + // accumulate partial reductions. The profiler will pick the best split-k factor from the + // given candidate list. Note that the larger split-K factor requires a larger workspace. + // Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d + // kinds, split_k_slices is ignored. + .add_attr_option>("split_k_slices", Array({1})) + // When True, profile all kernel variants with smaller alignments than the largest possible. + .add_attr_option("profile_all_alignments", Bool(false)) + // Whether to profile all candidate kernels, or stop profiling after the first applicable kernel + // is found. + .add_attr_option("find_first_valid", Bool(false)) + // Whether to compile profiler executables for different kernels in parallel. + .add_attr_option("use_multiprocessing", Bool(false)) + // Number of threads to use during compilation, or -1 to use number of cpus. + .add_attr_option("threads", Integer(-1)) + // Whether to replace sigmoid with tanh. + .add_attr_option("use_fast_math", Bool(false)) + // A temporary directory where intermediate compiled artifacts will be stored. + .add_attr_option("tmp_dir", String("./tmp")); +} // namespace cutlass } // namespace contrib } // namespace relay } // namespace tvm diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 8e5238b17399..753ee178f9d3 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -16,7 +16,6 @@ # under the License. import logging import math -import pytest import tvm from tvm import relay from tvm.contrib.cudnn import conv_output_shape @@ -25,10 +24,12 @@ from tvm.relay.op.contrib.cutlass import partition_for_cutlass from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType from tvm.contrib.cutlass import ( - tune_cutlass_kernels, - build_cutlass_kernels, - build_cutlass_kernels_vm, + has_cutlass, + num_cutlass_partitions, + finalize_modules, + finalize_modules_vm, ) +import tvm.testing logging.basicConfig(level=logging.INFO) @@ -37,10 +38,6 @@ def has_cublas(): return tvm.get_global_func("tvm.contrib.cublas.matmul", True) != None -def has_cutlass(): - return tvm.get_global_func("relay.ext.cutlass", True) != None - - def get_ref_rt_mod(mod, params, target="cuda"): with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) @@ -258,24 +255,33 @@ def profile_and_build( sm, split_k_slices=[1], tmp_dir="./tmp", - lib_path="compile.so", use_fast_math=False, use_3xtf32=True, ): + logging.info("before partitioning:\n%s", mod) mod = partition_for_cutlass(mod) - mod, num_cutlass_partition = tune_cutlass_kernels( - mod, - sm, - use_3xtf32=use_3xtf32, - split_k_slices=split_k_slices, - profile_all_alignments=False, - find_first_valid=True, - use_multiprocessing=True, - tmp_dir=tmp_dir, + logging.info("after partitioning:\n%s", mod) + + num_cutlass_partition = num_cutlass_partitions(mod) + host = tvm.target.Target("llvm") + cuda = tvm.target.Target("cuda", host=host) + cutlass = tvm.target.Target( + { + "kind": "cutlass", + "sm": sm, + "use_3xtf32": use_3xtf32, + "split_k_slices": split_k_slices, + "profile_all_alignments": False, + "find_first_valid": True, + "use_multiprocessing": True, + "use_fast_math": use_fast_math, + "tmp_dir": tmp_dir, + }, + host=host, ) with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target="cuda", params=params) - lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path, use_fast_math=use_fast_math) + lib = relay.build(mod, target=[cuda, cutlass], params=params) + lib = finalize_modules(lib, "compile.so", tmp_dir) dev = tvm.device("cuda", 0) rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) return rt_mod, dev, num_cutlass_partition @@ -287,26 +293,30 @@ def profile_and_build_vm( sm, split_k_slices=[1], tmp_dir="./tmp", - lib_path="compile.so", - vmcode_path="vmcode.ro", use_fast_math=False, use_3xtf32=True, ): mod = partition_for_cutlass(mod) - mod, num_cutlass_partition = tune_cutlass_kernels( - mod, - sm, - split_k_slices=split_k_slices, - use_3xtf32=use_3xtf32, - profile_all_alignments=False, - find_first_valid=True, - tmp_dir=tmp_dir, + num_cutlass_partition = num_cutlass_partitions(mod) + host = tvm.target.Target("llvm") + cuda = tvm.target.Target("cuda", host=host) + cutlass = tvm.target.Target( + { + "kind": "cutlass", + "sm": sm, + "use_3xtf32": use_3xtf32, + "split_k_slices": split_k_slices, + "profile_all_alignments": False, + "find_first_valid": True, + "use_multiprocessing": True, + "use_fast_math": use_fast_math, + "tmp_dir": tmp_dir, + }, + host=host, ) with tvm.transform.PassContext(opt_level=3): - vm_exec = relay.vm.compile(mod, target="cuda", params=params) - vm_exec = build_cutlass_kernels_vm( - vm_exec, sm, tmp_dir, lib_path, vmcode_path, use_fast_math=use_fast_math - ) + vm_exec = relay.vm.compile(mod, target=[cuda, cutlass], params=params) + vm_exec = finalize_modules_vm(vm_exec, "compile.so", tmp_dir) dev = tvm.device("cuda", 0) return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition @@ -325,8 +335,7 @@ def verify_dense( weight_dtype="float16", use_3xtf32=True, ): - if not has_cutlass(): - return + assert has_cutlass() if sm < 80 and data_dtype == "float32": return @@ -377,8 +386,7 @@ def verify_dense( def verify_batch_matmul( func, batch, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False ): - if not has_cutlass(): - return + assert has_cutlass() mod = tvm.IRModule.from_expr(func) typ = relay.transform.InferType()(mod)["main"].body.checked_type use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) @@ -415,6 +423,7 @@ def verify_batch_matmul( K = 64 +@tvm.testing.requires_cutlass def test_dense(): verify_dense(get_dense(M, N, K), M, N, K) verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) @@ -449,21 +458,25 @@ def test_dense(): ) +@tvm.testing.requires_cutlass def test_dense_bias(): verify_dense(get_dense_bias(M, N, K), M, N, K) verify_dense(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K) +@tvm.testing.requires_cutlass def test_dense_bias_relu(): verify_dense(get_dense_bias_relu(M, N, K), M, N, K) verify_dense(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K) +@tvm.testing.requires_cutlass def test_dense_bias_gelu(): verify_dense(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3) verify_dense(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, atol=1e-3, rtol=1e-3) +@tvm.testing.requires_cutlass def test_dense_dynamic(): data_shape = (relay.Any(), K) weight_shape = (relay.Any(), K) @@ -490,6 +503,7 @@ def test_dense_dynamic(): ) +@tvm.testing.requires_cutlass def test_batch_matmul(): batch = 8 verify_batch_matmul(get_batch_matmul(batch, M, N, K), batch, M, N, K) @@ -527,8 +541,7 @@ def verify_conv2d_common( ref_target="cuda", use_vm=False, ): - if not has_cutlass(): - return + assert has_cutlass() if sm < 80 and inputs[0].dtype == "float32": return @@ -666,6 +679,7 @@ def verify_conv2d_backward_weight( ) +@tvm.testing.requires_cutlass def test_conv2d(): padding = (1, 1) for IC in [3, 16]: @@ -746,6 +760,7 @@ def test_conv2d(): ) +@tvm.testing.requires_cutlass def test_conv2d_fusion(): d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) @@ -793,6 +808,7 @@ def test_conv2d_fusion(): ) +@tvm.testing.requires_cutlass def test_conv2d_residual_block(): d_shape = (16, 16, 32, 32) w_shape = (16, 16, 3, 3) @@ -813,6 +829,7 @@ def test_conv2d_residual_block(): verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False) +@tvm.testing.requires_cutlass def test_conv2d_transpose(): OC = 8 IC = 16 @@ -852,6 +869,7 @@ def test_conv2d_transpose(): ) +@tvm.testing.requires_cutlass def test_conv2d_backward_weight(): OC = 8 IC = 16 @@ -890,6 +908,7 @@ def test_conv2d_backward_weight(): ) +@tvm.testing.requires_cutlass def test_conv2d_bwd(): IC = 16 OC = 8