From d3df1d338b259d2ca41491a55bd1837b9dcd909c Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Mon, 6 Jun 2022 15:04:23 -0700 Subject: [PATCH 1/3] [BYOC] Make CUTLASS BYOC integration 'Collage friendly' (See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). Currently CUTLASS has four entry points: - The usual 'partition_for_cutlass' partitioning function, using the standard pattern table and pass machinery (see cutlass/build.py). - A 'tune_cutlass_kernels' function which augments CUTLASS partition functions with the results of building and running test kernels (see cutlass/build.py). - A 'relay.ext.cutlass' external codegen function which inspects the turning results and generates a CSourceModule for each partitions (see cutlass/codegen.cc). - A 'build_cutlass_kernels_vm' function which runs 'export_library' with all the nvcc compiler options needed to build all the CSourceModules (see cutlass/bild.py). For Collage we'd like CUTLASS to have only two entry points: 'partition_for_cutlass', and 'relay.ext.cutlass' or equivalent. This makes the CUTLASS external codegen integration composable with other integrations, which in turn helps Collage avoid having to understand any external codegen APIs other than the global pattern table and the custom compilation function/pass. Collage also tends to end up requiring multiple partitions for the same backend since it is more aggressive at mixing-and-matching smaller sub-graphs between backends. Thus we'd also like to make sure all tuning, generated code and compilation overhead is shared between all such CUTLASS partitions. So, in this PR: - We add all the CUTLASS-specific tuning and compilation options as new Target attributes for the 'external codegen' "cutlass" TargetKind (cutlass/target.cc). The user now has one place to provide those settings, and we've already done the legwork to plumb the target instance. - We replace 'relay.ext.cutlass' with a 'RelayToTIR' custom pass hook 'CompileForCutlass' (see cutlass/codegen.cc). This pass obviously can see all the CUTLASS partitions in the IRModule, so we can now share tuning results between them all and can be sure to generate a single CSourceModule. The pass can also invoke the compiler to yield a StaticModule, which we've also already done the legwork to support. In this way all CUTLASS-specific steps are handled at once. - For convenience we supply 'finalize_modules' and 'finalize_modules_vm' which invoke nvcc for final linking (using export_library as usual). However, there's now nothing CUTLASS specific in those helpers other than their overriding of the 'compiler' to be nvcc. - test_cutlass.py is updated to use the new API. Though this is a breaking change for existing users of the CUTLASS integration the change is pretty minor, as shown in test_cutlass.py. --- python/tvm/contrib/cutlass/__init__.py | 2 +- python/tvm/contrib/cutlass/build.py | 385 ++++++++++++------- python/tvm/contrib/cutlass/gen_tensor_op.py | 4 +- python/tvm/relay/op/contrib/cutlass.py | 17 +- python/tvm/testing/utils.py | 3 + src/relay/backend/contrib/cutlass/codegen.cc | 138 +++++-- src/relay/backend/contrib/cutlass/codegen.h | 48 +++ src/relay/backend/contrib/cutlass/target.cc | 33 +- tests/python/contrib/test_cutlass.py | 99 +++-- 9 files changed, 503 insertions(+), 226 deletions(-) create mode 100644 src/relay/backend/contrib/cutlass/codegen.h diff --git a/python/tvm/contrib/cutlass/__init__.py b/python/tvm/contrib/cutlass/__init__.py index 69d3e9c4bd7c..4b56ac4e164a 100644 --- a/python/tvm/contrib/cutlass/__init__.py +++ b/python/tvm/contrib/cutlass/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """BYOC support for CUTLASS.""" -from .build import tune_cutlass_kernels, build_cutlass_kernels, build_cutlass_kernels_vm +from .build import has_cutlass, num_cutlass_partitions, finalize_modules, finalize_modules_vm diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index bd372572c403..d954fb813fd4 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -22,6 +22,7 @@ import tvm from tvm import runtime, relay from tvm.contrib.nvcc import get_cuda_version +from tvm._ffi.registry import register_func from .gen_gemm import CutlassGemmProfiler from .gen_conv2d import CutlassConv2DProfiler from .library import ConvKind @@ -29,6 +30,11 @@ logger = logging.getLogger("cutlass") +def has_cutlass(): + """Returns true if the CUTLASS custom codegen is available""" + return tvm.get_global_func("relay.ext.cutlass.create_c_source_module", True) is not None + + def _get_cutlass_path(): tvm_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../../") cutlass_path = os.path.join(tvm_root, "3rdparty/cutlass") @@ -49,6 +55,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): kwargs = {} kwargs["cc"] = "nvcc" kwargs["options"] = [ + "-c", "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", "-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm), "-Xcompiler=-fPIC", @@ -77,7 +84,7 @@ def __init__(self): def visit_call(self, call): op = call.op - if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: + if isinstance(op, relay.Function) and "Composite" in op.attrs: self.signature["op_type"] = op.attrs["Composite"] for i, arg in enumerate(op.params): self.signature["arg%d_shape" % i] = arg.checked_type.shape @@ -285,6 +292,10 @@ def handle_conv2d( } +def num_cutlass_partitions(mod): + return sum([(1 if "cutlass" in var.name_hint else 0) for var in mod.get_global_vars()]) + + def tune_cutlass_kernels( mod, sm, @@ -346,187 +357,271 @@ def tune_cutlass_kernels( for var in mod.get_global_vars(): fun_name = var.name_hint func = mod[fun_name] - annotator = OpAnnotator() if "cutlass" in fun_name: num_cutlass_partition += 1 - annotator.visit(func) - out_shape = annotator.signature["ret_shape"] - out_dtype = annotator.signature["ret_dtype"] - op_type = annotator.signature["op_type"] - - new_attrs = {"op_type": op_type} - new_attrs.update(annotator.signature) - new_attrs.update(func.attrs) - arg0_shape = new_attrs["arg0_shape"] - arg1_shape = new_attrs["arg1_shape"] - arg0_dtype = new_attrs["arg0_dtype"] - arg1_dtype = new_attrs["arg1_dtype"] - - if "conv2d" in op_type: - new_attrs["padding"] = annotator.op_attrs.padding - new_attrs["strides"] = annotator.op_attrs.strides - new_attrs["dilation"] = annotator.op_attrs.dilation - - if "conv2d_transpose" in op_type: - d_shape = out_shape - w_shape = arg1_shape - elif "conv2d_backward_weight" in op_type: - d_shape = arg1_shape - w_shape = out_shape - else: - d_shape = arg0_shape - w_shape = arg1_shape - - new_attrs.update( - handle_conv2d( - conv2d_profiler, - op_type, - d_shape, - w_shape, - annotator.op_attrs.padding, - annotator.op_attrs.strides, - annotator.op_attrs.dilation, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - split_k_slices, - profile_all_alignments, - find_first_valid, - use_multiprocessing, - ) - ) - elif "batch_matmul" in op_type: - new_attrs.update( - handle_batch_matmul( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - elif "dense" in op_type: - new_attrs.update( - handle_dense( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - else: - raise ValueError("%s unsupported composite" % op_type) - - new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) - new_func = relay.Function( - func.params, - func.body, - ret_type=func.ret_type, - type_params=func.type_params, - attrs=new_attrs, + new_func = tune_cutlass_function( + func, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + gemm_profiler, + conv2d_profiler, ) mod.update_func(var, new_func) return mod, num_cutlass_partition -def build_cutlass_kernels( - lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, use_fast_math=False +def tune_cutlass_function( + func, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + gemm_profiler, + conv2d_profiler, ): - """Compile CUTLASS kernels in lib and return the runtime module ready to run. + """Given a function intended to be offloaded to CUTLASS, profile each workload to select which + kernels to emit. Parameters ---------- - lib : GraphExecutorFactoryModule - The output from relay.build containing compiled host code and non-cutlass kernels. + func : IRModule + The Relay Function to tune for. - sm : int - An integer specifying the compute capability. For example, 75 for Turing and - 80 or 86 for Ampere. + use_3xtf32 : bool + Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for + fp32 inputs on tensorcore. - tmp_dir : string, optional - A temporary directory where intermediate compiled artifacts will be stored. + split_k_slices : list of int + Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in + parallel accross split-K blocks, and a seperate global reduction kernel is launched to + accumulate partial reductions. The profiler will pick the best split-k factor from the + given candidate list. Note that the larger split-K factor requires a larger workspace. + Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d + kinds, split_k_slices is ignored. + + profile_all_alignments : bool + When True, profile all kernal variants with smaller alignments than the largest possible. - lib_path : string, optional - The path to a shared library which will be generated as the result of the build process. + find_first_valid : bool + Whether or not profile all candidate kernels, or stop profiling after + the first applicable kernel is found. - threads : int, optional - The number of threads to use for compiling generated kernels. Only available for - CUDA 11.2 or later. Use all physical cores by default. + use_multiprocessing : bool + Whether or not compile profiler executables for different kernels in parallel. + + gemm_profiler : CutlassGemmProfiler + Profiler for dense operators. May cache results between tuned functions. - use_fast_math : bool, optional - Whether or not to use faster but less accurate math intrinsics. + conv2d_profiler : CutlassConv2DProfiler + Profiler for conv2d operators. May cach results between tuned functions. Returns ------- - updated_lib : runtime.Module - The updated module with compiled cutlass kernels. + annot_func : Function + The input function with attributes capturing the best CUTLASS kernel found by tuning. """ - kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math) - lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) - return runtime.load_module(lib_path) + annotator = OpAnnotator() + annotator.visit(func) + out_shape = annotator.signature["ret_shape"] + out_dtype = annotator.signature["ret_dtype"] + op_type = annotator.signature["op_type"] + + new_attrs = {"op_type": op_type} + new_attrs.update(annotator.signature) + new_attrs.update(func.attrs) + arg0_shape = new_attrs["arg0_shape"] + arg1_shape = new_attrs["arg1_shape"] + arg0_dtype = new_attrs["arg0_dtype"] + arg1_dtype = new_attrs["arg1_dtype"] + + if "conv2d" in op_type: + new_attrs["padding"] = annotator.op_attrs.padding + new_attrs["strides"] = annotator.op_attrs.strides + new_attrs["dilation"] = annotator.op_attrs.dilation + + if "conv2d_transpose" in op_type: + d_shape = out_shape + w_shape = arg1_shape + elif "conv2d_backward_weight" in op_type: + d_shape = arg1_shape + w_shape = out_shape + else: + d_shape = arg0_shape + w_shape = arg1_shape + + new_attrs.update( + handle_conv2d( + conv2d_profiler, + op_type, + d_shape, + w_shape, + annotator.op_attrs.padding, + annotator.op_attrs.strides, + annotator.op_attrs.dilation, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + ) + ) + elif "batch_matmul" in op_type: + new_attrs.update( + handle_batch_matmul( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + elif "dense" in op_type: + new_attrs.update( + handle_dense( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + else: + raise ValueError("%s unsupported composite" % op_type) + + new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) + return relay.Function( + func.params, + func.body, + ret_type=func.ret_type, + type_params=func.type_params, + attrs=new_attrs, + ) -def build_cutlass_kernels_vm( - vm_exec, - sm, - tmp_dir="./tmp", - lib_path="compile.so", - vmcode_path="vmcode.ro", - threads=-1, - use_fast_math=False, -): - """Compile CUTLASS kernels in vm_exec and return a VM executable ready to run. +@register_func("relay.ext.cutlass.compile_for_cutlass") +def compile_for_cutlass(mod, cutlass_target): + """Given an IRModule with at least one Compiler='cutlass' Relay function, return a + LibraryModule with all such functions compiled into their PackedFunc-compatible form. + - First runs CUTLASS tuning to decide on the best kernels, which itself requires the + repeated compilation and execution of CUDA code using nvcc. The results of this + is captured as annotation on each relevant function. Kernel performance is cached + overall all functions. + - Then generates a single CSourceModule containing C code implementing all the + Compiler='cutlass' Relay functions, accounting for the tuning done above. + - Then compiles that CSourceModule with the appropriate nvcc arguments to yield + a static .o library. An export_library step will be required on the final runtime + module to link that library into the overall .so library. + See CompileForCutlass in src/relay/backend/contrib/cutlass/codegen.cc for where this + helper function is used to implement the RelayToTIR pass hook for CUTLASS.""" + + # Recover options from the current 'cutlass' Target + assert cutlass_target.kind.name == "cutlass" + tuning_config = { + key: cutlass_target.attrs.get(key) + for key in [ + "sm", + "use_3xtf32", + "split_k_slices", + "profile_all_alignments", + "find_first_valid", + "use_multiprocessing", + ] + } + compile_config = { + key: cutlass_target.attrs.get(key) for key in ["sm", "threads", "use_fast_math"] + } + tmp_dir = cutlass_target.attrs.get("tmp_dir") + + # Tune + logger.info("Tuning for CUTLASS") + mod, _ = tune_cutlass_kernels(mod, tmp_dir=tmp_dir, **tuning_config) + + # Compile + logger.info("Creating CSource module for CUTLASS") + create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module") + c_module = create_c_source_module(mod) + function_names = c_module.get_function("get_func_names")() + compile_options = _get_cutlass_compile_options(**compile_config) + lib_path = os.path.join(tmp_dir, "cutlass.o") + logger.info("Compiling generated CUTLASS code") + c_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options) + + # Recover static library + logger.info("Loading compiled CUTLASS code") + final_mod = tvm.runtime.load_static_library(lib_path, function_names) + + logger.info("Done with CUTLASS compilation") + return final_mod + + +def finalize_modules(lib, lib_path, tmp_dir): + """Returns lib with any C source, LLVM and static library modules complied and linked in ready + for use by the graph or AOT executors. This method is not specific to CUTLASS, however it does + assume nvcc will be used for final compilation and linking. It is provided here for + convenience. Parameters ---------- - vm_exec : vm.Executable - The output from relay.vm.compile containing compiled host code and non-cutlass kernels. + lib : runtime.Module + The output from relay.build. - sm : int - An integer specifying the compute capability. For example, 75 for Turing and - 80 or 86 for Ampere. + lib_path : string + Name for temporary library .so file. - tmp_dir : string, optional - A temporary directory where intermediate compiled artifacts will be stored. + tmp_dir : Working temporary directory. + + Returns + ------- + updated_lib : runtime::Module + The given lib with any final compilation and linking steps completed. + + """ + lib_path = os.path.join(tmp_dir, lib_path) + lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") + return runtime.load_module(lib_path) - lib_path : string, optional - The path to a shared library which will be generated as the result of the build process. - vmcode_path : string, optional - The path where the VM bytecode will be serialized to. +def finalize_modules_vm(vm_exec, lib_path, tmp_dir): + """Returns vm_exec with any C source, LLVM and static library modules compiled and linked in + ready for use by the VM executor. This method is not specific to CUTLASS, however it does + assume nvcc will be used for final compilation and linking. It is provided here for + convenience. + + Parameters + ---------- + vm_exec : vm.Executable + The output from relay.vm.compile. - threads : int, optional - The number of threads to use for compiling generated kernels. Only available for - CUDA 11.2 or later. Use all physical cores by default. + lib_path : string + Name for temporary library .so file. - use_fast_math : bool, optional - Whether or not to use faster but less accurate math intrinsics. + tmp_dir : Working temporary directory. Returns ------- - updated_vm_exec: vm.Executable - The updated exectuable with compiled cutlass kernels. + updated_vm_exec : vm.Executable + The given lib with any final compilation and linking steps completed. """ code, lib = vm_exec.save() - kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math) lib_path = os.path.join(tmp_dir, lib_path) - vmcode_path = os.path.join(tmp_dir, vmcode_path) - lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) - with open(vmcode_path, "wb") as fo: - fo.write(code) + lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") lib = tvm.runtime.load_module(lib_path) - code = bytearray(open(vmcode_path, "rb").read()) return tvm.runtime.vm.Executable.load_exec(code, lib) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index b3f40f09419c..3c7e1aba2a19 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -332,10 +332,11 @@ def _compile(self, op): opath = os.path.join(self.binary_prefix, op["name"]) if os.path.exists(opath): return - fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu") + fi = tempfile.NamedTemporaryFile("w", delete=False, prefix=self.binary_prefix, suffix=".cu") fi.write(op["src"]) fi.close() cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath) + logger.info("invoking compilation %s", cmd) os.system(cmd) os.unlink(fi.name) @@ -361,6 +362,7 @@ def evaluate(self, op, args): for arg in args: cmd.append(str(arg)) try: + logger.info("invoking evaluation %s", cmd) sp = subprocess.run(cmd, capture_output=True, check=True) rt = float(sp.stdout) if rt == 0.0: diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 5c906f7e69be..1a441a6f03c2 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -21,6 +21,7 @@ from tvm.ir.transform import Sequential, PassContext from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name +from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from ...dataflow_pattern import wildcard, is_op, is_constant @@ -200,8 +201,10 @@ def check_conv2d_residual(call, binary_op): return all(x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape)) -def partition_for_cutlass(mod, params=None): - """Partition the input module into CUTLASS-supported subgraphs.""" +@register_pattern_table("cutlass") +def pattern_table(): + """Returns list of triples describing the name, dataflow pattern and predicate for all + the CUTLASS-supported operators.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) @@ -273,9 +276,11 @@ def partition_for_cutlass(mod, params=None): ) ) - cutlass_patterns = ( - residual_block_patterns + dense_patterns + conv2d_patterns + conv2d_grad_patterns - ) + return residual_block_patterns + dense_patterns + conv2d_patterns + conv2d_grad_patterns + + +def partition_for_cutlass(mod, params=None): + """Partition the input module into CUTLASS-supported subgraphs.""" if params is not None: mod["main"] = bind_params_by_name(mod["main"], params) @@ -290,6 +295,8 @@ def partition_for_cutlass(mod, params=None): with PassContext(opt_level=3): mod = remove_bn_pass(mod) + cutlass_patterns = relay.op.contrib.get_pattern_table("cutlass") + seq = Sequential( [ transform.InferType(), diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 59ff93cfea5c..569ea0cca7ff 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -926,6 +926,9 @@ def _any_gpu_exists(): # Mark a test as requiring microTVM to run requires_micro = Feature("micro", "MicroTVM", cmake_flag="USE_MICRO") +# Mark a test as requiring CUTLASS to run +requires_cutlass = Feature("cutlass", "CUTLASS", cmake_flag="USE_CUTLASS") + # Mark a test as requiring rpc to run requires_rpc = Feature("rpc", "RPC", cmake_flag="USE_RPC") diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index b12da1ac62cb..db36d02896a2 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -19,28 +19,30 @@ /*! * \file src/relay/backend/contrib/cutlass/codegen.cc - * \brief Implementation of CUTLASS codegen. + * \brief The 'custom' compilation pass for CUTLASS (invoked by the RelayToTIRTargetHook pass). */ +#include #include -#include #include #include #include #include -#include #include #include +#include "../../../transforms/compiler_function_utils.h" #include "../../utils.h" #include "../codegen_c/codegen_c.h" namespace tvm { namespace relay { namespace contrib { +namespace cutlass { + +namespace { -using namespace backend; using Str2StrMap = std::unordered_map; static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, @@ -507,7 +509,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, return conv2d_decl.str(); } -class CodegenCutlass : public MemoizedExprTranslator>, public CodegenCBase { +class CodegenCutlass : public backend::MemoizedExprTranslator>, + public CodegenCBase { public: CodegenCutlass(const std::string& id, const Map& attrs) { this->ext_func_id_ = id; @@ -593,6 +596,8 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee, const CallNode* caller) { + using backend::GetRootCall; + const auto pattern_name = callee->GetAttr(attr::kComposite); ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported."; @@ -780,22 +785,22 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi std::vector buf_decl_; }; // class CodegenCutlass -class CutlassModuleCodegen : public CSourceModuleCodegenBase { +class CutlassModuleCodegen { public: - std::pair> GenCutlassFunc(const Function& func) { - ICHECK(func.defined()) << "Input error: expect a Relay function."; - // Record the external symbol for runtime lookup. - auto sid = GetExtSymbol(func); - const auto* attrs = func->attrs.as(); - ICHECK(attrs != nullptr); - const auto dict = attrs->dict; - CodegenCutlass builder(sid, dict); - auto out = builder.VisitExpr(func->body); - code_stream_ << builder.JIT(out); - return {sid, {}}; + explicit CutlassModuleCodegen(IRModule mod) : mod_(std::move(mod)) {} + + runtime::Module CreateCSourceModule() { + EmitPreamble(); + for (const auto& kv : mod_->functions) { + if (const auto* function_node = GetCutlassFunctionNode(kv.second)) { + GenCutlassFunc(GetRef(function_node)); + } + } + return Finalize(); } - runtime::Module CreateCSourceModule(const ObjectRef& ref) override { + private: + void EmitPreamble() { // create header code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -825,34 +830,101 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + } - ICHECK(ref->IsInstance()); - auto res = GenCutlassFunc(Downcast(ref)); - std::string code = code_stream_.str(); - String sym = std::get<0>(res); - Array variables = std::get<1>(res); - // Create a CSource module + void GenCutlassFunc(const Function& function) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + + // Record the external symbol for runtime lookup. + Optional opt_global_symbol = function->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_global_symbol.defined()) + << "CUTLASS functions must have a " << tvm::attr::kGlobalSymbol << " attribute"; + std::string sid = opt_global_symbol.value(); + if (std::find(func_names_.begin(), func_names_.end(), sid) != func_names_.end()) { + // Already emitted. + return; + } + func_names_.push_back(sid); + + const auto* attrs = function->attrs.as(); + ICHECK(attrs != nullptr); + const auto dict = attrs->dict; + CodegenCutlass builder(sid, dict); + VLOG(1) << "Creating cutlass C code for '" << sid << "' from:\n" << PrettyPrint(function); + auto out = builder.VisitExpr(function->body); + code_stream_ << builder.JIT(out); + } + + runtime::Module Finalize() { + ICHECK(!func_names_.empty()) + << "Should only create CUTLASS CSourceModule if have at least one CUTLASS partition"; const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; - return (*pf)(code, "cu", Array{sym}, variables); + VLOG(1) << "Generated CUTLASS code:" << std::endl << code_stream_.str(); + return (*pf)(code_stream_.str(), "cu", func_names_, const_vars_); } - private: - /*! \brief The code stream that will be compiled by NVCC */ + /*! + * \brief Returns \p expr as function if it is a \p Function with "Compiler" attribute + * value "cutlass". + */ + const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && opt_compiler.value() == "cutlass") { + return function_node; + } + } + return nullptr; + } + + /*! \brief Module we are compiling. */ + IRModule mod_; + /*! \brief The accumulated code stream that will be compiled by NVCC */ std::ostringstream code_stream_; + /*! \brief The accumulated function names. */ + Array func_names_; + /*! \brief The accumulated constant names. */ + Array const_vars_; }; // CutlassModuleCodegen /*! - * \brief The external cutlass compiler/codegen tool. It takes a Relay - * expression/module and compile it into a runtime module. + * \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python + * function which does the main CUTLASS training, c-code generation and compilation steps. */ -runtime::Module CutlassCompiler(const ObjectRef& ref) { - CutlassModuleCodegen cutlass; - return cutlass.CreateCSourceModule(ref); +transform::Pass CompileForCutlassImpl() { + auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { + VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod); + const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass"); + ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function"; + Optional opt_cutlass_target = Target::Current(); + ICHECK(opt_cutlass_target.defined()) << "Expecting Target::Current to be available"; + ICHECK_EQ(opt_cutlass_target.value()->kind->name, "cutlass"); + runtime::Module runtime_mod = (*pf)(mod, opt_cutlass_target.value()); + Array external_mods = + mod->GetAttr>("external_mods", Array()).value(); + external_mods.push_back(runtime_mod); + return WithAttr(mod, "external_mods", external_mods); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "CompileForCutlass", {}); +} + +runtime::Module CreateCSourceModule(const IRModule& mod) { + VLOG(1) << "Creating CUTLASS CSource module from:" << std::endl << PrettyPrint(mod); + return CutlassModuleCodegen(mod).CreateCSourceModule(); } -TVM_REGISTER_GLOBAL("relay.ext.cutlass").set_body_typed(CutlassCompiler); +} // namespace + +TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule); + +transform::Pass CompileForCutlass() { + return transform::Sequential( + {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"), + CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")}); +} +} // namespace cutlass } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/cutlass/codegen.h b/src/relay/backend/contrib/cutlass/codegen.h new file mode 100644 index 000000000000..e70e97a2fafa --- /dev/null +++ b/src/relay/backend/contrib/cutlass/codegen.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/cutlass/codegen.h + * \brief The 'custom' compilation pass for CUTLASS (invoked by the RelayToTIRTargetHook pass). + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_CUTLASS_CODEGEN_H_ +#define TVM_RELAY_BACKEND_CONTRIB_CUTLASS_CODEGEN_H_ + +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace cutlass { + +/*! + * \brief Returns the pass which replaces all calls to "Primitive" functions with "Compiler" + * attribute of "cutlass" with an call to an extern, and binds a \p runtime::StaticLibrary + * to the IRModule's "external_mods" attribute containing compiled implementations of + * those functions using the CUTLASS C++ template library. + */ +transform::Pass CompileForCutlass(); + +} // namespace cutlass +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_CUTLASS_CODEGEN_H_ diff --git a/src/relay/backend/contrib/cutlass/target.cc b/src/relay/backend/contrib/cutlass/target.cc index 3a7384fb19cc..7b377f340a57 100644 --- a/src/relay/backend/contrib/cutlass/target.cc +++ b/src/relay/backend/contrib/cutlass/target.cc @@ -24,9 +24,12 @@ #include +#include "./codegen.h" + namespace tvm { namespace relay { namespace contrib { +namespace cutlass { /*! * \brief This external codegen target can use the CUTLASS template library included in @@ -36,8 +39,36 @@ namespace contrib { * src/relay/backend/contrib/cutlass/codegen.cc */ TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr("RelayToTIR", CompileForCutlass()) + // An integer specifying the compute capability. For example, 75 for Turing and + // 80 or 86 for Ampere. + .add_attr_option("sm", Integer(80)) + // Whether to use slower but very accurate (compared to tf32) 3xtf32 mode for + // fp32 inputs on tensorcore. + .add_attr_option("use_3xtf32", Bool(true)) + // Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in + // parallel across split-K blocks, and a separate global reduction kernel is launched to + // accumulate partial reductions. The profiler will pick the best split-k factor from the + // given candidate list. Note that the larger split-K factor requires a larger workspace. + // Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d + // kinds, split_k_slices is ignored. + .add_attr_option>("split_k_slices", Array({1})) + // When True, profile all kernel variants with smaller alignments than the largest possible. + .add_attr_option("profile_all_alignments", Bool(false)) + // Whether to profile all candidate kernels, or stop profiling after the first applicable kernel + // is found. + .add_attr_option("find_first_valid", Bool(false)) + // Whether to compile profiler executables for different kernels in parallel. + .add_attr_option("use_multiprocessing", Bool(false)) + // Number of threads to use during compilation, or -1 to use number of cpus. + .add_attr_option("threads", Integer(-1)) + // Whether to replace sigmoid with tanh. + .add_attr_option("use_fast_math", Bool(false)) + // A temporary directory where intermediate compiled artifacts will be stored. + .add_attr_option("tmp_dir", String("./tmp")); +} // namespace cutlass } // namespace contrib } // namespace relay } // namespace tvm diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 8e5238b17399..753ee178f9d3 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -16,7 +16,6 @@ # under the License. import logging import math -import pytest import tvm from tvm import relay from tvm.contrib.cudnn import conv_output_shape @@ -25,10 +24,12 @@ from tvm.relay.op.contrib.cutlass import partition_for_cutlass from tvm.relay.transform import FirstOrderGradient, ToMixedPrecision, InferType from tvm.contrib.cutlass import ( - tune_cutlass_kernels, - build_cutlass_kernels, - build_cutlass_kernels_vm, + has_cutlass, + num_cutlass_partitions, + finalize_modules, + finalize_modules_vm, ) +import tvm.testing logging.basicConfig(level=logging.INFO) @@ -37,10 +38,6 @@ def has_cublas(): return tvm.get_global_func("tvm.contrib.cublas.matmul", True) != None -def has_cutlass(): - return tvm.get_global_func("relay.ext.cutlass", True) != None - - def get_ref_rt_mod(mod, params, target="cuda"): with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) @@ -258,24 +255,33 @@ def profile_and_build( sm, split_k_slices=[1], tmp_dir="./tmp", - lib_path="compile.so", use_fast_math=False, use_3xtf32=True, ): + logging.info("before partitioning:\n%s", mod) mod = partition_for_cutlass(mod) - mod, num_cutlass_partition = tune_cutlass_kernels( - mod, - sm, - use_3xtf32=use_3xtf32, - split_k_slices=split_k_slices, - profile_all_alignments=False, - find_first_valid=True, - use_multiprocessing=True, - tmp_dir=tmp_dir, + logging.info("after partitioning:\n%s", mod) + + num_cutlass_partition = num_cutlass_partitions(mod) + host = tvm.target.Target("llvm") + cuda = tvm.target.Target("cuda", host=host) + cutlass = tvm.target.Target( + { + "kind": "cutlass", + "sm": sm, + "use_3xtf32": use_3xtf32, + "split_k_slices": split_k_slices, + "profile_all_alignments": False, + "find_first_valid": True, + "use_multiprocessing": True, + "use_fast_math": use_fast_math, + "tmp_dir": tmp_dir, + }, + host=host, ) with tvm.transform.PassContext(opt_level=3): - lib = relay.build(mod, target="cuda", params=params) - lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path, use_fast_math=use_fast_math) + lib = relay.build(mod, target=[cuda, cutlass], params=params) + lib = finalize_modules(lib, "compile.so", tmp_dir) dev = tvm.device("cuda", 0) rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) return rt_mod, dev, num_cutlass_partition @@ -287,26 +293,30 @@ def profile_and_build_vm( sm, split_k_slices=[1], tmp_dir="./tmp", - lib_path="compile.so", - vmcode_path="vmcode.ro", use_fast_math=False, use_3xtf32=True, ): mod = partition_for_cutlass(mod) - mod, num_cutlass_partition = tune_cutlass_kernels( - mod, - sm, - split_k_slices=split_k_slices, - use_3xtf32=use_3xtf32, - profile_all_alignments=False, - find_first_valid=True, - tmp_dir=tmp_dir, + num_cutlass_partition = num_cutlass_partitions(mod) + host = tvm.target.Target("llvm") + cuda = tvm.target.Target("cuda", host=host) + cutlass = tvm.target.Target( + { + "kind": "cutlass", + "sm": sm, + "use_3xtf32": use_3xtf32, + "split_k_slices": split_k_slices, + "profile_all_alignments": False, + "find_first_valid": True, + "use_multiprocessing": True, + "use_fast_math": use_fast_math, + "tmp_dir": tmp_dir, + }, + host=host, ) with tvm.transform.PassContext(opt_level=3): - vm_exec = relay.vm.compile(mod, target="cuda", params=params) - vm_exec = build_cutlass_kernels_vm( - vm_exec, sm, tmp_dir, lib_path, vmcode_path, use_fast_math=use_fast_math - ) + vm_exec = relay.vm.compile(mod, target=[cuda, cutlass], params=params) + vm_exec = finalize_modules_vm(vm_exec, "compile.so", tmp_dir) dev = tvm.device("cuda", 0) return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition @@ -325,8 +335,7 @@ def verify_dense( weight_dtype="float16", use_3xtf32=True, ): - if not has_cutlass(): - return + assert has_cutlass() if sm < 80 and data_dtype == "float32": return @@ -377,8 +386,7 @@ def verify_dense( def verify_batch_matmul( func, batch, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False ): - if not has_cutlass(): - return + assert has_cutlass() mod = tvm.IRModule.from_expr(func) typ = relay.transform.InferType()(mod)["main"].body.checked_type use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) @@ -415,6 +423,7 @@ def verify_batch_matmul( K = 64 +@tvm.testing.requires_cutlass def test_dense(): verify_dense(get_dense(M, N, K), M, N, K) verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) @@ -449,21 +458,25 @@ def test_dense(): ) +@tvm.testing.requires_cutlass def test_dense_bias(): verify_dense(get_dense_bias(M, N, K), M, N, K) verify_dense(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K) +@tvm.testing.requires_cutlass def test_dense_bias_relu(): verify_dense(get_dense_bias_relu(M, N, K), M, N, K) verify_dense(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K) +@tvm.testing.requires_cutlass def test_dense_bias_gelu(): verify_dense(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3) verify_dense(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, atol=1e-3, rtol=1e-3) +@tvm.testing.requires_cutlass def test_dense_dynamic(): data_shape = (relay.Any(), K) weight_shape = (relay.Any(), K) @@ -490,6 +503,7 @@ def test_dense_dynamic(): ) +@tvm.testing.requires_cutlass def test_batch_matmul(): batch = 8 verify_batch_matmul(get_batch_matmul(batch, M, N, K), batch, M, N, K) @@ -527,8 +541,7 @@ def verify_conv2d_common( ref_target="cuda", use_vm=False, ): - if not has_cutlass(): - return + assert has_cutlass() if sm < 80 and inputs[0].dtype == "float32": return @@ -666,6 +679,7 @@ def verify_conv2d_backward_weight( ) +@tvm.testing.requires_cutlass def test_conv2d(): padding = (1, 1) for IC in [3, 16]: @@ -746,6 +760,7 @@ def test_conv2d(): ) +@tvm.testing.requires_cutlass def test_conv2d_fusion(): d_shape = (16, 16, 32, 32) w_shape = (32, 16, 3, 3) @@ -793,6 +808,7 @@ def test_conv2d_fusion(): ) +@tvm.testing.requires_cutlass def test_conv2d_residual_block(): d_shape = (16, 16, 32, 32) w_shape = (16, 16, 3, 3) @@ -813,6 +829,7 @@ def test_conv2d_residual_block(): verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, run_benchmark=False) +@tvm.testing.requires_cutlass def test_conv2d_transpose(): OC = 8 IC = 16 @@ -852,6 +869,7 @@ def test_conv2d_transpose(): ) +@tvm.testing.requires_cutlass def test_conv2d_backward_weight(): OC = 8 IC = 16 @@ -890,6 +908,7 @@ def test_conv2d_backward_weight(): ) +@tvm.testing.requires_cutlass def test_conv2d_bwd(): IC = 16 OC = 8 From 9c16e05020512493d86c249965d08312ab51704c Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 10 Jun 2022 14:50:27 -0700 Subject: [PATCH 2/3] - Masa's comments --- python/tvm/contrib/cutlass/build.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index d954fb813fd4..b266eb7a7041 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -572,7 +572,7 @@ def compile_for_cutlass(mod, cutlass_target): return final_mod -def finalize_modules(lib, lib_path, tmp_dir): +def finalize_modules(lib, lib_path="compile.so", tmp_dir="./tmp"): """Returns lib with any C source, LLVM and static library modules complied and linked in ready for use by the graph or AOT executors. This method is not specific to CUTLASS, however it does assume nvcc will be used for final compilation and linking. It is provided here for @@ -584,14 +584,15 @@ def finalize_modules(lib, lib_path, tmp_dir): The output from relay.build. lib_path : string - Name for temporary library .so file. + The path to a shared library which will be generated as the result of the build process. - tmp_dir : Working temporary directory. + tmp_dir : string + A temporary directory where intermediate compiled artifacts will be stored. Returns ------- - updated_lib : runtime::Module - The given lib with any final compilation and linking steps completed. + updated_lib : runtime.Module + The updated library with all compilation and linking completed. """ lib_path = os.path.join(tmp_dir, lib_path) @@ -599,7 +600,7 @@ def finalize_modules(lib, lib_path, tmp_dir): return runtime.load_module(lib_path) -def finalize_modules_vm(vm_exec, lib_path, tmp_dir): +def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", tmp_dir="./tmp"): """Returns vm_exec with any C source, LLVM and static library modules compiled and linked in ready for use by the VM executor. This method is not specific to CUTLASS, however it does assume nvcc will be used for final compilation and linking. It is provided here for @@ -608,20 +609,28 @@ def finalize_modules_vm(vm_exec, lib_path, tmp_dir): Parameters ---------- vm_exec : vm.Executable - The output from relay.vm.compile. + The output from relay.vm.compile containing compiled host code and kernels. lib_path : string - Name for temporary library .so file. + The path to a shared library which will be generated as the result of the build process. + + vmcode_path : string + The path where the VM bytecode will be serialized to. - tmp_dir : Working temporary directory. + tmp_dir : string + A temporary directory where intermediate compiled artifacts will be stored. Returns ------- updated_vm_exec : vm.Executable - The given lib with any final compilation and linking steps completed. + The updated VM executable with all compilation and linking completed. """ code, lib = vm_exec.save() lib_path = os.path.join(tmp_dir, lib_path) + vmcode_path = os.path.join(tmp_dir, vmcode_path) lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") + with open(vmcode_path, "wb") as fo: + fo.write(code) lib = tvm.runtime.load_module(lib_path) + code = bytearray(open(vmcode_path, "rb").read()) return tvm.runtime.vm.Executable.load_exec(code, lib) From d6efea03c67d66ce9c82b49bc64e9ff3436d9f53 Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 10 Jun 2022 17:09:42 -0700 Subject: [PATCH 3/3] - Remove unnecessary save. --- python/tvm/contrib/cutlass/build.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index b266eb7a7041..0c8c2ad0b2b9 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -615,7 +615,7 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", The path to a shared library which will be generated as the result of the build process. vmcode_path : string - The path where the VM bytecode will be serialized to. + The path where the VM bytecode will be serialized to as a side-effect. tmp_dir : string A temporary directory where intermediate compiled artifacts will be stored. @@ -632,5 +632,4 @@ def finalize_modules_vm(vm_exec, lib_path="compile.so", vmcode_path="vmcode.ro", with open(vmcode_path, "wb") as fo: fo.write(code) lib = tvm.runtime.load_module(lib_path) - code = bytearray(open(vmcode_path, "rb").read()) return tvm.runtime.vm.Executable.load_exec(code, lib)