diff --git a/python/tvm/contrib/msc/core/codegen/__init__.py b/python/tvm/contrib/msc/core/codegen/__init__.py index 1df257ed6d6b..78da1b3fdd69 100644 --- a/python/tvm/contrib/msc/core/codegen/__init__.py +++ b/python/tvm/contrib/msc/core/codegen/__init__.py @@ -17,3 +17,4 @@ """tvm.contrib.msc.core.codegen""" from .codegen import * +from .sources import * diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index 76b3fdce54e7..9245e13f42f0 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -16,6 +16,8 @@ # under the License. """tvm.contrib.msc.core.codegen.codegen""" +import os +import subprocess from typing import Dict, List, Optional, Any, Callable import tvm @@ -40,6 +42,8 @@ class CodeGen(object): The config to print code. build_folder: MSCDirectory The codegen folder. + coda_format: str + The code format cpp| python. """ def __init__( @@ -49,17 +53,21 @@ def __init__( codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, + code_format: str = "python", ): self._graph = graph self._source_getter = source_getter self._codegen_config = msc_utils.dump_dict(codegen_config) self._print_config = msc_utils.dump_dict(print_config) self._build_folder = build_folder or msc_utils.msc_dir(keep_history=False, cleanup=True) + self._code_format = code_format def load( self, inputs: Optional[List[Any]] = None, - weights_binder: Optional[Callable[[MSCGraph, Any, msc_utils.MSCDirectory], Any]] = None, + pre_load: Optional[Callable[[msc_utils.MSCDirectory], Any]] = None, + post_load: Optional[Callable[[Any, msc_utils.MSCDirectory], Any]] = None, + build_model: bool = True, ) -> Any: """Generate source and load the model @@ -67,8 +75,12 @@ def load( ------- inputs: list The inputs to build the model. - weights_binder: Callable - The method for binding weights to the model. + pre_load: Callable + The pre processing method before load. + post_load: Callable + The post processing method after load. + build_model: bool + Whether to build the model. Returns ------- @@ -79,13 +91,33 @@ def load( sources = self._source_getter(self._graph, self._codegen_config, self._print_config) inputs = inputs or [] with self._build_folder as folder: + # pre processing + if pre_load: + pre_load(folder) for name, source in sources.items(): folder.add_file(name, source) - builder = msc_utils.load_callable(self._graph.name + ".py:" + self._graph.name) - obj = builder(*inputs) - # load weights - if weights_binder: - obj = weights_binder(obj, folder) + if build_model: + if self._code_format == "cpp": + with folder.create_dir("build"): + command = "cmake ../ && make && mv {} ../".format(self._graph.name) + process = subprocess.Popen(command, shell=True) + process.wait() + assert process.returncode == 0, "Failed to build {} under {}".format( + self._graph.name, os.getcwd() + ) + obj = self._graph.name + elif self._code_format == "python": + builder = msc_utils.load_callable(self._graph.name + ".py:" + self._graph.name) + obj = builder(*inputs) + else: + raise NotImplementedError( + "Code format {} is not supported".format(self._code_format) + ) + # post processing + if post_load: + obj = post_load(obj, folder) + else: + obj = None return obj @@ -125,8 +157,7 @@ def relay_to_relax( opt_config=opt_config, ) source_getter = tvm.get_global_func("msc.framework.tvm.GetRelaxSources") - codegen_config = {"from_relay": True} - codegen = CodeGen(graph, source_getter, codegen_config) + codegen = CodeGen(graph, source_getter, codegen_config={"from_relay": True}) inputs = [ tvm.relax.Var(i.alias, tvm.relax.TensorStructInfo(i.get_shape(), i.dtype_name)) for i in graph.get_inputs() @@ -136,4 +167,4 @@ def relay_to_relax( def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: return BindParams("main", weights)(mod) - return codegen.load(inputs, _bind_weights) + return codegen.load(inputs, post_load=_bind_weights) diff --git a/python/tvm/contrib/msc/core/codegen/sources.py b/python/tvm/contrib/msc/core/codegen/sources.py new file mode 100644 index 000000000000..948d76087cde --- /dev/null +++ b/python/tvm/contrib/msc/core/codegen/sources.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.codegen.sources""" + +from typing import Dict + + +def get_base_h_code() -> str: + """Create base header file codes + + Returns + ------- + source: str + The base header source. + """ + + return """#ifndef TVM_CONTRIB_MSC_UTILS_BASE_H_ +#define TVM_CONTRIB_MSC_UTILS_BASE_H_ + +#include +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +class FileUtils { + public: + static inline bool FileExist(const std::string& file); + + template + static bool ReadToBuffer(const std::string& file, T* buffer, size_t size); +}; + +class DatasetReader { + public: + DatasetReader(const std::string& folder, int max_size = -1); + + void Reset(); + + bool ReadNext(void* buffers[], int num_datas = -1); + + private: + std::string folder_; + size_t max_size_; + size_t cur_cnt_; + std::vector> tensor_info_; +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_MSC_UTILS_BASE_H_ +""" + + +def get_base_cc_code() -> str: + """Create base cc file codes + + Returns + ------- + source: str + The base cc source. + """ + + return """#include "base.h" + +#include +#include + +namespace tvm { +namespace contrib { +namespace msc { + +bool FileUtils::FileExist(const std::string& file) { + std::ifstream in_file(file, std::ifstream::binary); + if (in_file.is_open()) { + in_file.close(); + return true; + } + return false; +} + +template +bool FileUtils::ReadToBuffer(const std::string& file, T* buffer, size_t size) { + std::ifstream in_file(file, std::ifstream::binary); + if (!in_file.is_open()) { + return false; + } + try { + in_file.read((char*)(&buffer[0]), size * sizeof(T)); + } catch (std::exception const& e) { + in_file.close(); + return false; + } + in_file.close(); + return true; +} + +DatasetReader::DatasetReader(const std::string& folder, int max_size) { + folder_ = folder; + const std::string info_file = folder_ + "/tensor_info"; + std::ifstream input(info_file, std::ios::binary); + assert(input.is_open() && ("Failed to open file " + info_file).c_str()); + std::string line; + while (getline(input, line)) { + int pos = line.find(":"); + assert(pos > 0 && ("Can not find : in line " + line).c_str()); + const auto& name = line.substr(0, pos); + const auto& byte_size = line.substr(pos + 1, line.size()); + tensor_info_.push_back(std::make_pair(name, static_cast(std::stoi(byte_size)))); + } + size_t file_cnt = 0; + while (true) { + bool all_exists = true; + for (const auto& pair : tensor_info_) { + const auto& d_file = + folder_ + "/" + pair.first + "/batch_" + std::to_string(file_cnt) + ".bin"; + if (!FileUtils::FileExist(d_file)) { + all_exists = false; + break; + } + } + if (!all_exists) { + break; + } + file_cnt++; + } + max_size_ = max_size > 0 ? static_cast(max_size) : file_cnt; + max_size_ = std::min(max_size_, file_cnt); + Reset(); +} + +void DatasetReader::Reset() { cur_cnt_ = 0; } + +bool DatasetReader::ReadNext(void* buffers[], int num_datas) { + if (cur_cnt_ >= max_size_) { + return false; + } + size_t max_num = num_datas > 0 ? static_cast(num_datas) : tensor_info_.size(); + max_num = std::min(max_num, tensor_info_.size()); + for (size_t i = 0; i < max_num; i++) { + const auto& pair = tensor_info_[i]; + const auto& d_file = folder_ + "/" + pair.first + "/batch_" + std::to_string(cur_cnt_) + ".bin"; + if (!FileUtils::ReadToBuffer(d_file, (char*)buffers[i], pair.second)) { + return false; + } + } + cur_cnt_++; + return true; +} + +} // namespace msc +} // namespace contrib +} // namespace tvm +""" + + +def get_base_sources() -> Dict[str, str]: + """Create base sources for cpp codegen + + Returns + ------- + sources: dict + The base utils sources. + """ + + return {"base.h": get_base_h_code(), "base.cc": get_base_cc_code()} diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index cafff112ce3c..154703d332ce 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -210,6 +210,22 @@ def output_at(self, idx: int) -> MSCTensor: return _ffi_api.MSCJointOutputAt(self, idx) + def weight_at(self, wtype: str) -> MSCTensor: + """Get weight from reference. + + Parameters + ---------- + wtype: str + The type of weight. + + Returns + ------- + weight: MSCTensor + The weight Tensor. + """ + + return _ffi_api.MSCJointWeightAt(self, wtype) + def get_inputs(self) -> List[MSCTensor]: """Get all the inputs. @@ -242,7 +258,7 @@ def get_weights(self) -> Dict[str, MSCTensor]: """ src_weights = _ffi_api.MSCJointGetWeights(self) - return {ref: src_weights[ref] for ref in src_weights} + return {wtype: src_weights[wtype] for wtype in src_weights} def get_attrs(self) -> Dict[str, str]: """Get all the attributes from node @@ -406,6 +422,22 @@ def __init__( output_names, ) + def has_node(self, name: str) -> bool: + """Check if node in the graph. + + Parameters + ---------- + name: string + The name of the node. + + Returns + ------- + has_node: bool + Whether the node is in the graph + """ + + return bool(_ffi_api.MSCGraphHasNode(self, name)) + def find_node(self, name: str) -> MSCJoint: """Find node by name. @@ -422,6 +454,22 @@ def find_node(self, name: str) -> MSCJoint: return _ffi_api.MSCGraphFindNode(self, name) + def has_tensor(self, name: str) -> bool: + """Check if tensor in the graph. + + Parameters + ---------- + name: string + The name of the tensor. + + Returns + ------- + has_tensor: bool + Whether the tensor is in the graph + """ + + return bool(_ffi_api.MSCGraphHasTensor(self, name)) + def find_tensor(self, name: str) -> MSCTensor: """Find tensor by name. diff --git a/python/tvm/contrib/msc/core/ir/translate.py b/python/tvm/contrib/msc/core/ir/translate.py index 8b6b48ebd465..082859519b35 100644 --- a/python/tvm/contrib/msc/core/ir/translate.py +++ b/python/tvm/contrib/msc/core/ir/translate.py @@ -16,12 +16,15 @@ # under the License. """tvm.contrib.msc.core.ir.translate""" -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, List import tvm from tvm.relax.transform import BindParams +from tvm.relax import PyExprVisitor from tvm.relax.backend.pattern_registry import get_patterns_with_prefix +from tvm.relay.expr_functor import ExprVisitor from tvm.relay.build_module import bind_params_by_name +from tvm.relay import dataflow_pattern as relay_pattern from tvm.contrib.msc.core import transform as msc_transform from tvm.contrib.msc.core import _ffi_api from tvm.contrib.msc.core import utils as msc_utils @@ -56,13 +59,13 @@ def _to_data(ref_t, data): ref_layout, weight_layout = ref_t.layout.name, weight_t.layout.name if ref_layout != weight_layout: assert all( - l.name in ref_layout for l in weight_layout + l in ref_layout for l in weight_layout ), "layout mismatch {} compare to {}".format(ref_t, weight_t) permute = [ref_layout.index(l) for l in weight_layout] return tvm.nd.array(data.asnumpy().transpose(*permute)) return data - weights = {t.name: _to_data(t, d) for t, d in t_weights.items()} + weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if graph.has_tensor(t.name)} return weights @@ -71,6 +74,7 @@ def from_relax( params: Optional[Dict[str, tvm.nd.array]] = None, trans_config: Optional[Dict[str, str]] = None, build_config: Optional[Dict[str, str]] = None, + opt_config: Optional[Dict[str, str]] = None, ) -> Tuple[MSCGraph, Dict[str, tvm.nd.array]]: """Change IRModule to MSCGraph. @@ -84,6 +88,8 @@ def from_relax( The config for transfrorm IRModule. build_config: dict The config for build MSCGraph. + opt_config: dict + The config for optimize the relax before translate. Returns ------- @@ -95,23 +101,85 @@ def from_relax( trans_config = trans_config or {} build_config = build_config or {} - # TODO(tong.meng): optimize before translate? + opt_config = opt_config or {} + entry = trans_config.get("entry", "main") if params: mod = BindParams("main", params)(mod) + opt_level = opt_config.get("opt_level", 1) + if opt_level > 0: + mod = tvm.transform.Sequential( + [ + tvm.relax.transform.FoldConstant(), + ] + )(mod) patterns = get_patterns_with_prefix("msc") passes = [ tvm.relax.transform.FuseOpsByPattern( patterns, bind_constants=False, annotate_codegen=False ), - msc_transform.SetExprName(), - msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), + msc_transform.SetExprName(entry_name=entry, target=trans_config.get("target", "")), + msc_transform.SetExprLayout( + trans_config.get("allow_layout_missing", True), entry_name=entry + ), ] mod = tvm.transform.Sequential(passes)(mod) - graph = _ffi_api.BuildFromRelax(mod, "main", msc_utils.dump_dict(build_config)) - t_weights = _ffi_api.GetRelaxWeights(mod, "main") + graph = _ffi_api.BuildFromRelax(mod, entry, msc_utils.dump_dict(build_config)) + t_weights = _ffi_api.GetRelaxWeights(mod, entry) return graph, normalize_weights(t_weights, graph) +def get_relay_patterns( + mod: tvm.IRModule, + entry_name: str = "main", +) -> List[Tuple[str, relay_pattern.DFPattern, callable]]: + """Filter relay patterns based on mod. + + Parameters + ---------- + mod: IRModule + The IRModule of relay. + entry_name: str + The entry name. + + Returns + ------- + patterns: list + The useful patterns for relay + """ + + class OpExtractor(ExprVisitor): + """Extract ops from expr.""" + + def extract(self, expr): + self._optypes = set() + super().visit(expr) + return self._optypes + + def visit_call(self, expr): + super().visit_call(expr) + if isinstance(expr.op, tvm.ir.Op): + self._optypes.add(expr.op.name) + + op_names = OpExtractor().extract(mod[entry_name]) + skip_tags, patterns = set(), list(tvm.relay.op.contrib.get_pattern_table("msc")) + if "nn.conv1d" not in op_names or "add" not in op_names: + skip_tags.add("msc.conv1d_bias") + if "nn.conv2d" not in op_names or "add" not in op_names: + skip_tags.add("msc.conv2d_bias") + if "nn.batch_matmul" not in op_names or "add" not in op_names: + skip_tags.add("msc.linear_bias") + if "nn.batch_matmul" not in op_names: + skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.linear")) + if "nn.dense" not in op_names: + skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.matmul")) + if "take" not in op_names: + skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.embedding")) + if "erf" not in op_names: + skip_tags |= set(p[0] for p in patterns if p[0].startswith("msc.gelu")) + valid_patterns = [p for p in patterns if p[0] not in skip_tags] + return valid_patterns + + def from_relay( mod: tvm.IRModule, params: Optional[Dict[str, tvm.nd.array]] = None, @@ -147,10 +215,9 @@ def from_relay( opt_config = opt_config or {} # TODO(tong.meng): optimize before translate? opt_level = opt_config.get("opt_level", 0) - if opt_level == 0: - if params: - mod["main"] = bind_params_by_name(mod["main"], params) - else: + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + if opt_level > 0: target = opt_config.get("target", "llvm") disabled_pass = opt_config.get("disabled_pass", []) + [ "SimplifyInference", @@ -160,7 +227,7 @@ def from_relay( ] with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): mod, params = tvm.relay.optimize(mod, target=target, params=params) - patterns = tvm.relay.op.contrib.get_pattern_table("msc") + patterns = get_relay_patterns(mod) passes = [ tvm.relay.transform.InferType(), tvm.relay.transform.MergeComposite(patterns), @@ -170,3 +237,99 @@ def from_relay( graph = _ffi_api.BuildFromRelay(mod, "main", msc_utils.dump_dict(build_config)) t_weights = _ffi_api.GetRelayWeights(mod, "main") return graph, normalize_weights(t_weights, graph) + + +@tvm.relax.expr_functor.visitor +class BYOCChecker(PyExprVisitor): + """Checker to check if any non-target ops exist""" + + def check(self, func_names, expr): + self._func_names = func_names + self._non_target_exprs = [] + if isinstance(expr, tvm.relax.Expr): + self.visit_expr(expr) + elif isinstance(expr, tvm.relax.BindingBlock): + self.visit_binding_block(expr) + assert len(self._non_target_exprs) == 0, "Exprs not on target {}".format( + self._non_target_exprs + ) + + def visit_var_binding_(self, binding) -> None: + super().visit_var_binding_(binding) + if isinstance(binding.value, tvm.relax.Call): + if isinstance(binding.value.op, tvm.relax.GlobalVar): + if binding.value.op.name_hint not in self._func_names: + self._non_target_exprs.append(binding.value) + else: + self._non_target_exprs.append(binding.value) + + +def byoc_partition( + target: str, + mod: tvm.IRModule, + params: Optional[Dict[str, tvm.nd.array]] = None, + trans_config: Optional[Dict[str, str]] = None, + build_config: Optional[Dict[str, str]] = None, + allow_incomplete: bool = True, +) -> Tuple[tvm.IRModule, List[Tuple[str, MSCGraph, Dict[str, tvm.nd.array]]]]: + """Partition module to target sub functions. + + Parameters + ---------- + target: str + The target for the BYOC. + mod: IRModule + The IRModule of relax. + trans_config: dict + The config for transfrorm IRModule. + params: dict of + The parameters of the IRModule. + build_config: dict + The config for build MSCGraph. + allow_incomplete: bool + Whether allow some ops not on tensorrt + + + Returns + ------- + mod: IRModule + The IRModule of partitioned relax. + graphs_info: list<> + The func list, each element for a sub graph. + """ + + trans_config = trans_config or {} + build_config = build_config or {} + build_config["target"] = target + entry = trans_config.get("entry", "main") + if params: + mod = BindParams("main", params)(mod) + + patterns = get_patterns_with_prefix(target) + mod = tvm.transform.Sequential( + [ + tvm.relax.transform.FuseOpsByPattern( + patterns, bind_constants=False, annotate_codegen=False + ), + tvm.relax.transform.MergeCompositeFunctions(), + msc_transform.SetExprName(target=target), + msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), + ] + )(mod) + + def _is_target_func(func): + if "Codegen" not in func.attrs: + return False + return func.attrs["Codegen"] == target + + func_names = [var.name_hint for var, func in mod.functions.items() if _is_target_func(func)] + + if not allow_incomplete: + BYOCChecker().check(func_names, mod[entry]) + + graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(mod, entry) + for idx, name in enumerate(func_names): + build_config["graph_name"] = target + "_" + str(idx) + graph = _ffi_api.BuildFromRelax(mod, name, msc_utils.dump_dict(build_config)) + graphs_info.append((name, graph, normalize_weights(all_weights, graph))) + return mod, graphs_info diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py index 76e9651c603e..08e054ac288b 100644 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -320,9 +320,155 @@ def _check_relax_mask_attention(context: PatternCheckContext) -> bool: return True +def make_opt_relax_conv_bias_pattern( + op_name: str, +) -> Tuple[relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern]]: + """Create patterns for an conv2d fused with bias, for mod after optimize. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + + Returns + ------- + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a conv_bias operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + data = relax_pattern.wildcard() + weight = relax_pattern.is_const() + conv = relax_pattern.is_op(op_name)(data, weight) + bias = relax_pattern.is_const() + out = relax_pattern.is_op("relax.add")(conv, bias) + annotations = {"data": data, "weight": weight, "bias": bias, "conv": conv, "out": out} + return out, annotations + + +def _check_opt_relax_conv_bias(context: PatternCheckContext) -> bool: + """Check if conv_bias fuse pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + ndim_conv = len(context.annotated_expr["conv"].struct_info.shape.values) + ndim_bias = len(context.annotated_expr["bias"].struct_info.shape.values) + ndim_out = len(context.annotated_expr["out"].struct_info.shape.values) + return ndim_conv == ndim_bias and ndim_bias == ndim_out + + +def make_opt_relax_linear_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: + """Create patterns for an linear, for mod after optimize. + + Returns + ------- + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a conv_bias operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + data = relax_pattern.wildcard() + weight = relax_pattern.is_const() + out = relax_pattern.is_op("relax.matmul")(data, weight) + annotations = {"weight": weight} + return out, annotations + + +def _check_opt_relax_linear(context: PatternCheckContext) -> bool: + """Check if linear fuse pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + ndim_weight = len(context.annotated_expr["weight"].struct_info.shape.values) + return ndim_weight == 2 + + +def make_opt_relax_linear_bias_pattern() -> Tuple[ + relax_pattern.DFPattern, Mapping[str, relax_pattern.DFPattern] +]: + """Create patterns for an linear_bias, for mod after optimize. + + Returns + ------- + out: tvm.relax.dpl.pattern.DFPattern + The resulting pattern describing a conv_bias operation. + + annotations: Mapping[str, tvm.relax.dpl.pattern.DFPattern] + A mapping from name to sub pattern. It can be used to extract + important expressions from match result, to power the partition + check function and codegen. + """ + + data = relax_pattern.wildcard() + weight = relax_pattern.is_const() + linear = relax_pattern.is_op("relax.matmul")(data, weight) + bias = relax_pattern.is_const() + out = relax_pattern.is_op("relax.add")(linear, bias) + annotations = {"weight": weight, "bias": bias, "linear": linear, "out": out} + return out, annotations + + +def _check_opt_relax_linear_bias(context: PatternCheckContext) -> bool: + """Check if linear fuse pattern is correct. + + Returns + ------- + pass: bool + Whether the pattern is correct. + """ + + if not _check_opt_relax_linear(context): + return False + ndim_bias = len(context.annotated_expr["bias"].struct_info.shape.values) + ndim_out = len(context.annotated_expr["out"].struct_info.shape.values) + return ndim_bias == 1 or ndim_bias == ndim_out + + # TODO(tong.meng): support patterns after optimize register_patterns( [ + ( + "msc.conv1d_bias", + *make_opt_relax_conv_bias_pattern( + "relax.nn.conv1d", + ), + _check_opt_relax_conv_bias, + ), + ( + "msc.conv2d_bias", + *make_opt_relax_conv_bias_pattern( + "relax.nn.conv2d", + ), + _check_opt_relax_conv_bias, + ), + ( + "msc.linear", + *make_opt_relax_linear_pattern(), + _check_opt_relax_linear, + ), + ( + "msc.linear_bias", + *make_opt_relax_linear_bias_pattern(), + _check_opt_relax_linear_bias, + ), ( "msc.conv1d_bias", *make_relax_conv_bias_pattern( diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py index 355922d6def2..a94a044633b2 100644 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -22,7 +22,9 @@ from tvm.relay.transform import _ffi_api as relay_api -def SetExprName(as_relax=True, entry_name="main") -> tvm.ir.transform.Pass: +def SetExprName( + as_relax: bool = True, entry_name: str = "main", target: str = "" +) -> tvm.ir.transform.Pass: """Set name for the call and constant in IRModule. Parameters @@ -31,7 +33,8 @@ def SetExprName(as_relax=True, entry_name="main") -> tvm.ir.transform.Pass: Whether set names for relax, otherwise for relay. entry_name: str The entry name - + target: str + The target prefix for target functions Returns ------- @@ -39,11 +42,33 @@ def SetExprName(as_relax=True, entry_name="main") -> tvm.ir.transform.Pass: """ if as_relax: - return relax_api.SetRelaxExprName(entry_name) # type: ignore + return relax_api.SetRelaxExprName(entry_name, target) # type: ignore return relay_api.SetRelayExprName(entry_name) # type: ignore -def SetExprLayout(allow_missing=True, entry_name="main") -> tvm.ir.transform.Pass: +def BindExprName( + name_key: str = "", seperator: str = ",", entry_name: str = "main" +) -> tvm.ir.transform.Pass: + """Bind name for the call and constant in IRModule. + + Parameters + ---------- + name_key: str + The key to find name + seperator: str + The seperator + entry_name: str + The entry name + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + + return relay_api.BindRelayExprName(name_key, seperator, entry_name) # type: ignore + + +def SetExprLayout(allow_missing: bool = True, entry_name: str = "main") -> tvm.ir.transform.Pass: """Set layout for the var and constant in IRModule. Parameters diff --git a/python/tvm/contrib/msc/core/utils/expr.py b/python/tvm/contrib/msc/core/utils/expr.py index ad459e78325d..8e4f8bd1b7eb 100644 --- a/python/tvm/contrib/msc/core/utils/expr.py +++ b/python/tvm/contrib/msc/core/utils/expr.py @@ -16,6 +16,8 @@ # under the License. """tvm.contrib.msc.core.utils.expr""" +import copy + import tvm from tvm import relax from tvm.relax import PyExprVisitor @@ -41,6 +43,7 @@ class SpanVisitor(PyExprVisitor): def extract(self, expr: relax.Expr) -> dict: self._span_info = {} + self._local_funcs = {} if isinstance(expr, relax.Expr): self.visit_expr(expr) elif isinstance(expr, relax.BindingBlock): @@ -53,11 +56,21 @@ def _update_attrs(self, expr: relax.Expr, name: str = "") -> None: name = name or _ffi_api.SpanGetAttr(expr.span, "name") if not name: return - self._span_info[name] = _ffi_api.SpanGetAttrs(expr.span) + self._span_info[name] = dict(_ffi_api.SpanGetAttrs(expr.span)) def visit_var_binding_(self, binding: relax.VarBinding) -> None: - super().visit_var_binding_(binding) - self._update_attrs(binding.value, binding.var.name_hint) + if isinstance(binding.value, relax.expr.Function): + self._local_funcs[binding.var] = binding.value + elif ( + isinstance(binding.value, relax.expr.Call) and binding.value.op in self._local_funcs + ): + cache_info = copy.deepcopy(self._span_info) + func_info = self.extract(self._local_funcs[binding.value.op]) + self._span_info = cache_info + self._span_info[binding.value.op.name_hint] = func_info + else: + super().visit_var_binding_(binding) + self._update_attrs(binding.value, binding.var.name_hint) def visit_constant_(self, op: relax.Constant) -> None: super().visit_constant_(op) diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index c726240075be..70f0d5692c7f 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -59,7 +59,7 @@ def clean_up(self): if self._cleanup and os.path.isdir(self._path): shutil.rmtree(self._path) - def add_file(self, name: str, contains: str): + def add_file(self, name: str, contains: str) -> str: """Add a file under the folder Parameters @@ -72,11 +72,85 @@ def add_file(self, name: str, contains: str): Returns ------- path: str - The concatenated path. + The abs file path. """ - with open(self.relpath(name), "w") as f: + file_path = self.relpath(name) + with open(file_path, "w") as f: f.write(contains) + return file_path + + def move_file(self, src_file: str, dst_folder: object, dst_file: str = None): + """Move a file to another folder + + Parameters + ---------- + src_file: str + The name of the source file. + dst_folder: MSCDirectory + The target folder. + dst_file: str + The target file name. + + Returns + ------- + path: str + The abs file path. + """ + + src_path = os.path.join(self.relpath(src_file)) + assert os.path.isfile(src_path), "Source file {} not exist".format(src_path) + dst_path = dst_folder.relpath(dst_file or src_file) + os.rename(src_path, dst_path) + return dst_path + + def copy_file(self, src_file: str, dst_folder: object, dst_file: str = None): + """Copy a file to another folder + + Parameters + ---------- + src_file: str + The name of the source file. + dst_folder: MSCDirectory + The target folder. + dst_file: str + The target file name. + + Returns + ------- + path: str + The abs file path. + """ + + src_path = os.path.join(self.relpath(src_file)) + assert os.path.isfile(src_path), "Source file {} not exist".format(src_path) + dst_path = dst_folder.relpath(dst_file or src_file) + shutil.copy2(src_path, dst_path) + return dst_path + + def create_dir(self, name: str, keep_history: bool = True, cleanup: bool = False) -> object: + """Add a dir under the folder + + Parameters + ---------- + name: str + The name of the file. + keep_history: bol + Whether to keep history. + cleanup: bool + Whether to clean up before exit. + + + Returns + ------- + dir: MSCDirectory + The created dir. + """ + + dir_path = self.relpath(name) + if os.path.isfile(dir_path): + os.remove(dir_path) + return self.__class__(dir_path, keep_history=keep_history, cleanup=cleanup) def relpath(self, name: str) -> str: """Relative path in dir diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 9e06a6d909f6..98967c361881 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -18,6 +18,11 @@ import os import json +from typing import List +from distutils.version import LooseVersion + +import tvm +from .namespace import MSCFramework def load_dict(str_dict: str, flavor: str = "json") -> dict: @@ -96,3 +101,38 @@ def dict_equal(dict_a: dict, dict_b: dict) -> bool: if v != dict_b[k]: return False return True + + +def get_version(framework: str) -> List[int]: + """Get the version list of framework. + + Parameters + ---------- + framework: string + Should be from MSCFramework. + + Returns + ------- + version: list + The version in . + """ + + try: + if framework in (MSCFramework.MSC, MSCFramework.TVM): + raw_version = tvm.__version__ + elif framework == MSCFramework.TORCH: + import torch # pylint: disable=import-outside-toplevel + + raw_version = torch.__version__ + elif framework == MSCFramework.TENSORFLOW: + import tensorflow # pylint: disable=import-outside-toplevel + + raw_version = tensorflow.__version + if framework == MSCFramework.TENSORRT: + raw_version = ".".join(tvm.get_global_func("relax.get_tensorrt_version")()) + else: + raw_version = "1.0.0" + except: # pylint: disable=bare-except + raw_version = "1.0.0" + + return LooseVersion(raw_version).version diff --git a/python/tvm/contrib/msc/core/utils/namespace.py b/python/tvm/contrib/msc/core/utils/namespace.py index 984af184172c..e9d72f1a708f 100644 --- a/python/tvm/contrib/msc/core/utils/namespace.py +++ b/python/tvm/contrib/msc/core/utils/namespace.py @@ -63,3 +63,5 @@ class MSCFramework: MSC = "msc" TVM = "tvm" TORCH = "torch" + TENSORFLOW = "tensorflow" + TENSORRT = "tensorrt" diff --git a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py index 1918ad2f4d68..6bfe86056e27 100644 --- a/python/tvm/contrib/msc/framework/torch/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/torch/codegen/codegen.py @@ -69,4 +69,4 @@ def _bind_weights(model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> tor return model codegen = CodeGen(graph, _ffi_api.GetTorchSources, codegen_config, print_config, build_folder) - return codegen.load([], _bind_weights) + return codegen.load([], post_load=_bind_weights) diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py index 3ee932578d25..0edce88365d6 100644 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py @@ -67,4 +67,4 @@ def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRMo return mod codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder) - return codegen.load(inputs, _bind_weights) + return codegen.load(inputs, post_load=_bind_weights) diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index f40eda36dc7e..fb1f72b928c2 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -30,6 +30,7 @@ #include #include #include +#include #include "../ir/graph.h" #include "code_stack.h" @@ -61,26 +62,29 @@ class BaseOpCode { config_ = config; } + /*! \brief Get docs for the node*/ + virtual const Array GetDocs() = 0; + /*! \brief Get return describe for default node*/ - virtual const String IdxNode(bool as_raw = true) { return IdxNode(node_, as_raw); } + virtual const String IdxNode(bool as_raw = true) { return IdxNodeBase(node_, as_raw); } /*! \brief Get describe for default node input*/ - virtual const String IdxInput(int idx = 0, bool as_raw = false) { - return IdxInput(node_, idx, as_raw); + const String IdxInput(int idx = 0, bool as_raw = false) { + return IdxInputBase(node_, idx, as_raw); } /*! \brief Get describe for default node output*/ - virtual const String IdxOutput(int idx = 0, bool as_raw = false) { - return IdxOutput(node_, idx, as_raw); + const String IdxOutput(int idx = 0, bool as_raw = false) { + return IdxOutputBase(node_, idx, as_raw); } /*! \brief Get describe for default node weight*/ - virtual const String IdxWeight(const String& wtype, bool as_raw = false) { - return IdxWeight(node_, wtype, as_raw); + const String IdxWeight(const String& wtype, bool as_raw = false) { + return IdxWeightBase(node_, wtype, as_raw); } /*! \brief Get comment for default node*/ - virtual const String Comment() { return Comment(node_); } + const String Comment() { return Comment(node_); } /*! \brief Get func_name for the default node*/ const String func_name() { return func_name_; } @@ -88,6 +92,9 @@ class BaseOpCode { /*! \brief Get valid func name for the default node*/ virtual const String callee_name() { return func_name(); } + /*! \brief Get valid return name for the default node*/ + virtual const String ret_name() { return IdxNode(true); } + /*! \brief Get the default node*/ const MSCJoint node() { return node_; } @@ -164,6 +171,26 @@ class BaseCodeGen { } } + /*! + * \brief Compare version with version in config + * 0 for same version, 1 for greater version, -1 for less version + */ + int CompareVersion(size_t major, size_t minor, size_t patch) { + if (config_->version.size() == 0) { + return 0; + } + ICHECK_EQ(config_->version.size(), 3) << "Version should be in format major,minor,patch"; + std::vector given_version{major, minor, patch}; + for (size_t i = 0; i < 3; i++) { + if (given_version[i] > config_->version[i]) { + return 1; + } else if (given_version[i] < config_->version[i]) { + return -1; + } + } + return 0; + } + /*! \brief Get the docs for the op*/ virtual const Array GetOpCodes(const MSCJoint& node) = 0; diff --git a/src/contrib/msc/core/codegen/code_stack.cc b/src/contrib/msc/core/codegen/code_stack.cc index 430dafc4c560..6d8ced83dee6 100644 --- a/src/contrib/msc/core/codegen/code_stack.cc +++ b/src/contrib/msc/core/codegen/code_stack.cc @@ -46,27 +46,16 @@ void BaseStack::AssignBase(const String& lhs, const ExprDoc& rhs, const String& } } -void BaseStack::AssignIndexBase(const String& lhs, const String& rhs, const Array& indices, - const String& annotation) { - Array doc_indices; - for (const auto& i : indices) { - doc_indices.push_back(i); - } - AssignBase(lhs, IndexDoc(IdDoc(rhs), doc_indices), annotation); +void BaseStack::Declare(const String& type, const String& variable, size_t len, + bool use_constructor) { + PushDoc(DocUtils::ToDeclareDoc(type, variable, len, use_constructor)); } -void BaseStack::AttrAccess(const String& attr) { - const auto& host = PopDoc(); - if (host.as()) { - const auto& assign = Downcast(host); - ICHECK(assign->rhs.defined()) << "AttrAccess with assign missing rhs"; - const auto& access = AttrAccessDoc(assign->rhs.value(), attr); - PushDoc(AssignDoc(assign->lhs, access, assign->annotation)); - } else if (host.as()) { - PushDoc(AttrAccessDoc(Downcast(host), attr)); - } else { - LOG(FATAL) << "Unexpected attr access host " << host->GetTypeKey(); - } +void BaseStack::DeclareArgBase(const ExprDoc& value) { + const auto& declare = PopCheckedDoc(); + Array init_args = declare->init_args; + init_args.push_back(value); + PushDoc(DeclareDoc(declare->type, declare->variable, init_args, declare->use_constructor)); } void BaseStack::FuncDef(const String& func_name, const String& ret_type) { @@ -143,87 +132,97 @@ void BaseStack::ClassEnd() { PushDoc(ClassDoc(class_doc->name, class_doc->decorators, body)); } -void BaseStack::CallStart(const String& callee) { - PushDoc(CallDoc(IdDoc(callee), Array(), Array(), Array())); +void BaseStack::FuncCall(const String& callee, Optional assign_to, + Optional caller) { + if (!caller.defined()) { + PushDoc(CallDoc(IdDoc(callee), Array(), Array(), Array())); + } else { + const auto& new_access = AttrAccessDoc(caller.value(), callee); + PushDoc(CallDoc(new_access, Array(), Array(), Array())); + } + if (assign_to.defined()) { + const auto& last_call = PopCheckedDoc(); + const auto& declare = Downcast(assign_to.value()); + PushDoc(AssignDoc(declare->variable, last_call, declare->type)); + } } -void BaseStack::CallEnd(const String& assign) { - if (assign.size() > 0) { - const auto& last_call = PopCheckedDoc(); - PushDoc(AssignDoc(IdDoc(assign), last_call, NullOpt)); +void BaseStack::FuncCall(const String& callee, const String& assign_to, const String& caller) { + Optional assign_doc; + if (assign_to.size() == 0) { + assign_doc = NullOpt; + } else { + assign_doc = DocUtils::ToDeclareDoc("", assign_to); } + Optional caller_doc; + if (caller.size() == 0) { + caller_doc = NullOpt; + } else { + caller_doc = IdDoc(caller); + } + FuncCall(callee, assign_doc, caller_doc); } -void BaseStack::InplaceStart(const String& callee) { +void BaseStack::MethodCall(const String& callee) { const auto& host = PopDoc(); - if (host.as()) { - const auto& call = AttrAccessDoc(Downcast(host), callee); - PushDoc(CallDoc(call, Array(), Array(), Array())); + if (host->IsInstance()) { + FuncCall(callee, NullOpt, Downcast(host)); } else if (const auto* a_node = host.as()) { ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; - const auto& assign = AssignDoc(a_node->lhs, IdDoc("msc::inplace"), a_node->annotation); - assign->comment = "msc::inplace"; - PushDoc(assign); - const auto& call = AttrAccessDoc(a_node->rhs.value(), callee); - PushDoc(CallDoc(call, Array(), Array(), Array())); + FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, Array(), true), + a_node->rhs); } else { LOG(FATAL) << "Unexpected host type for inplace " << host->GetTypeKey(); } } -void BaseStack::InplaceEnd() { - const auto& call = PopCheckedDoc(); - if (HasDoc() && TopDoc()->IsInstance() && - Downcast(TopDoc())->comment == "msc::inplace") { - const auto& assign = PopCheckedDoc(); - PushDoc(AssignDoc(assign->lhs, call, assign->annotation)); +void BaseStack::PopNest(const String& key) { + const auto& last = PopDoc(); + if (last->IsInstance()) { + CallArgBase(Downcast(last), key); } else { - PushDoc(call); + LOG(FATAL) << "Unexpected nest type " << last->GetTypeKey(); } } void BaseStack::CallArgBase(const ExprDoc& value, const String& key) { - const auto& call = PopCheckedDoc(); + const auto& last = PopDoc(); + Array args; + Array kwargs_keys; + Array kwargs_values; + // get args and kwargs + if (const auto* call = last.as()) { + args = call->args; + kwargs_keys = call->kwargs_keys; + kwargs_values = call->kwargs_values; + } else if (const auto* assign = last.as()) { + const auto& call = Downcast(assign->rhs); + args = call->args; + kwargs_keys = call->kwargs_keys; + kwargs_values = call->kwargs_values; + } else { + LOG(FATAL) << "Unexpected last type for call arg " << last->GetTypeKey(); + } + // push args or kwargs if (key.size() == 0) { - ICHECK(call->kwargs_keys.size() == 0) << "kwargs followed by args " << value; - Array args = call->args; + ICHECK(kwargs_keys.size() == 0) << "kwargs followed by args " << value; args.push_back(value); - PushDoc(CallDoc(call->callee, args, call->kwargs_keys, call->kwargs_values)); } else { - Array kwargs_keys = call->kwargs_keys; - Array kwargs_values = call->kwargs_values; kwargs_keys.push_back(key); kwargs_values.push_back(value); - PushDoc(CallDoc(call->callee, call->args, kwargs_keys, kwargs_values)); - } -} - -void BaseStack::CallStrArg(const String& value, const String& key) { - if (value.size() > 0) { - CallArgBase(DocUtils::ToStrDoc(value), key); } -} - -void BaseStack::CallListArgBase(const Array& values, const String& key, bool allow_empty, - bool as_list) { - if (values.size() > 0 || allow_empty) { - if (as_list) { - CallArgBase(ListDoc(values), key); - } else { - for (const auto& v : values) { - CallArgBase(v); - } - } + // push doc + if (const auto* call = last.as()) { + PushDoc(CallDoc(call->callee, args, kwargs_keys, kwargs_values)); + } else if (const auto* assign = last.as()) { + const auto& call = Downcast(assign->rhs); + const auto& new_call = CallDoc(call->callee, args, kwargs_keys, kwargs_values); + PushDoc(AssignDoc(assign->lhs, new_call, assign->annotation)); + } else { + LOG(FATAL) << "Unexpected last type for call arg " << last->GetTypeKey(); } } -void BaseStack::CallInplaceStart(const String& callee) { CallStart(callee); } - -void BaseStack::CallInplaceEnd(const String& key) { - const auto& inplace = PopCheckedDoc(); - CallArgBase(inplace, key); -} - void BaseStack::ConditionIf(const String& predicate) { Array else_branch{ExprStmtDoc(IdDoc("pass"))}; PushDoc(IfDoc(IdDoc(predicate), Array(), else_branch)); @@ -248,6 +247,36 @@ void BaseStack::ConditionEnd() { } } +void BaseStack::ForStart(const String& lhs, const String& rhs) { + PushDoc(ForDoc(IdDoc(lhs), IdDoc(rhs), Array())); + BlockStart(); +} + +void BaseStack::ForStart(const String& lhs, size_t start, size_t end) { + Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; + PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), Array())); + BlockStart(); +} + +void BaseStack::ForEnd() { + const auto& block = PopBlock(); + const auto& for_doc = PopCheckedDoc(); + const auto& body = DocUtils::ToStmts(block); + PushDoc(ForDoc(for_doc->lhs, for_doc->rhs, body)); +} + +void BaseStack::WhileStart(const String& predicate) { + PushDoc(WhileDoc(IdDoc(predicate), Array())); + BlockStart(); +} + +void BaseStack::WhileEnd() { + const auto& block = PopBlock(); + const auto& while_doc = PopCheckedDoc(); + const auto& body = DocUtils::ToStmts(block); + PushDoc(WhileDoc(while_doc->predicate, body)); +} + void BaseStack::BlockStart() { Array block; blocks_.push(block); diff --git a/src/contrib/msc/core/codegen/code_stack.h b/src/contrib/msc/core/codegen/code_stack.h index acfd1f0aac87..bf659f927de7 100644 --- a/src/contrib/msc/core/codegen/code_stack.h +++ b/src/contrib/msc/core/codegen/code_stack.h @@ -30,6 +30,7 @@ #include #include +#include "../printer/msc_doc.h" #include "../printer/print_utils.h" #include "codegen_utils.h" @@ -68,50 +69,43 @@ class BaseStack { /*! \brief Push Comment Doc*/ void Comment(const String& comment = ""); - /*! \brief Push assign Doc*/ - void AssignBase(const String& lhs, const ExprDoc& rhs, const String& annotation = ""); - /*! \brief Push typed assign Doc*/ - template - void Assign(const String& lhs, const T& rhs, const String& annotation = "") { - AssignBase(lhs, DocUtils::ToDoc(rhs), annotation); - } + void AssignBase(const String& lhs, const ExprDoc& rhs, const String& annotation = ""); - /*! \brief Push assign for list Doc*/ template - inline void AssignList(const String& lhs, const std::vector& rhs, - const String& annotation = "") { - AssignBase(lhs, DocUtils::ToListDoc(rhs), annotation); + inline void Assign(const String& lhs, const T& rhs, const String& annotation = "") { + const auto& doc_rhs = DocUtils::ToDoc(rhs); + if (doc_rhs.defined()) { + AssignBase(lhs, doc_rhs, annotation); + } } - template - inline void AssignList(const String& lhs, const Array& rhs, const String& annotation = "") { - AssignBase(lhs, DocUtils::ToListDoc(rhs), annotation); - } + /*! \brief Push declare for variable Doc*/ + void Declare(const String& type, const String& variable, size_t len = 0, + bool use_constructor = true); - /*! \brief Push assign for index Doc*/ - void AssignIndexBase(const String& lhs, const String& rhs, const Array& indices, - const String& annotation = ""); + /*! \brief Cache declare typed argument*/ + void DeclareArgBase(const ExprDoc& value); template - inline void AssignIndex(const String& lhs, const String& rhs, const std::vector& indices, - const String& annotation = "") { - AssignIndexBase(lhs, rhs, DocUtils::ToDocList(indices), annotation); + inline void DeclareArg(const T& value) { + const auto& doc_value = DocUtils::ToDoc(value); + if (doc_value.defined()) { + DeclareArgBase(doc_value); + } } - template - inline void AssignIndex(const String& lhs, const String& rhs, const Array& indices, - const String& annotation = "") { - AssignIndexBase(lhs, rhs, DocUtils::ToDocList(indices), annotation); - } + /*! \brief Cache class Doc*/ + void ClassDef(const String& class_name); - inline void AssignIndex(const String& lhs, const String& rhs, const Array& indices, - const String& annotation = "") { - AssignIndexBase(lhs, rhs, indices, annotation); - } + /*! \brief Cache class decorator*/ + void ClassDecorator(const String& decorator); + + /*! \brief Start class body block*/ + void ClassStart(); - /*! \brief Push attr access Doc*/ - void AttrAccess(const String& attr); + /*! \brief End class body block*/ + void ClassEnd(); /*! \brief Cache function Doc*/ void FuncDef(const String& func_name, const String& ret_type = ""); @@ -128,67 +122,38 @@ class BaseStack { /*! \brief End function body block*/ void FuncEnd(const String& ret_val = ""); - /*! \brief Cache class Doc*/ - void ClassDef(const String& class_name); + /*! \brief Push call and maybe assign Doc*/ + void FuncCall(const String& callee, Optional assign_to, + Optional caller = NullOpt); - /*! \brief Cache class decorator*/ - void ClassDecorator(const String& decorator); + /*! \brief Push call and maybe assign Doc*/ + void FuncCall(const String& callee, const String& assign_to = "", const String& caller = ""); - /*! \brief Start class body block*/ - void ClassStart(); - - /*! \brief End class body block*/ - void ClassEnd(); + /*! \brief Push method call Doc*/ + void MethodCall(const String& callee); - /*! \brief Cache call Doc*/ - void CallStart(const String& callee); + /*! \brief Push nested expr to last Doc*/ + void PopNest(const String& key = ""); - /*! \brief Push call or/and assign Doc*/ - void CallEnd(const String& assign = ""); - - /*! \brief Cache inplace call Doc*/ - void InplaceStart(const String& callee); - - /*! \brief Push inplace call or/and assign Doc*/ - void InplaceEnd(); - - /*! \brief Cache call argument*/ + /*! \brief Cache call typed argument*/ void CallArgBase(const ExprDoc& value, const String& key = ""); + /*! \brief Cache call normal argument*/ template inline void CallArg(T value, const String& key = "") { - CallArgBase(DocUtils::ToDoc(value), key); - } - - void CallStrArg(const String& value, const String& key = ""); - - /*! \brief Cache call list argument*/ - void CallListArgBase(const Array& values, const String& key = "", - bool allow_empty = false, bool as_list = true); - - template - inline void CallListArg(const std::vector& values, const String& key = "", - bool allow_empty = false, bool as_list = true) { - return CallListArgBase(DocUtils::ToDocList(values), key, allow_empty, as_list); - } - - template - inline void CallListArg(const Array& values, const String& key = "", bool allow_empty = false, - bool as_list = true) { - return CallListArgBase(DocUtils::ToDocList(values), key, allow_empty, as_list); + const auto& doc_value = DocUtils::ToDoc(value); + if (doc_value.defined()) { + CallArgBase(doc_value, key); + } } - - inline void CallListArg(const Array& values, const String& key = "", - bool allow_empty = false, bool as_list = true) { - return CallListArgBase(values, key, allow_empty, as_list); + inline void CallArg(const Array& values) { + for (const auto& v : values) { + if (v.defined()) { + CallArgBase(v); + } + } } - /*! \brief Cache call inplace func argument*/ - void CallInplaceStart(const String& callee); - - /*! \brief Push call inplace func argument*/ - void CallInplaceEnd(const String& key = ""); - /*! \brief Push if to cache and start if block*/ void ConditionIf(const String& predicate); @@ -198,6 +163,21 @@ class BaseStack { /*! \brief Push else branch to cached*/ void ConditionEnd(); + /*! \brief Push for to cache and start for block*/ + void ForStart(const String& lhs, const String& rhs); + + /*! \brief Push for range to cache and start for block*/ + void ForStart(const String& lhs, size_t start, size_t end); + + /*! \brief End a for block*/ + void ForEnd(); + + /*! \brief Push while to cache and start while block*/ + void WhileStart(const String& predicate); + + /*! \brief End a while block*/ + void WhileEnd(); + /*! \brief Start a new block*/ void BlockStart(); @@ -240,164 +220,148 @@ class BaseStack { std::stack> blocks_; }; -#define COMMON_WRAPPERS(Stack) \ - Stack& line(const Doc& doc) { \ - Line(doc); \ - return *this; \ - } \ - Stack& line(const String& line = "") { \ - Line(line); \ - return *this; \ - } \ - Stack& comment(const String& comment) { \ - Comment(comment); \ - return *this; \ - } \ - template \ - Stack& assign(const String& lhs, const T& rhs, const String& annotation = "") { \ - Assign(lhs, rhs, annotation); \ - return *this; \ - } \ - template \ - Stack& assign_list(const String& lhs, const std::vector& rhs, \ - const String& annotation = "") { \ - AssignList(lhs, rhs, annotation); \ - return *this; \ - } \ - template \ - Stack& assign_list(const String& lhs, const Array& rhs, const String& annotation = "") { \ - AssignList(lhs, rhs, annotation); \ - return *this; \ - } \ - template \ - Stack& assign_index(const String& lhs, const String& rhs, const std::vector& indices, \ - const String& annotation = "") { \ - AssignIndex(lhs, rhs, indices, annotation); \ - return *this; \ - } \ - template \ - Stack& assign_index(const String& lhs, const String& rhs, const Array& indices, \ - const String& annotation = "") { \ - AssignIndex(lhs, rhs, indices, annotation); \ - return *this; \ - } \ - Stack& attr_access(const String& attr) { \ - AttrAccess(attr); \ - return *this; \ - } \ - Stack& block_start() { \ - BlockStart(); \ - return *this; \ - } \ - Stack& block_end(bool block_docs = true) { \ - BlockEnd(block_docs); \ - return *this; \ - } \ - Stack& scope_start(const String& scope_def, const String& scope_ref = "") { \ - ScopeStart(scope_def, scope_ref); \ - return *this; \ - } \ - Stack& scope_end() { \ - ScopeEnd(); \ - return *this; \ - } \ - Stack& func_def(const String& func_name, const String& ret_type = "") { \ - FuncDef(func_name, ret_type); \ - return *this; \ - } \ - Stack& func_arg(const String& arg, const String& annotation = "", const String& value = "") { \ - FuncArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& func_decorator(const String& decorator) { \ - FuncDecorator(decorator); \ - return *this; \ - } \ - Stack& func_start() { \ - FuncStart(); \ - return *this; \ - } \ - Stack& func_end(const String& ret_val = "") { \ - FuncEnd(ret_val); \ - return *this; \ - } \ - Stack& class_def(const String& class_name) { \ - ClassDef(class_name); \ - return *this; \ - } \ - Stack& class_decorator(const String& decorator) { \ - ClassDecorator(decorator); \ - return *this; \ - } \ - Stack& class_start() { \ - ClassStart(); \ - return *this; \ - } \ - Stack& class_end() { \ - ClassEnd(); \ - return *this; \ - } \ - Stack& call_start(const String& callee) { \ - CallStart(callee); \ - return *this; \ - } \ - Stack& call_end(const String& assign = "") { \ - CallEnd(assign); \ - return *this; \ - } \ - Stack& inplace_start(const String& callee) { \ - InplaceStart(callee); \ - return *this; \ - } \ - Stack& inplace_end() { \ - InplaceEnd(); \ - return *this; \ - } \ - template \ - Stack& call_arg(T value, const String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_str_arg(const String& value, const String& key = "") { \ - CallStrArg(value, key); \ - return *this; \ - } \ - Stack& call_list_arg(const Array& values, const String& key = "", \ - bool allow_empty = false, bool as_list = true) { \ - CallListArg(values, key, allow_empty, as_list); \ - return *this; \ - } \ - template \ - Stack& call_list_arg(const std::vector& values, const String& key = "", \ - bool allow_empty = false, bool as_list = true) { \ - CallListArg(values, key, allow_empty, as_list); \ - return *this; \ - } \ - template \ - Stack& call_list_arg(const Array& values, const String& key = "", bool allow_empty = false, \ - bool as_list = true) { \ - CallListArg(values, key, allow_empty, as_list); \ - return *this; \ - } \ - Stack& call_inplace_start(const String& callee) { \ - CallInplaceStart(callee); \ - return *this; \ - } \ - Stack& call_inplace_end(const String& key = "") { \ - CallInplaceEnd(key); \ - return *this; \ - } \ - Stack& cond_if(const String& predicate) { \ - ConditionIf(predicate); \ - return *this; \ - } \ - Stack& cond_else() { \ - ConditionElse(); \ - return *this; \ - } \ - Stack& cond_end() { \ - ConditionEnd(); \ - return *this; \ +#define COMMON_WRAPPERS(Stack) \ + Stack& line(const Doc& doc) { \ + Line(doc); \ + return *this; \ + } \ + Stack& line(const String& line = "") { \ + Line(line); \ + return *this; \ + } \ + Stack& comment(const String& comment) { \ + Comment(comment); \ + return *this; \ + } \ + template \ + Stack& assign(const String& lhs, const T& rhs, const String& annotation = "") { \ + Assign(lhs, rhs, annotation); \ + return *this; \ + } \ + Stack& declare(const String& type, const String& variable, size_t len = 0, \ + bool use_constructor = true) { \ + Declare(type, variable, len, use_constructor); \ + return *this; \ + } \ + template \ + Stack& declare_arg(const T& value) { \ + DeclareArg(value); \ + return *this; \ + } \ + Stack& class_def(const String& class_name) { \ + ClassDef(class_name); \ + return *this; \ + } \ + Stack& class_decorator(const String& decorator) { \ + ClassDecorator(decorator); \ + return *this; \ + } \ + Stack& class_start() { \ + ClassStart(); \ + return *this; \ + } \ + Stack& class_end() { \ + ClassEnd(); \ + return *this; \ + } \ + Stack& func_def(const String& func_name, const String& ret_type = "") { \ + FuncDef(func_name, ret_type); \ + return *this; \ + } \ + Stack& func_arg(const String& arg, const String& annotation = "", const String& value = "") { \ + FuncArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& func_decorator(const String& decorator) { \ + FuncDecorator(decorator); \ + return *this; \ + } \ + Stack& func_start() { \ + FuncStart(); \ + return *this; \ + } \ + Stack& func_end(const String& ret_val = "") { \ + FuncEnd(ret_val); \ + return *this; \ + } \ + Stack& func_call(const String& callee, Optional assign_to, \ + Optional caller = NullOpt) { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& func_call(const String& callee, const String& assign_to = "", \ + const String& caller = "") { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& method_call(const String& callee) { \ + MethodCall(callee); \ + return *this; \ + } \ + Stack& pop_nest(const String& key = "") { \ + PopNest(key); \ + return *this; \ + } \ + template \ + Stack& call_arg(T value, const String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const ExprDoc& value, const String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const Array& values) { \ + CallArg(values); \ + return *this; \ + } \ + Stack& cond_if(const String& predicate) { \ + ConditionIf(predicate); \ + return *this; \ + } \ + Stack& cond_else() { \ + ConditionElse(); \ + return *this; \ + } \ + Stack& cond_end() { \ + ConditionEnd(); \ + return *this; \ + } \ + Stack& for_start(const String& lhs, const String& rhs) { \ + ForStart(lhs, rhs); \ + return *this; \ + } \ + Stack& for_start(const String& lhs, size_t start, size_t end) { \ + ForStart(lhs, start, end); \ + return *this; \ + } \ + Stack& for_end() { \ + ForEnd(); \ + return *this; \ + } \ + Stack& while_start(const String& predicate) { \ + WhileStart(predicate); \ + return *this; \ + } \ + Stack& while_end() { \ + WhileEnd(); \ + return *this; \ + } \ + Stack& block_start() { \ + BlockStart(); \ + return *this; \ + } \ + Stack& block_end(bool block_docs = true) { \ + BlockEnd(block_docs); \ + return *this; \ + } \ + Stack& scope_start(const String& scope_def, const String& scope_ref = "") { \ + ScopeStart(scope_def, scope_ref); \ + return *this; \ + } \ + Stack& scope_end() { \ + ScopeEnd(); \ + return *this; \ } /*! @@ -434,16 +398,12 @@ class OpCodeStack : public BaseStack { COMMON_WRAPPERS(OpCodeStack) - /*! \brief Cache op_call Doc*/ - OpCodeStack& op_start(const String& callee = "msc::auto") { - const String& v_callee = callee == "msc::auto" ? codegen_->callee_name() : callee; - return call_start(v_callee); - } - /*! \brief Push op_call Doc*/ - OpCodeStack& op_end(const String& assign_str = "msc::auto") { - const String& v_assign = assign_str == "msc::auto" ? codegen_->IdxNode(true) : assign_str; - return call_end(v_assign); + OpCodeStack& op_call(const String& callee = "msc::auto", + const String& assign_to = "msc::auto") { + const String& v_callee = callee == "msc::auto" ? codegen_->callee_name() : callee; + const String& v_assign = assign_to == "msc::auto" ? codegen_->ret_name() : assign_to; + return func_call(v_callee, v_assign); } /*! \brief Push op comment Doc*/ @@ -452,7 +412,7 @@ class OpCodeStack : public BaseStack { return comment(v_comment); } - /*! \brief Cache attribute as argument*/ + /*! \brief Cache typed attribute as argument*/ template OpCodeStack& op_arg(const String& attr_key, const String& key = "msc::auto") { T attr_val; @@ -463,11 +423,12 @@ class OpCodeStack : public BaseStack { return *this; } + /*! \brief Cache str attribute as argument*/ OpCodeStack& op_str_arg(const String& attr_key, const String& key = "msc::auto") { std::string attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { const String& valid_key = key == "msc::auto" ? attr_key : key; - return call_str_arg(attr_val, valid_key); + return call_arg(DocUtils::ToStrDoc(attr_val), valid_key); } return *this; } @@ -475,11 +436,11 @@ class OpCodeStack : public BaseStack { /*! \brief Cache list attribute as argument*/ template OpCodeStack& op_list_arg(const String& attr_key, const String& key = "msc::auto", - bool allow_empty = false, bool as_list = true) { + bool allow_empty = false) { std::vector attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { const String& valid_key = key == "msc::auto" ? attr_key : key; - return call_list_arg(attr_val, valid_key, allow_empty, as_list); + return call_arg(DocUtils::ToListDoc(attr_val, allow_empty), valid_key); } return *this; } @@ -495,7 +456,11 @@ class OpCodeStack : public BaseStack { for (size_t i = 0; i < codegen_->node()->inputs.size(); i++) { inputs.push_back(codegen_->IdxInput(i, false)); } - return call_list_arg(inputs, key, false, as_list); + if (as_list) { + return call_arg(DocUtils::ToListDoc(inputs), key); + } else { + return call_arg(DocUtils::ToDocList(inputs)); + } } /*! \brief Cache output as argument*/ @@ -511,7 +476,16 @@ class OpCodeStack : public BaseStack { return *this; } - OpCodeStack& call_dtype_arg(const DataType& dtype, const String& key = "") { + /*! \brief Cache name as argument*/ + OpCodeStack& op_name_arg(const String& key = "msc::auto", + const String& name = "msc::auto") { + const String& valid_key = key == "msc::auto" ? "name" : key; + const String& valid_name = name == "msc::auto" ? codegen_->node()->name : name; + return call_arg(DocUtils::ToStrDoc(valid_name), valid_key); + return *this; + } + + OpCodeStack& op_dtype_arg(const DataType& dtype, const String& key = "") { return call_arg(codegen_->DType(dtype), key); } diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h index 3cdd8dc78aa0..ce32ae5f8893 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ b/src/contrib/msc/core/codegen/codegen_utils.h @@ -28,6 +28,7 @@ #include #include +#include #include "../ir/graph.h" #include "../utils.h" @@ -49,7 +50,7 @@ using namespace tvm::script::printer; std::string test_device{"cpu"}; \ std::string prefix{"res_"}; \ std::string baseline_folder{"baseline"}; \ - std::string version; + std::vector version{0, 0, 0}; #define CODEGEN_CONFIG_PARSE \ if (key == "is_train") { \ @@ -80,34 +81,34 @@ using namespace tvm::script::printer; LOG(FATAL) << "Do not support key " << key; \ } -#define CODEGEN_MEMBERS \ - public: \ - virtual const Array GetDocs() = 0; \ - \ - protected: \ - const std::shared_ptr config() { return config_; } \ - const String GetSuffix(bool as_raw = false) { \ - const String& suffix = as_raw && config()->need_process ? "_raw" : ""; \ - return suffix; \ - } \ - virtual const String IdxNode(const MSCJoint& node, bool as_raw = true) { \ - return CodeGenUtils::IdxNode(node, config()->prefix, GetSuffix(as_raw)); \ - } \ - virtual const String IdxInput(const MSCJoint& node, int idx = 0, bool as_raw = false) { \ - return CodeGenUtils::IdxInput(node, config()->prefix, idx, GetSuffix(as_raw)); \ - } \ - virtual const String IdxOutput(const MSCJoint& node, int idx = 0, bool as_raw = false) { \ - return CodeGenUtils::IdxOutput(node, config()->prefix, idx, GetSuffix(as_raw)); \ - } \ - virtual const String IdxWeight(const MSCJoint& node, const String& wtype, bool as_raw = false) { \ - return CodeGenUtils::IdxWeight(node, wtype, GetSuffix(as_raw)); \ - } \ - virtual const String DType(const DataType& dtype) { return runtime::DLDataType2String(dtype); } \ - virtual const String Comment(const MSCJoint& node) { \ - return CodeGenUtils::CommentNode(node, config()->prefix); \ - } \ - \ - private: \ +#define CODEGEN_MEMBERS \ + public: \ + virtual const String DType(const DataType& dtype) { return runtime::DLDataType2String(dtype); } \ + \ + protected: \ + const std::shared_ptr config() { return config_; } \ + const String GetSuffix(bool as_raw = false) { \ + const String& suffix = as_raw && config()->need_process ? "_raw" : ""; \ + return suffix; \ + } \ + virtual const String IdxNodeBase(const MSCJoint& node, bool as_raw = true) { \ + return CodeGenUtils::IdxNode(node, config()->prefix, GetSuffix(as_raw)); \ + } \ + virtual const String IdxInputBase(const MSCJoint& node, int idx = 0, bool as_raw = false) { \ + return CodeGenUtils::IdxInput(node, config()->prefix, idx, GetSuffix(as_raw)); \ + } \ + virtual const String IdxOutputBase(const MSCJoint& node, int idx = 0, bool as_raw = false) { \ + return CodeGenUtils::IdxOutput(node, config()->prefix, idx, GetSuffix(as_raw)); \ + } \ + virtual const String IdxWeightBase(const MSCJoint& node, const String& wtype, \ + bool as_raw = false) { \ + return CodeGenUtils::IdxWeight(node, wtype, GetSuffix(as_raw)); \ + } \ + virtual const String Comment(const MSCJoint& node) { \ + return CodeGenUtils::CommentNode(node, config()->prefix); \ + } \ + \ + private: \ std::shared_ptr config_; /*! diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h index ec2c7f9c4175..21812fdba20b 100644 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ b/src/contrib/msc/core/codegen/py_codegen.h @@ -52,7 +52,7 @@ class PyCodeGen : public BaseCodeGen { : BaseCodeGen(graph, config) {} /*! \brief Stack the docs for the script*/ - virtual const Array GetDocs() { + virtual void CodeGenScript() { CodeGenHeader(); this->stack_.line().comment("Define the helpers"); CodeGenHelper(); @@ -62,14 +62,14 @@ class PyCodeGen : public BaseCodeGen { this->stack_.line().comment("Define the test"); CodeGenTest(); } - return this->stack_.GetDocs(); } /*! \brief Get sources*/ virtual const Map GetSources(const std::string& print_options = "") { Map sources; PythonPrinter printer(print_options); - for (const auto& d : this->GetDocs()) { + CodeGenScript(); + for (const auto& d : this->stack_.GetDocs()) { printer.Append(d); } sources.Set(this->graph()->name + ".py", printer.GetString()); @@ -100,25 +100,20 @@ class PyCodeGen : public BaseCodeGen { .func_arg("shape", "List[int]") .func_arg("dtype", "str") .func_start() - .call_start("os.path.join") - .call_str_arg(this->config()->baseline_folder) + .func_call("os.path.join", "path") + .call_arg(DocUtils::ToStrDoc(this->config()->baseline_folder)) .call_arg("name + \".bin\"") - .call_end("path") .cond_if("os.path.isfile(path)") - .call_start("np.fromfile") + .func_call("np.fromfile", "data") .call_arg("path") .call_arg("dtype", "dtype") - .call_end("data") - .inplace_start("reshape") + .method_call("reshape") .call_arg("shape") - .inplace_end() .cond_else() - .call_start("np.ones") + .func_call("np.ones", "data") .call_arg("(shape)") - .call_end("data") - .inplace_start("astype") + .method_call("astype") .call_arg("dtype") - .inplace_end() .cond_end() .func_end("data"); } @@ -132,19 +127,17 @@ class PyCodeGen : public BaseCodeGen { .assign("golden", "{}"); for (const auto& i : this->graph()->input_names) { const auto& input = this->graph()->FindTensor(i); - this->stack_.call_start("load_data") - .call_str_arg(input->alias) - .call_list_arg(input->shape, "", true) - .call_str_arg(runtime::DLDataType2String(input->dtype)) - .call_end("inputs[\"" + input->alias + "\"]"); + this->stack_.func_call("load_data", "inputs[\"" + input->alias + "\"]") + .call_arg(DocUtils::ToStrDoc(input->alias)) + .call_arg(DocUtils::ToListDoc(input->shape, true)) + .call_arg(DocUtils::ToStrDoc(runtime::DLDataType2String(input->dtype))); } for (const auto& o : this->graph()->output_names) { const auto& output = this->graph()->FindTensor(o); - this->stack_.call_start("load_data") - .call_str_arg(output->alias) - .call_list_arg(output->shape, "", true) - .call_str_arg(runtime::DLDataType2String(output->dtype)) - .call_end("golden[\"" + output->alias + "\"]"); + this->stack_.func_call("load_data", "golden[\"" + output->alias + "\"]") + .call_arg(DocUtils::ToStrDoc(output->alias)) + .call_arg(DocUtils::ToListDoc(output->shape, true)) + .call_arg(DocUtils::ToStrDoc(runtime::DLDataType2String(output->dtype))); } this->stack_.comment("Build and inference the graph"); CodeGenInference(); @@ -157,18 +150,16 @@ class PyCodeGen : public BaseCodeGen { if (this->config()->need_process) { for (size_t i = 0; i < node->inputs.size(); i++) { const auto& input = node->InputAt(i); - this->stack_.call_start("process_tensor") - .call_arg(this->IdxInput(node, i, true)) - .call_str_arg(input->name) - .call_str_arg(node->name) - .call_end(this->IdxInput(node, i, false)); + this->stack_.func_call("process_tensor", this->IdxInputBase(node, i, false)) + .call_arg(this->IdxInputBase(node, i, true)) + .call_arg(DocUtils::ToStrDoc(input->name)) + .call_arg(DocUtils::ToStrDoc(node->name)); } for (const auto& pair : node->weights) { - this->stack_.call_start("process_tensor") - .call_arg(this->IdxWeight(node, pair.first, true)) - .call_str_arg(pair.second->name) - .call_str_arg(node->name) - .call_end(this->IdxWeight(node, pair.first, false)); + this->stack_.func_call("process_tensor", this->IdxWeightBase(node, pair.first, false)) + .call_arg(this->IdxWeightBase(node, pair.first, true)) + .call_arg(DocUtils::ToStrDoc(pair.second->name)) + .call_arg(DocUtils::ToStrDoc(node->name)); } } for (const auto& d : this->GetOpCodes(node)) { diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 56081caa7eaa..f300eb016df3 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -638,6 +638,16 @@ const Array MSCGraphNode::GetExits() const { return exits; } +const bool MSCGraphNode::HasTensor(const String& name) const { + if (weight_holders.count(name)) { + return true; + } + const String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; + String host, index; + std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); + return nodes.count(host) > 0 ? true : false; +} + const MSCTensor MSCGraphNode::FindTensor(const String& name) const { if (weight_holders.count(name)) { const auto& node = FindNode(weight_holders[name][0]); @@ -991,11 +1001,21 @@ TVM_REGISTER_GLOBAL("msc.core.WeightGraph") }); // Graph APIS +TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasNode") + .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { + return Bool(graph->nodes.count(name)); + }); + TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { return graph->FindNode(name); }); +TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") + .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { + return Bool(graph->HasTensor(name)); + }); + TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCTensor { return graph->FindTensor(name); @@ -1054,6 +1074,11 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJointOutputAt") return node->OutputAt(index); }); +TVM_REGISTER_GLOBAL("msc.core.MSCJointWeightAt") + .set_body_typed([](const MSCJoint& node, const String& wtype) -> MSCTensor { + return node->WeightAt(wtype); + }); + TVM_REGISTER_GLOBAL("msc.core.MSCJointGetInputs") .set_body_typed([](const MSCJoint& node) -> Array { return node->GetInputs(); }); diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index 8179471d5a0f..7165d82ba764 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -628,6 +628,8 @@ class MSCGraphNode : public BaseGraphNode { const Array GetEntries() const; /*! \brief Get exits from the graph. */ const Array GetExits() const; + /*! \brief Check if tensor in the graph. */ + const bool HasTensor(const String& name) const; /*! \brief Find tensor from the graph. */ const MSCTensor FindTensor(const String& name) const; /*! \brief Find producer of tensor from the graph. */ diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index cadebcfcb0ad..e123c4c2924c 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -47,6 +47,32 @@ void RelaxFuncAttrGetter::VisitExpr_(const relax::CallNode* op) { } } +void RelaxFuncParamsFinder::VisitBinding_(const relax::VarBindingNode* binding, + const relax::FunctionNode* val) { + local_funcs_.Set(binding->var, GetRef(val)); +} + +void RelaxFuncParamsFinder::VisitExpr_(const relax::CallNode* call_node) { + RelaxExprVisitor::VisitExpr_(call_node); + relax::Function func; + if (const auto* v_node = call_node->op.as()) { + func = Downcast(ref_module_->Lookup(v_node->name_hint)); + } else if (call_node->op.as()) { + ICHECK(local_funcs_.count(call_node->op)) << "Can not find local func " << call_node->op; + func = local_funcs_[call_node->op]; + } + if (func.defined()) { + for (size_t i = 0; i < call_node->args.size(); i++) { + const auto& arg = call_node->args[i]; + if (arg->IsInstance() && params_.count(Downcast(arg))) { + params_.Set(func->params[i], params_[Downcast(arg)]); + } else { + params_.Set(func->params[i], arg); + } + } + } +} + const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { // Add input nodes and record inputs; Array input_names, output_names; @@ -64,21 +90,32 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { } // remove const nodes as weights Array valid_nodes; + std::set ignore_inputs; for (const auto& n : nodes_) { - if (!weights_.count(n->name)) { + if (!weights_.count(n->name) && !ignore_nodes_.count(n->name)) { n->index = valid_nodes.size(); valid_nodes.push_back(n); + } else if (n->optype == "input") { + ignore_inputs.insert(n->OutputAt(0)->name); } } - const auto& graph = MSCGraph(name_, valid_nodes, input_names, output_names); + // remove uselese inputs + Array valid_inputs; + for (const auto& i : input_names) { + if (!ignore_inputs.count(i)) { + valid_inputs.push_back(i); + } + } + // build graph + const auto& graph = MSCGraph(name_, valid_nodes, valid_inputs, output_names); // set inputs and outputs alias - if (config_.input_aliases.size() == input_names.size()) { - for (size_t i = 0; i < input_names.size(); i++) { - graph->FindTensor(input_names[i])->alias = config_.input_aliases[i]; + if (config_.input_aliases.size() == valid_inputs.size()) { + for (size_t i = 0; i < valid_inputs.size(); i++) { + graph->FindTensor(valid_inputs[i])->alias = config_.input_aliases[i]; } } else { - for (size_t i = 0; i < input_names.size(); i++) { - graph->FindTensor(input_names[i])->alias = graph->FindProducer(input_names[i])->name; + for (size_t i = 0; i < valid_inputs.size(); i++) { + graph->FindTensor(valid_inputs[i])->alias = graph->FindProducer(valid_inputs[i])->name; } } if (config_.output_aliases.size() == output_names.size()) { @@ -123,6 +160,11 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional const auto& name_opt = func->GetAttr(relax::attr::kComposite); ICHECK(name_opt.defined()) << "Unexpected global func without composite"; optype = name_opt.value(); + } else if (call_node->op.as()) { + ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; + const auto& func = target_funcs_[call_node->op]; + const auto& name_opt = func->GetAttr(relax::attr::kComposite); + optype = StringUtils::Replace(name_opt.value(), config_.target + ".", ""); } else if (const auto* f_node = call_node->op.as()) { const auto& name_opt = f_node->GetAttr(relax::attr::kComposite); ICHECK(name_opt.defined()) << "Unexpected func without composite"; @@ -139,6 +181,10 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional if (const auto* v_node = call_node->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); attrs = RelaxFuncAttrGetter().GetAttrs(func); + } else if (call_node->op->IsInstance()) { + ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; + const auto& func = target_funcs_[call_node->op]; + attrs = RelaxFuncAttrGetter().GetAttrs(func); } else if (call_node->op->IsInstance()) { attrs = RelaxFuncAttrGetter().GetAttrs(call_node->op); } else if (call_node->attrs.defined()) { @@ -173,29 +219,67 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); continue; } + if (func_params_.count(arg) && func_params_[arg]->IsInstance()) { + const auto* s_node = func_params_[arg].as(); + attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); + ignore_nodes_.insert(Downcast(arg)->name_hint()); + continue; + } if (const auto* s_node = arg.as()) { ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype << " should has special type, get " << input_types; attrs.Set(input_types[i], StringUtils::ToString(s_node->value)); continue; } - ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg; + Array arg_names; + if (expr_tensor_map_.count(arg)) { + arg_names = expr_tensor_map_[arg]; + } else if (const auto* tuple_node = arg.as()) { + for (const auto& f : tuple_node->fields) { + ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; + for (const auto& in_name : expr_tensor_map_[f]) { + arg_names.push_back(in_name); + } + } + } + String weight_name; if (input_types[i] != "input" && arg->IsInstance()) { - const auto& t_name = expr_tensor_map_[arg][0]; - const auto& w_name = SpanUtils::GetAttr(arg->span, "name"); + weight_name = SpanUtils::GetAttr(arg->span, "name"); + } else if (input_types[i] != "input" && func_params_.count(arg) && + func_params_[arg]->IsInstance()) { + weight_name = SpanUtils::GetAttr(func_params_[arg]->span, "name"); + ignore_nodes_.insert(Downcast(arg)->name_hint()); + } + // set weights or inputs + if (weight_name.size() > 0) { + const auto& t_name = arg_names[0]; const auto& pair = tensor_input_map_[t_name]; const auto& producer = Downcast(pair.first); - if (!weights_.count(w_name)) { + if (!weights_.count(weight_name)) { const auto& ref = producer->OutputAt(pair.second); - const auto& weight = MSCTensor(w_name, ref->dtype, ref->layout.name(), ref->shape); - weights_.Set(w_name, weight); + MSCTensor weight; + if (input_types[i] == "bias") { + weight = MSCTensor(weight_name, ref->dtype, "O", Array{ref->GetSize()}); + } else if (input_types[i] == "weight" && + (optype == "msc.linear" || optype == "msc.linear_bias")) { + if (ref->layout.name() == "IO") { + String valid_layout = ref->layout[1].name() + ref->layout[0].name(); + const auto& valid_shape = Array({ref->shape[1], ref->shape[0]}); + weight = MSCTensor(weight_name, ref->dtype, valid_layout, valid_shape); + } else { + weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); + } + } else { + weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); + } + weights_.Set(weight_name, weight); } if (producer->HasAttr("scalar")) { attrs.Set(input_types[i], producer->GetTypeAttr("scalar")); } - node_weights.Set(input_types[i], weights_[w_name]); + node_weights.Set(input_types[i], weights_[weight_name]); } else { - for (const auto& in_name : expr_tensor_map_[arg]) { + for (const auto& in_name : arg_names) { input_names.push_back(in_name); } } @@ -280,19 +364,22 @@ void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) { void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, const relax::ConstantNode* val) { - AddNode(GetRef(val), binding->var); + const String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(GetRef(val), binding->var, name); } void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, const relax::ShapeExprNode* val) { - AddNode(GetRef(val), binding->var); + const String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(GetRef(val), binding->var, name); } void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, const relax::CallNode* call_node) { RelaxExprVisitor::VisitBinding_(binding, call_node); + const String& name = config_.use_var_name ? binding->var->name_hint() : ""; try { - AddNode(GetRef(call_node), binding->var); + AddNode(GetRef(call_node), binding->var, name); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to add node from " << binding->var << " : " << binding->value << ", reason: " << err.message(); @@ -303,13 +390,15 @@ void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, const relax::TupleNode* val) { RelaxExprVisitor::VisitBinding_(binding, val); - AddNode(GetRef(val), binding->var); + const String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(GetRef(val), binding->var, name); } void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, const relax::TupleGetItemNode* val) { RelaxExprVisitor::VisitBinding_(binding, val); - AddNode(GetRef(val), binding->var); + const String& name = config_.use_var_name ? binding->var->name_hint() : ""; + AddNode(GetRef(val), binding->var, name); } void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, @@ -328,6 +417,15 @@ void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } +void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, + const relax::FunctionNode* val) { + const auto& name_opt = val->GetAttr(relay::attr::kComposite); + ICHECK(name_opt.defined()) << "Unexpected target func without composite"; + ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target)) + << "Target should be given for target function"; + target_funcs_.Set(binding->var, GetRef(val)); +} + Map RelaxWeightsExtractor::GetWeights(const relax::Function& func) { VisitExpr(func); return weights_; @@ -476,18 +574,23 @@ MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) { ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg; if (input_types[i] != "input" && arg->IsInstance()) { const auto& t_name = expr_tensor_map_[arg][0]; - const auto& w_name = SpanUtils::GetAttr(arg->span, "name"); + const auto& weight_name = SpanUtils::GetAttr(arg->span, "name"); const auto& pair = tensor_input_map_[t_name]; const auto& producer = Downcast(pair.first); - if (!weights_.count(w_name)) { + if (!weights_.count(weight_name)) { const auto& ref = producer->OutputAt(pair.second); - const auto& weight = MSCTensor(w_name, ref->dtype, ref->layout.name(), ref->shape); - weights_.Set(w_name, weight); + MSCTensor weight; + if (input_types[i] == "bias") { + weight = MSCTensor(weight_name, ref->dtype, "O", Array{ref->GetSize()}); + } else { + weight = MSCTensor(weight_name, ref->dtype, ref->layout.name(), ref->shape); + } + weights_.Set(weight_name, weight); } if (producer->HasAttr("scalar")) { attrs.Set(input_types[i], producer->GetTypeAttr("scalar")); } - node_weights.Set(input_types[i], weights_[w_name]); + node_weights.Set(input_types[i], weights_[weight_name]); } else { for (const auto& in_name : expr_tensor_map_[arg]) { input_names.push_back(in_name); @@ -606,6 +709,9 @@ void RelayGraphBuilder::VisitExpr_(const relay::CallNode* op) { const auto& name_opt = f_node->GetAttr(relay::attr::kComposite); if (name_opt.defined()) { for (size_t i = 0; i < op->args.size(); i++) { + if (!expr_tensor_map_.count(op->args[i])) { + RelayExprVisitor::VisitExpr(op->args[i]); + } ICHECK(expr_tensor_map_.count(op->args[i])) << "Can not find argument " << relay::PrettyPrint(op->args[i]); expr_tensor_map_.Set(f_node->params[i], expr_tensor_map_[op->args[i]]); diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 6a2a64b262c3..243304f238ce 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -56,8 +57,11 @@ using namespace tvm::runtime; */ struct MSCRBuildConfig { bool prune_graph{false}; + bool use_var_name{false}; int float_precision = 6; std::string sort_by; + std::string target = ""; + std::string graph_name = ""; std::vector input_aliases; std::vector output_aliases; std::unordered_map> input_types; @@ -78,10 +82,16 @@ struct MSCRBuildConfig { while (reader->NextObjectItem(&key)) { if (key == "prune_graph") { reader->Read(&prune_graph); + } else if (key == "use_var_name") { + reader->Read(&use_var_name); } else if (key == "float_precision") { reader->Read(&float_precision); } else if (key == "sort_by") { reader->Read(&sort_by); + } else if (key == "target") { + reader->Read(&target); + } else if (key == "graph_name") { + reader->Read(&graph_name); } else if (key == "input_aliases") { reader->Read(&input_aliases); } else if (key == "output_aliases") { @@ -147,6 +157,32 @@ class RelaxFuncAttrGetter : public RelaxExprVisitor { Map attrs_; }; +class RelaxFuncParamsFinder : public RelaxExprVisitor { + public: + /*! + * \brief The constructor of RelaxGraphBuilder + * \param ref_module the reference module. + */ + explicit RelaxFuncParamsFinder(const IRModule& ref_module) : RelaxExprVisitor() { + ref_module_ = ref_module; + } + + /*! \brief Find the func params and bind with arguments*/ + Map FindParams(const Expr& expr) { + RelaxExprVisitor::VisitExpr(expr); + return params_; + } + + void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + + void VisitExpr_(const relax::CallNode* op) final; + + private: + IRModule ref_module_; + Map params_; + Map local_funcs_; +}; + class RelaxGraphBuilder : public RelaxExprVisitor { public: /*! @@ -158,13 +194,16 @@ class RelaxGraphBuilder : public RelaxExprVisitor { explicit RelaxGraphBuilder(const IRModule& ref_module, const String& name, const std::string& options = "") : RelaxExprVisitor() { - name_ = name; ref_module_ = ref_module; if (options.size() > 0) { std::istringstream is(options); dmlc::JSONReader reader(&is); reader.Read(&config_); } + name_ = config_.graph_name.size() > 0 ? String(config_.graph_name) : name; + if (name != "main") { + func_params_ = RelaxFuncParamsFinder(ref_module).FindParams(ref_module->Lookup("main")); + } } /*! \brief Build MSCGraph from relax function*/ @@ -193,6 +232,8 @@ class RelaxGraphBuilder : public RelaxExprVisitor { void VisitBinding_(const relax::VarBindingNode* binding, const relax::DataflowVarNode* val) final; + void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + private: String name_; String scope_name_; @@ -202,6 +243,10 @@ class RelaxGraphBuilder : public RelaxExprVisitor { Map weights_; Map> expr_tensor_map_; std::unordered_map> tensor_input_map_; + std::set ignore_nodes_; + // BYOC maps + Map target_funcs_; + Map func_params_; }; class RelaxWeightsExtractor : public RelaxExprVisitor { @@ -259,7 +304,6 @@ class RelayGraphBuilder : public RelayExprVisitor { explicit RelayGraphBuilder(const IRModule& ref_module, const String& name, const std::string& options = "") : RelayExprVisitor() { - name_ = name; ref_module_ = ref_module; if (options.size() > 0) { std::istringstream is(options); @@ -269,6 +313,7 @@ class RelayGraphBuilder : public RelayExprVisitor { while (!func_scopes_.empty()) { func_scopes_.pop(); } + name_ = config_.graph_name.size() > 0 ? String(config_.graph_name) : name; } /*! \brief Build MSCGraph from relax function*/ diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 9f3c68d77e85..ba434976ffa1 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -76,6 +76,12 @@ void MSCBasePrinter::PrintDoc(const Doc& doc, bool new_line) { PrintTypedDoc(doc_node.value()); } else if (auto doc_node = doc.as()) { PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); } else { LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); throw; @@ -104,21 +110,6 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { void MSCBasePrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } -void MSCBasePrinter::PrintTypedDoc(const IndexDoc& doc) { - PrintDoc(doc->value); - if (doc->indices.size() == 0) { - output_ << "[()]"; - } else { - for (size_t i = 0; i < doc->indices.size(); i++) { - if (i == 0) { - output_ << "["; - } - PrintDoc(doc); - output_ << (i == doc->indices.size() - 1 ? "]" : ", "); - } - } -} - void MSCBasePrinter::PrintTypedDoc(const ListDoc& doc) { output_ << "["; PrintJoinedDocs(doc->elements); diff --git a/src/contrib/msc/core/printer/msc_base_printer.h b/src/contrib/msc/core/printer/msc_base_printer.h index 4900fba363e8..17cb218bfcba 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.h +++ b/src/contrib/msc/core/printer/msc_base_printer.h @@ -30,6 +30,7 @@ #include #include "../../../../../src/support/str_escape.h" +#include "msc_doc.h" namespace tvm { namespace contrib { @@ -41,7 +42,6 @@ using namespace tvm::script::printer; * \brief MSCPrinterConfig is base for config class in MSC * \sa Doc */ - struct MSCPrinterConfig { size_t indent{0}; size_t float_precision{6}; @@ -109,9 +109,6 @@ class MSCBasePrinter { /*! \brief Virtual method to print an IdDoc*/ virtual void PrintTypedDoc(const IdDoc& doc); - /*! \brief Virtual method to print an IndexDoc*/ - virtual void PrintTypedDoc(const IndexDoc& doc); - /*! \brief Virtual method to print a ListDoc*/ virtual void PrintTypedDoc(const ListDoc& doc); @@ -127,6 +124,9 @@ class MSCBasePrinter { /*! \brief Virtual method to print a ExprStmtDoc*/ virtual void PrintTypedDoc(const ExprStmtDoc& doc); + /*! \brief Virtual method to print an IndexDoc*/ + virtual void PrintTypedDoc(const IndexDoc& doc) { LOG(FATAL) << "Index is not implemented"; } + /*! \brief Virtual method to print a CallDoc*/ virtual void PrintTypedDoc(const CallDoc& doc) { LOG(FATAL) << "Call is not implemented"; } @@ -170,6 +170,19 @@ class MSCBasePrinter { /*! \brief Virtual method to print a CommentDoc*/ virtual void PrintTypedDoc(const CommentDoc& doc) { LOG(FATAL) << "Comment is not implemented"; } + /*! \brief Virtual method to print a DeclareDoc*/ + virtual void PrintTypedDoc(const DeclareDoc& doc) { LOG(FATAL) << "Declare is not implemented"; } + + /*! \brief Virtual method to print a StrictListDoc*/ + virtual void PrintTypedDoc(const StrictListDoc& doc) { + LOG(FATAL) << "StrictList is not implemented"; + } + + /*! \brief Virtual method to print a PointerDoc*/ + virtual void PrintTypedDoc(const PointerDoc& doc) { + LOG(FATAL) << "PointerDoc is not implemented"; + } + /*! \brief Print docs to joined doc */ template void PrintJoinedDocs(const Array& docs, const String& separator = ", ") { diff --git a/src/contrib/msc/core/printer/msc_doc.cc b/src/contrib/msc/core/printer/msc_doc.cc new file mode 100644 index 000000000000..d5c94675266e --- /dev/null +++ b/src/contrib/msc/core/printer/msc_doc.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/printer/msc_doc.cc + */ + +#include "msc_doc.h" + +#include + +namespace tvm { +namespace contrib { +namespace msc { + +DeclareDoc::DeclareDoc(Optional type, ExprDoc variable, Array init_args, + bool use_constructor) { + ObjectPtr n = make_object(); + n->type = type; + n->variable = variable; + n->init_args = init_args; + n->use_constructor = use_constructor; + this->data_ = std::move(n); +} + +StrictListDoc::StrictListDoc(ListDoc list, bool allow_empty) { + ObjectPtr n = make_object(); + n->list = list; + n->allow_empty = allow_empty; + this->data_ = std::move(n); +} + +PointerDoc::PointerDoc(String name) { + ObjectPtr n = make_object(); + n->name = name; + this->data_ = std::move(n); +} + +} // namespace msc +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h new file mode 100644 index 000000000000..440d2079ffca --- /dev/null +++ b/src/contrib/msc/core/printer/msc_doc.h @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/printer/msc_doc.h + * \brief Extra docs for MSC + */ +#ifndef TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ +#define TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ + +#include + +#include + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace tvm::script::printer; + +/*! + * \brief Doc that declare a var with type. + * + * \sa DeclareDoc + */ +class DeclareDocNode : public ExprDocNode { + public: + /*! \brief The type of the variable */ + Optional type; + /*! \brief The variable */ + ExprDoc variable{nullptr}; + /*! \brief The init arguments for the variable. */ + Array init_args; + /*! \brief Whether to use constructor(otherwise initializer) */ + bool use_constructor{true}; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("type", &type); + v->Visit("variable", &variable); + v->Visit("init_args", &init_args); + v->Visit("use_constructor", &use_constructor); + } + + static constexpr const char* _type_key = "script.printer.DeclareDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(DeclareDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of DeclareDocNode. + * + * \sa DeclareDocNode + */ +class DeclareDoc : public ExprDoc { + public: + /*! + * \brief Constructor of DeclareDoc. + * \param type The type of the variable. + * \param variable The variable. + * \param init_args The init arguments of the variable. + * \param use_constructor Whether to use constructor(otherwise initializer). + */ + explicit DeclareDoc(Optional type, ExprDoc variable, Array init_args, + bool use_constructor); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DeclareDoc, ExprDoc, DeclareDocNode); +}; + +/*! + * \brief Doc that build a strict list, which check the empty. + * + * \sa StrictListDoc + */ +class StrictListDocNode : public ExprDocNode { + public: + /*! \brief The inner list doc */ + ListDoc list; + /*! \brief Whether to allow empty */ + bool allow_empty{true}; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("list", &list); + v->Visit("allow_empty", &allow_empty); + } + + static constexpr const char* _type_key = "script.printer.StrictListDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(StrictListDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of StrictListDocNode. + * + * \sa StrictListDocNode + */ +class StrictListDoc : public ExprDoc { + public: + /*! + * \brief Constructor of StrictListDoc. + * \param list The inner list doc. + * \param allow_empty Whether to allow empty. + */ + explicit StrictListDoc(ListDoc list, bool allow_empty); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StrictListDoc, ExprDoc, StrictListDocNode); +}; + +/*! + * \brief Doc that represents pointer. + * + * \sa PointerDoc + */ +class PointerDocNode : public ExprDocNode { + public: + /*! \brief The name of the identifier */ + String name; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("name", &name); + } + + static constexpr const char* _type_key = "script.printer.PointerDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PointerDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of PointerDocNode. + * + * \sa PointerDocNode + */ +class PointerDoc : public ExprDoc { + public: + /*! + * \brief Constructor of PointerDoc. + * \param name The name of identifier. + */ + explicit PointerDoc(String name); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PointerDoc, ExprDoc, PointerDocNode); +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ diff --git a/src/contrib/msc/core/printer/print_utils.cc b/src/contrib/msc/core/printer/print_utils.cc index 6d7f1f368caa..6086e3ffa3c0 100644 --- a/src/contrib/msc/core/printer/print_utils.cc +++ b/src/contrib/msc/core/printer/print_utils.cc @@ -50,8 +50,32 @@ const ExprDoc DocUtils::ToDoc(const String& val) { return IdDoc(val); } const ExprDoc DocUtils::ToDoc(bool val) { return LiteralDoc::Boolean(val, NullOpt); } +const ExprDoc DocUtils::ToDoc(const ExprDoc& val) { return val; } + const ExprDoc DocUtils::ToStrDoc(const String& val) { return LiteralDoc::Str(val, NullOpt); } +const PointerDoc DocUtils::ToPtrDoc(const String& val) { return PointerDoc(val); } + +const DeclareDoc DocUtils::ToDeclareDoc(const String& type, const String& variable, size_t len, + bool use_constructor) { + Optional type_doc; + if (type.size() == 0) { + type_doc = NullOpt; + } else { + type_doc = IdDoc(type); + } + if (len == 0) { + return DeclareDoc(type_doc, IdDoc(variable), Array(), use_constructor); + } + Array doc_indices{DocUtils::ToDoc(len)}; + return DeclareDoc(type_doc, IndexDoc(IdDoc(variable), doc_indices), Array(), + use_constructor); +} + +const AttrAccessDoc DocUtils::ToAttrAccessDoc(const String& value, const String& name) { + return AttrAccessDoc(IdDoc(value), name); +} + const Array DocUtils::ToStmts(const Array& docs) { Array stmts; for (const auto& d : docs) { diff --git a/src/contrib/msc/core/printer/print_utils.h b/src/contrib/msc/core/printer/print_utils.h index def2c5560550..acae1dbd4d7f 100644 --- a/src/contrib/msc/core/printer/print_utils.h +++ b/src/contrib/msc/core/printer/print_utils.h @@ -28,6 +28,8 @@ #include +#include "msc_doc.h" + namespace tvm { namespace contrib { namespace msc { @@ -54,7 +56,23 @@ class DocUtils { TVM_DLL static const ExprDoc ToDoc(const char* val); TVM_DLL static const ExprDoc ToDoc(const String& val); TVM_DLL static const ExprDoc ToDoc(bool val); + TVM_DLL static const ExprDoc ToDoc(const ExprDoc& val); TVM_DLL static const ExprDoc ToStrDoc(const String& val); + TVM_DLL static const PointerDoc ToPtrDoc(const String& val); + + /*! + * \brief Change object to DeclareDoc. + * \return The DeclareDoc. + */ + TVM_DLL static const DeclareDoc ToDeclareDoc(const String& type, const String& variable, + size_t len = 0, bool use_constructor = true); + + /*! + * \brief Change object to AttrAccessDoc. + * \return The AttrAccessDoc. + */ + TVM_DLL static const AttrAccessDoc ToAttrAccessDoc(const String& value, const String& name); + /*! * \brief Change object to List of Docs. * \return The List of Docs. @@ -81,12 +99,53 @@ class DocUtils { * \return The ListDoc. */ template - TVM_DLL static const ListDoc ToListDoc(const std::vector& values) { - return ListDoc(ToDocList(values)); + TVM_DLL static const StrictListDoc ToListDoc(const std::vector& values, + bool allow_empty = false) { + if (values.size() > 0 || allow_empty) { + return StrictListDoc(ListDoc(ToDocList(values)), allow_empty); + } + return StrictListDoc(ListDoc(), false); + } + template + TVM_DLL static const StrictListDoc ToListDoc(const Array& values, bool allow_empty = false) { + if (values.size() > 0 || allow_empty) { + return StrictListDoc(ListDoc(ToDocList(values)), allow_empty); + } + return StrictListDoc(ListDoc(), false); + } + + /*! + * \brief Change object to IndexDoc. + * \return The ListDoc. + */ + template + TVM_DLL static const IndexDoc ToIndexDoc(const String& value, const std::vector& indices) { + Array doc_indices; + for (const auto& i : indices) { + doc_indices.push_back(ToDoc(i)); + } + return IndexDoc(IdDoc(value), doc_indices); } template - TVM_DLL static const ListDoc ToListDoc(const Array& values) { - return ListDoc(ToDocList(values)); + TVM_DLL static const IndexDoc ToIndexDoc(const String& value, const Array& indices) { + Array doc_indices; + for (const auto& i : indices) { + doc_indices.push_back(ToDoc(i)); + } + return IndexDoc(IdDoc(value), doc_indices); + } + + /*! + * \brief Change object to AssignDoc. + * \return The AssignDoc. + */ + template + TVM_DLL static const AssignDoc ToAssignDoc(const String& lhs, const T& rhs, + const String& annotation = "") { + if (annotation.size() == 0) { + return AssignDoc(IdDoc(lhs), ToDoc(rhs), NullOpt); + } + return AssignDoc(IdDoc(lhs), ToDoc(rhs), IdDoc(annotation)); } /*! diff --git a/src/contrib/msc/core/printer/python_printer.cc b/src/contrib/msc/core/printer/python_printer.cc index cf3a4cf7e25a..db198aaa569b 100644 --- a/src/contrib/msc/core/printer/python_printer.cc +++ b/src/contrib/msc/core/printer/python_printer.cc @@ -174,6 +174,14 @@ void PythonPrinter::PrintTypedDoc(const CommentDoc& doc) { } } +void PythonPrinter::PrintTypedDoc(const StrictListDoc& doc) { + if (doc->allow_empty || doc->list->elements.size() > 0) { + PrintDoc(doc->list, false); + } else { + output_ << "None"; + } +} + void PythonPrinter::PrintIndentedBlock(const Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { diff --git a/src/contrib/msc/core/printer/python_printer.h b/src/contrib/msc/core/printer/python_printer.h index e281a118cfbe..ac053ee276fb 100644 --- a/src/contrib/msc/core/printer/python_printer.h +++ b/src/contrib/msc/core/printer/python_printer.h @@ -19,7 +19,7 @@ /*! * \file src/contrib/msc/core/printer/python_printer.h - * \brief Prototxt Printer. + * \brief Python Printer. */ #ifndef TVM_CONTRIB_MSC_CORE_PRINTER_PYTHON_PRINTER_H_ @@ -36,7 +36,7 @@ namespace msc { using namespace tvm::script::printer; /*! - * \brief PythonPrinter change list of dict to python format + * \brief PythonPrinter change list of docs to python format * \sa Doc */ class PythonPrinter : public MSCBasePrinter { @@ -78,6 +78,9 @@ class PythonPrinter : public MSCBasePrinter { /*! * \brief Print a CommentDoc to python format*/ void PrintTypedDoc(const CommentDoc& doc) final; + /*! \brief Virtual method to print a StrictListDoc*/ + void PrintTypedDoc(const StrictListDoc& doc) final; + private: /*! \brief Print block with indent*/ void PrintIndentedBlock(const Array& docs); diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index 3c70e1871b7e..78dc9abeb1a7 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -49,7 +49,7 @@ bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { } expr->span = SpanUtils::SetAttr(expr->span, "layout", l_layout.name()); } else if (sinfo.as()) { - ICHECK(!layout.IsLeaf()) << "Expr has tupple struct, but find non-nested layout " << expr; + ICHECK(!layout.IsLeaf()) << "Expr has tuple struct, but find non-nested layout " << expr; String layout_str; Array nested_layouts = layout.NestedArray(); for (size_t i = 0; i < nested_layouts.size(); i++) { diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index a94c846b3f18..fac1bbf89f3a 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -252,6 +252,9 @@ InferLayoutOutput ForwardInferLayoutBinary(const Call& call, if (const auto* t_info = sinfo.as()) { if (t_info->ndim == 0) { input_layouts.push_back(LayoutDecision("")); + } else if (t_info->ndim == 1) { + const auto& ref_layout = output->output_layouts[0].LeafValue()->layout; + input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1].name())); } else { input_layouts.push_back(output->input_layouts[i]); } @@ -268,6 +271,32 @@ InferLayoutOutput ForwardInferLayoutInplace(const Call& call, return ForwardInferLayoutCommon(call, desired_layouts, var_layout_map); } +InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + if (!input_layout->layout.defined()) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + if (attrs->keepdims) { + return InferLayoutOutput({input_layout}, {input_layout}, Attrs()); + } + if (!attrs->axis.defined()) { + return InferLayoutOutput({input_layout}, {LayoutDecision("")}, Attrs()); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + std::vector axes; + axes.push_back(CommonUtils::GetIndex(Downcast(attrs->axis)->value, input_shape.size())); + LayoutDecision output_layout = LayoutUtils::ReduceLayout(input_layout, axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { @@ -504,9 +533,9 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") .set_attr("FMSCForwardInferLayout", MSCInferLayoutPool2d); // reduce axis ops TVM_REGISTER_OP("relax.argmax") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutArgMaxMin); TVM_REGISTER_OP("relax.argmin") - .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); + .set_attr("FMSCForwardInferLayout", ForwardInferLayoutArgMaxMin); TVM_REGISTER_OP("relax.max") .set_attr("FMSCForwardInferLayout", ForwardInferLayoutReduceAxis); TVM_REGISTER_OP("relax.min") @@ -624,6 +653,9 @@ InferLayoutOutput BackwardInferLayoutBinary(const Call& call, if (const auto* t_info = sinfo.as()) { if (t_info->ndim == 0) { input_layouts.push_back(LayoutDecision("")); + } else if (t_info->ndim == 1) { + const auto& ref_layout = output->output_layouts[0].LeafValue()->layout; + input_layouts.push_back(LayoutDecision(ref_layout[ref_layout.ndim() - 1].name())); } else { input_layouts.push_back(output->input_layouts[i]); } @@ -640,6 +672,29 @@ InferLayoutOutput BackwardInferLayoutInplace(const Call& call, return BackwardInferLayoutCommon(call, desired_layouts, var_layout_map); } +InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + if (!output_layout->layout.defined()) { + return InferLayoutOutput(); + } + const auto* attrs = call->attrs.as(); + if (attrs->keepdims) { + return InferLayoutOutput({output_layout}, {output_layout}, Attrs()); + } + Array empty; + const auto& input_shape = + Downcast(GetStructInfo(call->args[0]))->GetShape().value_or(empty); + if (input_shape.size() == 0) { + return InferLayoutOutput(); + } + std::vector axes; + axes.push_back(CommonUtils::GetIndex(Downcast(attrs->axis)->value, input_shape.size())); + LayoutDecision input_layout = LayoutUtils::ExpandLayout(output_layout, axes); + return InferLayoutOutput({input_layout}, {output_layout}, Attrs()); +} + InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { @@ -853,9 +908,9 @@ TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d") .set_attr("FMSCBackwardInferLayout", MSCInferLayoutPool2d); // reduce axis ops TVM_REGISTER_OP("relax.argmax") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutArgMaxMin); TVM_REGISTER_OP("relax.argmin") - .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); + .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutArgMaxMin); TVM_REGISTER_OP("relax.max") .set_attr("FMSCBackwardInferLayout", BackwardInferLayoutReduceAxis); TVM_REGISTER_OP("relax.min") @@ -963,17 +1018,24 @@ class LayoutInfer : public ExprVisitor { if (expr->IsInstance()) { continue; } + if (expr->IsInstance()) { + continue; + } if (!expr->IsInstance()) { continue; } const Call& call = Downcast(expr); + if (const auto* v_node = call->op.as()) { + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + BackwardInferFunc(func, call); + continue; + } else if (call->op->IsInstance() && local_funcs_.count(call->op)) { + BackwardInferFunc(local_funcs_[call->op], call); + continue; + } size_t infered_num = 0; for (const auto& arg : call->args) { - if (arg->IsInstance() && var_map_.count(Downcast(arg))) { - if (LayoutUtils::LayoutInfered(var_map_[Downcast(arg)]) > 0) { - infered_num++; - } - } else if (LayoutUtils::LayoutInfered(arg)) { + if (IsArgInfered(arg)) { infered_num++; } } @@ -988,7 +1050,7 @@ class LayoutInfer : public ExprVisitor { // Infer by op_node Op op = Downcast(GetRef(op_node)); InferLayoutOutput infered_layout; - const auto msc_infer_map = Op::GetAttrMap("FMSCBackwardInferLayout"); + const auto& msc_infer_map = Op::GetAttrMap("FMSCBackwardInferLayout"); try { if (msc_infer_map.count(op)) { FRelaxInferLayout f = msc_infer_map[op]; @@ -1003,7 +1065,7 @@ class LayoutInfer : public ExprVisitor { } try { if (infered_layout.defined()) { - SetInputLayouts(infered_layout->input_layouts, call); + SetInputLayouts(call, infered_layout->input_layouts); } } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to backward set inputs layout for " << call << " : " @@ -1012,47 +1074,16 @@ class LayoutInfer : public ExprVisitor { } } - void SetInputLayouts(const Array& input_layouts, const Call& call) { - if (input_layouts.size() == call->args.size()) { - for (size_t i = 0; i < input_layouts.size(); i++) { - if (call->args[i]->IsInstance()) { - const auto& var = Downcast(call->args[i]); - var_layout_map_[var] = input_layouts[i]; - if (var_map_.count(var)) { - if (LayoutUtils::SetLayout(var_map_[var], input_layouts[i])) { - infered_ = true; - } - } else if (LayoutUtils::SetLayout(var, input_layouts[i])) { - infered_ = true; - } - } else if (LayoutUtils::SetLayout(call->args[i], input_layouts[i])) { - infered_ = true; - } - } - } - } - void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { ExprVisitor::VisitBinding_(binding, call_node); const auto& call = GetRef(call_node); if (const auto* v_node = call->op.as()) { - // infer global func and set var layouts const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - Infer(func); - for (size_t i = 0; i < func->params.size(); i++) { - if (var_layout_map_.count(func->params[i]) && - LayoutUtils::SetLayout(call->args[i], var_layout_map_[func->params[i]])) { - infered_ = true; - } - } - if (const auto* b_node = func->body.as()) { - var_layout_map_[binding->var] = GetNLayout(var_layout_map_, b_node->body); - if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) { - infered_ = true; - } - } else { - LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; - } + RecordExpr(binding->var, call); + ForwardInferFunc(func, call, binding->var); + } else if (call->op->IsInstance() && local_funcs_.count(call->op)) { + RecordExpr(binding->var, call); + ForwardInferFunc(local_funcs_[call->op], call, binding->var); } else { // infer call bool infer_outputs = true; @@ -1072,8 +1103,8 @@ class LayoutInfer : public ExprVisitor { // infer layouts Op op = Downcast(GetRef(op_node)); InferLayoutOutput infered_layout; - const auto msc_infer_map = Op::GetAttrMap("FMSCForwardInferLayout"); - const auto relax_infer_map = Op::GetAttrMap("FRelaxInferLayout"); + const auto& msc_infer_map = Op::GetAttrMap("FMSCForwardInferLayout"); + const auto& relax_infer_map = Op::GetAttrMap("FRelaxInferLayout"); bool set_inputs = true; try { if (msc_infer_map.count(op)) { @@ -1095,10 +1126,7 @@ class LayoutInfer : public ExprVisitor { } if (infered_layout.defined() && infered_layout->output_layouts.size() == 1) { try { - var_layout_map_[binding->var] = infered_layout->output_layouts[0]; - if (LayoutUtils::SetLayout(call, var_layout_map_[binding->var])) { - infered_ = true; - } + SetExprLayout(binding->var, infered_layout->output_layouts[0]); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to forward set output layout for " << binding->var << " : " << binding->value << ", reason: " << err.message(); @@ -1106,7 +1134,7 @@ class LayoutInfer : public ExprVisitor { } if (set_inputs && infered_layout.defined()) { try { - SetInputLayouts(infered_layout->input_layouts, call); + SetInputLayouts(call, infered_layout->input_layouts); } catch (runtime::InternalError& err) { LOG(WARNING) << "Failed to forward set inputs layout for " << call << " : " << err.message(); @@ -1116,49 +1144,133 @@ class LayoutInfer : public ExprVisitor { } } + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { + local_funcs_.Set(binding->var, GetRef(val)); + } + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { ExprVisitor::VisitBinding_(binding, val); + RecordExpr(binding->var, GetRef(val)); if (IsNestedTensor(binding->var)) { Array input_layouts; for (const auto& field : val->fields) { input_layouts.push_back(InferLayoutDecision(field, var_layout_map_)); } - var_layout_map_[binding->var] = input_layouts; - if (LayoutUtils::SetLayout(GetRef(val), NLayout(input_layouts))) { - infered_ = true; - } + SetExprLayout(binding->var, input_layouts); } - RecordExpr(binding->var, GetRef(val)); } void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { ExprVisitor::VisitBinding_(binding, val); + RecordExpr(binding->var, GetRef(val)); const auto& out_layout = InferLayoutDecisionAt(GetRef(val)->tuple, var_layout_map_, val->index); - var_layout_map_[binding->var] = out_layout; - if (LayoutUtils::SetLayout(GetRef(val), out_layout)) { - infered_ = true; - } - RecordExpr(binding->var, GetRef(val)); + SetExprLayout(binding->var, out_layout); } void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) final { ExprVisitor::VisitBinding_(binding, val); - const NLayout& out_layout = LayoutDecision("O"); - var_layout_map_[binding->var] = out_layout; - if (LayoutUtils::SetLayout(GetRef(val), out_layout)) { - infered_ = true; - } + RecordExpr(binding->var, GetRef(val)); + SetExprLayout(binding->var, LayoutDecision("O")); } bool infered() { return infered_; } private: + bool IsArgInfered(const Expr& arg) { + if (arg->IsInstance() && var_map_.count(Downcast(arg))) { + if (LayoutUtils::LayoutInfered(var_map_[Downcast(arg)]) > 0) { + return true; + } + } else if (const auto* tuple_node = arg.as()) { + for (const auto& field : tuple_node->fields) { + if (!IsArgInfered(field)) { + return false; + } + } + return true; + } else if (LayoutUtils::LayoutInfered(arg)) { + return true; + } + return false; + } + + void SetExprLayout(const Expr& expr, const NLayout& layout) { + if (expr->IsInstance()) { + const auto& var = Downcast(expr); + var_layout_map_[var] = layout; + if (LayoutUtils::SetLayout(var, layout)) { + infered_ = true; + } + if (var_map_.count(var) && LayoutUtils::SetLayout(var_map_[var], layout)) { + infered_ = true; + } + } else if (LayoutUtils::SetLayout(expr, layout)) { + infered_ = true; + } + } + + void SetInputLayouts(const Call& call, const Array& input_layouts) { + if (input_layouts.size() == call->args.size()) { + for (size_t i = 0; i < input_layouts.size(); i++) { + SetExprLayout(call->args[i], input_layouts[i]); + } + } + } + + void ForwardInferFunc(const Function& func, const Call& call, const Var& ret) { + for (size_t i = 0; i < call->args.size(); i++) { + if (call->args[i]->IsInstance() && + var_layout_map_.count(Downcast(call->args[i]))) { + SetExprLayout(func->params[i], var_layout_map_[Downcast(call->args[i])]); + } + } + ForwardInfer(func); + for (size_t i = 0; i < func->params.size(); i++) { + if (var_layout_map_.count(func->params[i])) { + SetExprLayout(call->args[i], var_layout_map_[func->params[i]]); + } + } + if (const auto* b_node = func->body.as()) { + if (b_node->body->IsInstance() && + var_layout_map_.count(Downcast(b_node->body))) { + SetExprLayout(ret, var_layout_map_[Downcast(b_node->body)]); + } + } else { + LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; + } + } + + void BackwardInferFunc(const Function& func, const Call& call) { + for (size_t i = 0; i < func->params.size(); i++) { + if (var_layout_map_.count(func->params[i])) { + const auto& param_layout = var_layout_map_[func->params[i]]; + SetExprLayout(call->args[i], param_layout); + if (call->args[i]->IsInstance() && var_map_.count(Downcast(call->args[i]))) { + const auto& producer = var_map_[Downcast(call->args[i])]; + if (producer->IsInstance() && + local_funcs_.count(Downcast(producer)->op)) { + const auto& caller = local_funcs_[Downcast(producer)->op]; + if (const auto* b_node = caller->body.as()) { + if (b_node->body->IsInstance() && + var_map_.count(Downcast(b_node->body))) { + SetExprLayout(b_node->body, param_layout); + } + } else { + LOG(FATAL) << "Caller body should be SeqExpr, get " << caller->body; + } + } + } + } + } + } + IRModule ref_module_; bool infered_; Map var_map_; Array ordered_exprs_; std::unordered_map var_layout_map_; + Map local_funcs_; }; // class LayoutInfer class LayoutChecker : public ExprVisitor { diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 8a5d8fec5d42..68213d02b724 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -41,7 +41,8 @@ namespace relax { */ class RelaxExprNameSetter : public ExprVisitor { public: - explicit RelaxExprNameSetter(const IRModule& ref_module) : ref_module_(ref_module) {} + explicit RelaxExprNameSetter(const IRModule& ref_module, const String& target) + : ref_module_(ref_module), target_{target} {} void VisitBindingBlock(const BindingBlock& block) final { String block_name = SpanUtils::GetAttr(block->span, "name"); @@ -109,6 +110,14 @@ class RelaxExprNameSetter : public ExprVisitor { expr_names_.Set(binding->var, unique_name); } + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { + ExprVisitor::VisitBinding_(binding, val); + const auto& name_opt = val->GetAttr(attr::kComposite); + if (name_opt.defined()) { + local_funcs_.Set(binding->var, GetRef(val)); + } + } + void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { ExprVisitor::VisitBinding_(binding, val); String name_hint, optype; @@ -120,38 +129,52 @@ class RelaxExprNameSetter : public ExprVisitor { } else if (const auto* v_node = val->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); ExprVisitor::VisitExpr(func); - const auto& name_opt = func->GetAttr(attr::kComposite); - ICHECK(name_opt.defined()) << "Unexpected global func without composite"; - name_hint = name_opt.value(); - optype = name_hint; - } - // set name - const String& unique_name = GetUniqueName(GetRef(val), name_hint); - if (unique_name != SpanUtils::GetAttr(val->span, "name")) { - val->span = SpanUtils::SetAttr(val->span, "name", unique_name); - } - // set constant consumer && shared_ref - Array input_types; - try { - input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(val) << " : " << err.message(); - throw err; + optype = GetFuncType(func); + if (optype == "extern_func") { + name_hint = v_node->name_hint; + } else { + name_hint = optype; + } + } else if (local_funcs_.count(val->op)) { + optype = GetFuncType(local_funcs_[val->op]); + ExprVisitor::VisitExpr(local_funcs_[val->op]); + if (optype == "extern_func") { + name_hint = Downcast(val->op)->name_hint(); + } else { + name_hint = optype; + } } - for (size_t i = 0; i < input_types.size(); i++) { - if (input_types[i] == "input") { - continue; + if (name_hint.size() > 0) { + // set name + const String& unique_name = GetUniqueName(GetRef(val), name_hint); + if (unique_name != SpanUtils::GetAttr(val->span, "name")) { + val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + } + // set constant consumer && shared_ref + Array input_types; + try { + input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true); + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(val) << " : " + << err.message(); + throw err; } - if (const auto* c_node = val->args[i].as()) { - const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); - if (constant_consumers_.count(const_name)) { - val->span = SpanUtils::SetAttr(val->span, "shared_ref", constant_consumers_[const_name]); - } else { - constant_consumers_.Set(const_name, unique_name); + for (size_t i = 0; i < input_types.size(); i++) { + if (input_types[i] == "input") { + continue; + } + if (const auto* c_node = val->args[i].as()) { + const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); + if (constant_consumers_.count(const_name)) { + val->span = + SpanUtils::SetAttr(val->span, "shared_ref", constant_consumers_[const_name]); + } else { + constant_consumers_.Set(const_name, unique_name); + } } } + expr_names_.Set(binding->var, unique_name); } - expr_names_.Set(binding->var, unique_name); } private: @@ -179,24 +202,40 @@ class RelaxExprNameSetter : public ExprVisitor { return expr_name; } + const String GetFuncType(const Function& func) { + String optype; + const auto& name_opt = func->GetAttr(attr::kComposite); + if (name_opt.defined()) { + optype = name_opt.value(); + if (target_.size() > 0) { + optype = StringUtils::Replace(optype, target_ + ".", ""); + } + } else { + optype = "extern_func"; + } + return optype; + } + Map setted_names_; Map constant_consumers_; std::set setted_blocks_; Array block_stack_; Map expr_names_; + Map local_funcs_; IRModule ref_module_; + String target_; }; // class ExprNameSetter -void SetRelaxExprName(const IRModule& ref_module, const Expr& e) { - RelaxExprNameSetter(ref_module).VisitExpr(e); +void SetRelaxExprName(const IRModule& ref_module, const Expr& e, const String& target) { + RelaxExprNameSetter(ref_module, target).VisitExpr(e); } namespace transform { -Pass SetRelaxExprName(const String& entry_name) { +Pass SetRelaxExprName(const String& entry_name, const String& target) { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { - relax::SetRelaxExprName(m, m->Lookup(entry_name)); + relax::SetRelaxExprName(m, m->Lookup(entry_name), target); return m; }; return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); @@ -261,35 +300,42 @@ class RelayExprNameSetter : public ExprVisitor { optype = StringUtils::Replace(op_node->name, "relay.", ""); } else if (const auto* v_node = op->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - ExprVisitor::VisitExpr(func); const auto& name_opt = func->GetAttr(attr::kComposite); - ICHECK(name_opt.defined()) << "Unexpected global func without composite"; - optype = name_opt.value(); - name_hint = optype; - } - // set name - const String& unique_name = GetUniqueName(GetRef(op), name_hint); - if (unique_name != SpanUtils::GetAttr(op->span, "name")) { - op->span = SpanUtils::SetAttr(op->span, "name", unique_name); - } - // set constant consumer && shared_ref - Array input_types; - try { - input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false); - } catch (runtime::InternalError& err) { - LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(op) << " : " << err.message(); - throw err; + if (name_opt.defined()) { + optype = name_opt.value(); + name_hint = optype; + ExprVisitor::VisitExpr(func); + } else { + optype = "extern_func"; + name_hint = v_node->name_hint; + } } - for (size_t i = 0; i < input_types.size(); i++) { - if (input_types[i] == "input") { - continue; + if (name_hint.size() > 0) { + // set name + const String& unique_name = GetUniqueName(GetRef(op), name_hint); + if (unique_name != SpanUtils::GetAttr(op->span, "name")) { + op->span = SpanUtils::SetAttr(op->span, "name", unique_name); } - if (const auto* c_node = op->args[i].as()) { - const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); - if (constant_consumers_.count(const_name)) { - op->span = SpanUtils::SetAttr(op->span, "shared_ref", constant_consumers_[const_name]); - } else { - constant_consumers_.Set(const_name, unique_name); + // set constant consumer && shared_ref + Array input_types; + try { + input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false); + } catch (runtime::InternalError& err) { + LOG(WARNING) << "Failed to GetInputTypes for " << GetRef(op) << " : " + << err.message(); + throw err; + } + for (size_t i = 0; i < input_types.size(); i++) { + if (input_types[i] == "input") { + continue; + } + if (const auto* c_node = op->args[i].as()) { + const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); + if (constant_consumers_.count(const_name)) { + op->span = SpanUtils::SetAttr(op->span, "shared_ref", constant_consumers_[const_name]); + } else { + constant_consumers_.Set(const_name, unique_name); + } } } } @@ -329,6 +375,69 @@ void SetRelayExprName(const IRModule& ref_module, const Expr& e) { RelayExprNameSetter(ref_module).VisitExpr(e); } +/*! + * \brief Name binder for Relay + */ +class RelayExprNameBinder : public ExprVisitor { + public: + explicit RelayExprNameBinder(const String& name_key, const String& seperator) + : name_key_(name_key), seperator_(seperator) {} + + void VisitExpr_(const ConstantNode* op) final { + if (op->span.defined()) { + BindName(GetRef(op)); + } + } + + void VisitExpr_(const CallNode* op) final { + if (op->span.defined()) { + BindName(GetRef(op)); + } + ExprVisitor::VisitExpr_(op); + } + + private: + void BindName(const Expr& expr) { + const auto& name = expr->span->source_name->name; + String valid_name; + if (name_key_.size() == 0) { + valid_name = name; + expr->span = Span(SourceName::Get(""), expr->span->line, expr->span->end_line, + expr->span->column, expr->span->end_column); + } else { + String right = std::get<1>(StringUtils::SplitOnce(name, name_key_)); + if (right.size() > 0) { + valid_name = std::get<0>(StringUtils::SplitOnce(name, seperator_)); + if (valid_name.size() > 0) { + const auto& new_source = StringUtils::Replace(name, name_key_ + valid_name, ""); + expr->span = Span(SourceName::Get(new_source), expr->span->line, expr->span->end_line, + expr->span->column, expr->span->end_column); + } + } + } + if (valid_name.size() > 0) { + if (setted_names_.count(valid_name)) { + int cnt = 1; + while (setted_names_.count(valid_name + "_" + std::to_string(cnt)) && + setted_names_[valid_name + "_" + std::to_string(cnt)] != expr) { + cnt++; + } + valid_name = valid_name + "_" + std::to_string(cnt); + } + setted_names_.Set(valid_name, expr); + expr->span = SpanUtils::SetAttr(expr->span, "name", valid_name); + } + } + + Map setted_names_; + String name_key_; + String seperator_; +}; // class ExprNameBinder + +void BindRelayExprName(const Expr& e, const String& name_key, const String& seperator) { + RelayExprNameBinder(name_key, seperator).VisitExpr(e); +} + namespace transform { Pass SetRelayExprName(const String& entry_name) { @@ -342,6 +451,17 @@ Pass SetRelayExprName(const String& entry_name) { TVM_REGISTER_GLOBAL("relay._transform.SetRelayExprName").set_body_typed(SetRelayExprName); +Pass BindRelayExprName(const String& name_key, const String& seperator, const String& entry_name) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + relay::BindRelayExprName(m->Lookup(entry_name), name_key, seperator); + return m; + }; + return CreateModulePass(pass_func, 0, "BindRelayExprName", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.BindRelayExprName").set_body_typed(BindRelayExprName); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 908c4d95077d..66f26bd1ebee 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -47,6 +47,49 @@ std::vector CommonUtils::GetIndices(const std::vector& indices, siz return v_indices; } +bool StringUtils::Contains(const String& src_string, const String& sub_string) { + if (src_string.size() == 0) { + return false; + } + if (sub_string.size() == 0) { + return false; + } + + const std::string& src_cstring = src_string; + const std::string& sub_cstring = sub_string; + int pos = src_cstring.find(sub_cstring); + return pos >= 0; +} + +bool StringUtils::StartsWith(const String& src_string, const String& sub_string) { + if (src_string.size() == 0) { + return false; + } + if (sub_string.size() == 0) { + return false; + } + const std::string& src_cstring = src_string; + const std::string& sub_cstring = sub_string; + int pos = src_cstring.find(sub_cstring); + return pos == 0; +} + +bool StringUtils::EndsWith(const String& src_string, const String& sub_string) { + if (src_string.size() == 0) { + return false; + } + if (sub_string.size() == 0) { + return false; + } + const std::string& src_cstring = src_string; + const std::string& sub_cstring = sub_string; + int pos = src_cstring.rfind(sub_cstring); + if (pos < 0) { + return false; + } + return static_cast(pos) == src_cstring.size() - sub_cstring.size(); +} + const Array StringUtils::Split(const String& src_string, const String& sep) { Array sub_strings; if (src_string.size() == 0) { @@ -192,7 +235,8 @@ const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& const String& source_str = span->source_name->name; String left = std::get<0>(StringUtils::SplitOnce(source_str, tokens[0])); String right = std::get<1>(StringUtils::SplitOnce(source_str, tokens[1])); - if (left.size() > 0) { + if (StringUtils::Contains(source_str, tokens[0]) && + StringUtils::Contains(source_str, tokens[1])) { new_source = left + tokens[0] + value + tokens[1] + right; } else { new_source = source_str + tokens[0] + value + tokens[1]; @@ -256,28 +300,18 @@ const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs input_types.push_back("gamma"); input_types.push_back("beta"); } else if (optype == "msc.linear") { - if (as_relax) { - input_types.push_back("weight"); - input_types.push_back("input"); - } else { - input_types.push_back("input"); - input_types.push_back("weight"); - } + input_types.push_back("input"); + input_types.push_back("weight"); } else if (optype == "msc.conv1d_bias" || optype == "msc.conv2d_bias") { input_types.push_back("input"); input_types.push_back("weight"); input_types.push_back("bias"); - if (as_relax) { + if (as_relax && inputs_num > 3) { input_types.push_back("expand_bias"); } } else if (optype == "msc.linear_bias") { - if (as_relax) { - input_types.push_back("weight"); - input_types.push_back("input"); - } else { - input_types.push_back("input"); - input_types.push_back("weight"); - } + input_types.push_back("input"); + input_types.push_back("weight"); input_types.push_back("bias"); } else if (optype == "msc.embedding" && inputs_num == 2) { input_types.push_back("input"); diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 82a77d895b3e..de03d1764291 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -59,6 +59,24 @@ class CommonUtils { */ class StringUtils { public: + /*! + * \brief Check if the String contains a substring. + * \return Whether substring is contained. + */ + TVM_DLL static bool Contains(const String& src_string, const String& sub_string); + + /*! + * \brief Check if the String starts with a substring. + * \return Whether string starts with substring. + */ + TVM_DLL static bool StartsWith(const String& src_string, const String& sub_string); + + /*! + * \brief Check if the String ens with a substring. + * \return Whether string endswith substring. + */ + TVM_DLL static bool EndsWith(const String& src_string, const String& sub_string); + /*! * \brief Split the String into sub Strings. * \return The SubStrings. @@ -159,7 +177,11 @@ class ArrayUtils { TVM_DLL static const Array Cast(const Array& src_array) { Array new_array; for (const auto& s : src_array) { - new_array.push_back(Downcast(s)); + if (s->IsInstance()) { + new_array.push_back(T(-1)); + } else { + new_array.push_back(Downcast(s)); + } } return new_array; } diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index c77e67816fb2..daf6f48633e8 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -42,12 +42,7 @@ void TorchCodeGen::CodeGenGraph() { stack_.func_def("__init__", "torch.nn.Module"); stack_.func_arg("self", "torch.nn.Module"); stack_.func_start(); - stack_.call_start("super") - .call_arg(graph()->name) - .call_arg("self") - .call_end() - .inplace_start("__init__") - .inplace_end(); + stack_.func_call("super").call_arg(graph()->name).call_arg("self").method_call("__init__"); for (const auto& n : graph()->node_names) { const auto& node = graph()->FindNode(n); if (node->optype == "input") { @@ -63,7 +58,7 @@ void TorchCodeGen::CodeGenGraph() { stack_.func_arg("self", "torch.nn.Module"); for (const auto& i : graph()->GetInputs()) { const auto& pair = graph()->FindProducerAndIdx(i); - stack_.func_arg(IdxOutput(pair.first, pair.second), "torch.Tensor"); + stack_.func_arg(IdxOutputBase(pair.first, pair.second), "torch.Tensor"); } stack_.func_start(); for (const auto& n : graph()->node_names) { @@ -76,12 +71,12 @@ void TorchCodeGen::CodeGenGraph() { Array idx_outputs; for (const auto& o : graph()->GetOutputs()) { const auto& pair = graph()->FindProducerAndIdx(o); - idx_outputs.push_back(IdxOutput(pair.first, pair.second)); + idx_outputs.push_back(IdxOutputBase(pair.first, pair.second)); } if (idx_outputs.size() == 1) { stack_.assign("outputs", idx_outputs[0]); } else { - stack_.assign_list("outputs", idx_outputs); + stack_.assign("outputs", DocUtils::ToListDoc(idx_outputs)); } stack_.func_end("outputs"); stack_.class_end(); @@ -89,37 +84,28 @@ void TorchCodeGen::CodeGenGraph() { void TorchCodeGen::CodeGenInference() { stack_.comment("Build Model") - .call_start(graph()->name) - .call_end("model") + .func_call(graph()->name, "model") .comment("Load weights") - .call_start("torch.load") - .call_str_arg(graph()->name + ".pth") - .call_end("weights") - .call_start("model.load_state_dict") - .call_arg("weights") - .call_end(); + .func_call("torch.load", "weights") + .call_arg(DocUtils::ToStrDoc(graph()->name + ".pth")) + .func_call("model.load_state_dict") + .call_arg("weights"); if (config()->test_device == "gpu") { - stack_.call_start("model.to").call_start("torch.device").call_arg("cuda").call_end().call_end(); + stack_.func_call("model.to").func_call("torch.device").call_arg("cuda").pop_nest(); } for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); - stack_.call_start("torch.from_numpy") - .call_arg("inputs[\"" + i->alias + "\"]") - .call_end(IdxNode(producer)); + stack_.func_call("torch.from_numpy", IdxNodeBase(producer)) + .call_arg("inputs[\"" + i->alias + "\"]"); } - stack_.call_start("model"); + stack_.func_call("model", "outputs"); for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); - stack_.call_arg(IdxNode(producer)); + stack_.call_arg(IdxNodeBase(producer)); if (config()->test_device == "gpu") { - stack_.inplace_start("to") - .call_start("torch.device") - .call_arg("cuda") - .call_end() - .inplace_end(); + stack_.method_call("to").func_call("torch.device").call_arg("cuda"); } } - stack_.call_end("outputs"); } const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index cbd1ada68a1f..ff6f18424875 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -42,15 +42,15 @@ const Array TorchOpCode::GetDocs() { void TorchOpCode::CodeGenInit() { if (module_name().size() > 0) { - stack_.op_start().op_end(); + stack_.op_call(); } else { stack_.comment("passby: implement by " + func_name()); } } -void TorchOpCode::CodeGenForward() { stack_.op_start().op_inputs_arg(false).op_end(); } +void TorchOpCode::CodeGenForward() { stack_.op_call().op_inputs_arg(false); } -const std::vector TorchOpCode::GetPadding(const String& key) { +const StrictListDoc TorchOpCode::GetPadding(const String& key) { std::vector padding, src_padding; ICHECK(node()->GetAttr(key, &src_padding)); if (node()->optype == "nn.conv1d" || node()->optype == "msc.conv1d_bias") { @@ -70,8 +70,10 @@ const std::vector TorchOpCode::GetPadding(const String& key) { } else { LOG_FATAL << "nn.conv2d/pool2d with unexpected padding " << node(); } + } else { + LOG_FATAL << "Unexpected padding node" << node(); } - return padding; + return DocUtils::ToListDoc(padding); } #define TORCH_OP_CODEGEN_METHODS(TypeName) \ @@ -83,7 +85,7 @@ class TorchAdaptivePoolCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchAdaptivePoolCodeGen); protected: - void CodeGenInit() final { stack_.op_start().op_list_arg("output_size").op_end(); } + void CodeGenInit() final { stack_.op_call().op_list_arg("output_size"); } }; class TorchAstypeCodeGen : public TorchOpCode { @@ -91,10 +93,7 @@ class TorchAstypeCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.assign(IdxNode(), IdxInput()) - .inplace_start("to") - .call_dtype_arg(node()->OutputAt(0)->dtype) - .inplace_end(); + stack_.assign(IdxNode(), IdxInput()).method_call("to").op_dtype_arg(node()->OutputAt(0)->dtype); } }; @@ -104,13 +103,12 @@ class TorchAttentionCodeGen : public TorchOpCode { protected: void CodeGenForward() final { std::string causal_mask; - stack_.op_start().op_inputs_arg(false); + stack_.op_call().op_inputs_arg(false); if (node()->GetAttr("causal_mask", &causal_mask)) { if (causal_mask.size() > 0) { stack_.call_arg(true, "is_causal"); } } - stack_.op_end(); } }; @@ -121,7 +119,7 @@ class TorchAxesCodeGen : public TorchOpCode { void CodeGenInit() final { if (module_name().size() > 0) { const String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_start().op_list_arg(key, "").op_end(); + stack_.op_call().op_list_arg(key, ""); } else { TorchOpCode::CodeGenInit(); } @@ -132,7 +130,7 @@ class TorchAxesCodeGen : public TorchOpCode { TorchOpCode::CodeGenForward(); } else { const String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_start().op_input_arg().op_list_arg(key, "").op_end(); + stack_.op_call().op_input_arg().op_list_arg(key, ""); } } }; @@ -143,7 +141,7 @@ class TorchAxisCodeGen : public TorchOpCode { protected: void CodeGenInit() final { if (module_name().size() > 0) { - stack_.op_start().op_arg("axis", "dim").op_end(); + stack_.op_call().op_arg("axis", "dim"); } else { TorchOpCode::CodeGenInit(); } @@ -153,7 +151,7 @@ class TorchAxisCodeGen : public TorchOpCode { if (module_name().size() > 0) { TorchOpCode::CodeGenForward(); } else { - stack_.op_start().op_input_arg().op_arg("axis", "dim").op_end(); + stack_.op_call().op_input_arg().op_arg("axis", "dim"); } } }; @@ -166,10 +164,7 @@ class TorchBatchNormCodeGen : public TorchOpCode { ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) << "Only support center and scale batchnorm, get " << node(); const auto& gamma = node()->WeightAt("gamma"); - stack_.op_start() - .call_arg(gamma->DimAt(0), "num_features") - .op_arg("epsilon", "eps") - .op_end(); + stack_.op_call().call_arg(gamma->DimAt(0), "num_features").op_arg("epsilon", "eps"); } }; @@ -178,10 +173,7 @@ class TorchBroadcastToCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.assign(IdxNode(), IdxInput()) - .inplace_start("expand") - .call_list_arg(node()->GetTypeArrayAttr("shape"), "", false, true) - .inplace_end(); + stack_.assign(IdxNode(), IdxInput()).method_call("expand").op_list_arg("shape", ""); } }; @@ -190,7 +182,7 @@ class TorchClipCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.op_start().op_input_arg().op_arg("min").op_arg("max").op_end(); + stack_.op_call().op_input_arg().op_arg("min").op_arg("max"); } }; @@ -208,12 +200,10 @@ class TorchConstantCodeGen : public TorchOpCode { stack_.assign(module_ref(), node()->GetTypeAttr("scalar")); } } else { - stack_.call_start("torch.Tensor") - .call_list_arg(node()->OutputAt(0)->shape, "", false, false) - .call_end("data") - .op_start() - .call_arg("data") - .op_end(); + stack_.func_call("torch.Tensor", "data") + .call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape)) + .op_call() + .call_arg("data"); } } @@ -235,16 +225,15 @@ class TorchConvCodeGen : public TorchOpCode { } kernel_size.push_back(weight->DimAt(i)->value); } - stack_.op_start() + stack_.op_call() .call_arg(weight->DimAt("I"), "in_channels") .call_arg(weight->DimAt("O"), "out_channels") - .call_list_arg(kernel_size, "kernel_size") + .call_arg(DocUtils::ToListDoc(kernel_size), "kernel_size") .op_list_arg("strides", "stride") - .call_list_arg(GetPadding(), "padding") + .call_arg(GetPadding(), "padding") .op_list_arg("dilation") .op_arg("groups") - .call_arg(use_bias_, "bias") - .op_end(); + .call_arg(use_bias_, "bias"); } private: @@ -256,11 +245,10 @@ class TorchCumsumCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.op_start() + stack_.op_call() .op_input_arg() .op_arg("axis", "dim") - .call_dtype_arg(node()->OutputAt(0)->dtype, "dtype") - .op_end(); + .op_dtype_arg(node()->OutputAt(0)->dtype, "dtype"); } }; @@ -270,10 +258,9 @@ class TorchEmbeddingCodeGen : public TorchOpCode { protected: void CodeGenInit() final { const auto& weight = node()->WeightAt("weight"); - stack_.op_start() + stack_.op_call() .call_arg(weight->DimAt("W"), "num_embeddings") - .call_arg(weight->DimAt("E"), "embedding_dim") - .op_end(); + .call_arg(weight->DimAt("E"), "embedding_dim"); } }; @@ -289,7 +276,7 @@ class TorchExpandDimsCodeGen : public TorchOpCode { if (i < axes.size() - 1) { idx_out = idx_out + "_" + std::to_string(i); } - stack_.op_start().call_arg(idx_input).call_arg(axes[i], "dim").op_end(); + stack_.op_call().call_arg(idx_input).call_arg(axes[i], "dim"); idx_input = idx_out; } } @@ -300,11 +287,10 @@ class TorchFullCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.op_start() + stack_.op_call() .op_list_arg("shape", "size") .op_input_arg(0, "fill_value") - .call_dtype_arg(node()->OutputAt(0)->dtype, "dtype") - .op_end(); + .op_dtype_arg(node()->OutputAt(0)->dtype, "dtype"); } }; @@ -325,11 +311,10 @@ class TorchGroupNormCodeGen : public TorchOpCode { ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) << "Only support center and scale batchnorm, get " << node(); int channel_axis = node()->GetTypeAttr("channel_axis"); - stack_.op_start() + stack_.op_call() .op_arg("num_groups") .call_arg(node()->InputAt(0)->DimAt(channel_axis), "num_channels") - .op_arg("epsilon", "eps") - .op_end(); + .op_arg("epsilon", "eps"); } }; @@ -346,10 +331,9 @@ class TorchLayerNormCodeGen : public TorchOpCode { for (const auto& a : axes) { normalized_shape.push_back(node()->InputAt(0)->DimAt(a)); } - stack_.op_start() - .call_list_arg(normalized_shape, "normalized_shape") - .op_arg("epsilon", "eps") - .op_end(); + stack_.op_call() + .call_arg(DocUtils::ToListDoc(normalized_shape), "normalized_shape") + .op_arg("epsilon", "eps"); } }; @@ -361,11 +345,10 @@ class TorchLinearCodeGen : public TorchOpCode { protected: void CodeGenInit() final { const auto& weight = node()->WeightAt("weight"); - stack_.op_start() + stack_.op_call() .call_arg(weight->DimAt("I"), "in_features") .call_arg(weight->DimAt("O"), "out_features") - .call_arg(use_bias_, "bias") - .op_end(); + .call_arg(use_bias_, "bias"); } private: @@ -377,11 +360,7 @@ class TorchNllLossCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.op_start() - .op_inputs_arg(false) - .op_str_arg("reduction") - .op_arg("ignore_index") - .op_end(); + stack_.op_call().op_inputs_arg(false).op_str_arg("reduction").op_arg("ignore_index"); } }; @@ -390,15 +369,29 @@ class TorchPoolCodeGen : public TorchOpCode { protected: void CodeGenInit() final { - stack_.op_start() + stack_.op_call() .op_list_arg("pool_size", "kernel_size") .op_list_arg("strides", "stride") - .call_list_arg(GetPadding(), "padding") + .call_arg(GetPadding(), "padding") .op_arg("ceil_mode"); if (node()->optype == "nn.max_pool2d") { stack_.op_list_arg("dilation"); } - stack_.op_end(); + } +}; + +class TorchPermuteDimsCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchPermuteDimsCodeGen) + + protected: + void CodeGenForward() final { + std::vector axes; + if (!node()->GetAttr("axes", &axes)) { + for (size_t i = node()->InputAt(0)->Ndim(); i > 0; i--) { + axes.push_back(i - 1); + } + } + stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(axes)); } }; @@ -409,13 +402,13 @@ class TorchReduceAxisCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.op_start().op_input_arg(); + stack_.op_call().op_input_arg(); if (as_list_) { stack_.op_list_arg("axis", "dim"); } else { stack_.op_arg("axis", "dim"); } - stack_.op_arg("keepdims", "keepdim").op_end(); + stack_.op_arg("keepdims", "keepdim"); } private: @@ -438,9 +431,8 @@ class TorchRepeatCodeGen : public TorchOpCode { } } stack_.assign(IdxNode(), IdxInput()) - .inplace_start("repeat") - .call_list_arg(repeats, "", false, true) - .inplace_end(); + .method_call("repeat") + .call_arg(DocUtils::ToListDoc(repeats), ""); } }; @@ -457,7 +449,7 @@ class TorchReshapeCodeGen : public TorchOpCode { shape[batch_dim] = -1; } } - stack_.op_start().op_input_arg().call_list_arg(shape).op_end(); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(shape)); } }; @@ -473,11 +465,8 @@ class TorchResize2dCodeGen : public TorchOpCode { } else { LOG(FATAL) << "Unexpected resize2d method " << method; } - stack_.op_start() - .op_input_arg() - .op_list_arg("size") - .call_str_arg(v_method, "mode") - .op_end(); + stack_.op_call().op_input_arg().op_list_arg("size").call_arg(DocUtils::ToStrDoc(v_method), + "mode"); } }; @@ -487,9 +476,9 @@ class TorchShapeCodeGen : public TorchOpCode { protected: void CodeGenForward() final { if (node()->inputs.size() == 0) { - stack_.op_start().op_list_arg("shape", "").op_end(); + stack_.op_call().op_list_arg("shape", ""); } else { - stack_.assign(IdxNode(), IdxInput()).inplace_start("size").inplace_end(); + stack_.assign(IdxNode(), IdxInput()).method_call("size"); } } }; @@ -503,13 +492,14 @@ class TorchSplitCodeGen : public TorchOpCode { protected: void CodeGenForward() final { - stack_.op_start().op_input_arg(); + stack_.op_call().op_input_arg(); std::vector indices; int axis = node()->GetTypeAttr("axis"); for (size_t i = 0; i < node()->outputs.size(); i++) { indices.push_back(node()->OutputAt(i)->DimAt(axis)->value); } - stack_.call_list_arg(indices, "split_size_or_sections").op_arg("axis", "dim").op_end(); + stack_.call_arg(DocUtils::ToListDoc(indices), "split_size_or_sections") + .op_arg("axis", "dim"); } }; @@ -536,7 +526,7 @@ class TorchStridedSliceCodeGen : public TorchOpCode { slice.push_back(":"); } } - stack_.assign_index(IdxNode(), IdxInput(), slice); + stack_.assign(IdxNode(), DocUtils::ToIndexDoc(IdxInput(), slice)); } }; @@ -544,16 +534,14 @@ class TorchTriCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchTriCodeGen) protected: - void CodeGenForward() final { - stack_.op_start().op_input_arg().op_arg("k", "diagonal").op_end(); - } + void CodeGenForward() final { stack_.op_call().op_input_arg().op_arg("k", "diagonal"); } }; class TorchTupleCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchTupleCodeGen) protected: - void CodeGenForward() final { stack_.op_start().op_inputs_arg().op_end(); } + void CodeGenForward() final { stack_.op_call().op_inputs_arg(); } }; const std::shared_ptr>> GetTorchOpCodes() { @@ -623,7 +611,6 @@ const std::shared_ptr>> std::make_shared("nn.LogSoftmax", "functional.log_softmax")); map->emplace("nn.softmax", std::make_shared("nn.Softmax", "functional.softmax")); - map->emplace("permute_dims", std::make_shared("", "torch.permute")); map->emplace("squeeze", std::make_shared("", "torch.squeeze")); // math ops @@ -632,6 +619,7 @@ const std::shared_ptr>> map->emplace("clip", std::make_shared("", "torch.clamp")); map->emplace("cumsum", std::make_shared("", "torch.cumsum")); map->emplace("expand_dims", std::make_shared("", "torch.unsqueeze")); + map->emplace("permute_dims", std::make_shared("", "torch.permute")); map->emplace("repeat", std::make_shared("", "repeat")); map->emplace("reshape", std::make_shared("", "torch.reshape")); map->emplace("split", std::make_shared("", "torch.split")); diff --git a/src/contrib/msc/framework/torch/torch_opcode.h b/src/contrib/msc/framework/torch/torch_opcode.h index e35ad8c5ca53..4f29a3c64c5b 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.h +++ b/src/contrib/msc/framework/torch/torch_opcode.h @@ -96,7 +96,7 @@ class TorchOpCode : public BaseOpCode { virtual void CodeGenForward(); /*! \brief Get the padding from op*/ - const std::vector GetPadding(const String& key = "padding"); + const StrictListDoc GetPadding(const String& key = "padding"); /*! \brief Get the is_init_ of codegen*/ bool is_init() { return is_init_; } diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index f68cd638d39a..7f0270d6b4ef 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -36,31 +36,29 @@ void RelaxCodeGen::CodeGenGraph() { Array idx_inputs; for (const auto& i : graph()->GetInputs()) { const auto& pair = graph()->FindProducerAndIdx(i); - const auto& idx_input = IdxOutput(pair.first, pair.second); + const auto& idx_input = IdxOutputBase(pair.first, pair.second); stack_.func_arg(idx_input, "relax.Var"); idx_inputs.push_back(idx_input); } - stack_.func_start().assign_list("inputs", idx_inputs); + stack_.func_start().assign("inputs", DocUtils::ToListDoc(idx_inputs, true)); // define weights stack_.comment("Define the weights and constant"); for (const auto& n : graph()->node_names) { const auto& node = graph()->FindNode(n); for (const auto& pair : node->weights) { - const auto& idx_weight = IdxWeight(node, pair.first); - stack_.call_start("relax.Var") - .call_str_arg(pair.second->name) - .call_inplace_start("relax.TensorStructInfo") - .call_list_arg(pair.second->shape, "", true) - .call_str_arg(pair.second->DTypeName()) - .call_inplace_end() - .call_end(idx_weight) - .call_start("inputs.append") - .call_arg(idx_weight) - .call_end(); + const auto& idx_weight = IdxWeightBase(node, pair.first); + stack_.func_call("relax.Var", idx_weight) + .call_arg(DocUtils::ToStrDoc(pair.second->name)) + .func_call("relax.TensorStructInfo") + .call_arg(DocUtils::ToListDoc(pair.second->shape, true), "") + .call_arg(DocUtils::ToStrDoc(pair.second->DTypeName())) + .pop_nest() + .func_call("inputs.append") + .call_arg(idx_weight); } if (node->optype == "constant") { CodeGenNode(node); - stack_.call_start("inputs.append").call_arg(IdxNode(node)).call_end(); + stack_.func_call("inputs.append").call_arg(IdxNodeBase(node)); } } stack_.comment("Define the module"); @@ -92,36 +90,34 @@ void RelaxCodeGen::CodeGenGraph() { stack_.comment("Emit the outputs"); Array idx_exits; for (const auto& e : graph()->GetExits()) { - const auto& idx_exit = IdxNode(e, false); - stack_.call_start("block_builder.emit_output").call_arg(idx_exit).call_end(idx_exit); + const auto& idx_exit = IdxNodeBase(e, false); + stack_.func_call("block_builder.emit_output", idx_exit).call_arg(idx_exit); idx_exits.push_back(idx_exit); } - stack_.scope_end().call_start("block_builder.emit_func_output"); + stack_.scope_end().func_call("block_builder.emit_func_output"); if (idx_exits.size() == 1) { stack_.call_arg(idx_exits[0]); } else { - stack_.call_list_arg(idx_exits); + stack_.call_arg(DocUtils::ToListDoc(idx_exits)); } - stack_.call_end().scope_end().assign("mod", "block_builder.get()").func_end("mod"); + stack_.scope_end().assign("mod", "block_builder.get()").func_end("mod"); } void RelaxCodeGen::CodeGenInference() { for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); - stack_.call_start("relax.Var") - .call_str_arg(i->alias) - .call_inplace_start("relax.TensorStructInfo") - .call_list_arg(i->shape) - .call_str_arg(i->DTypeName()) - .call_inplace_end() - .call_end(IdxNode(producer)); + stack_.func_call("relax.Var", IdxNodeBase(producer)) + .call_arg(DocUtils::ToStrDoc(i->alias)) + .func_call("relax.TensorStructInfo") + .call_arg(DocUtils::ToListDoc(i->shape)) + .call_arg(DocUtils::ToStrDoc(i->DTypeName())) + .pop_nest(); } - stack_.comment("Build Module").call_start(graph()->name); + stack_.comment("Build Module").func_call(graph()->name, "mod"); for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); - stack_.call_arg(IdxNode(producer)); + stack_.call_arg(IdxNodeBase(producer)); } - stack_.call_end("mod"); String target, device; if (config()->test_device == "cpu") { target = "llvm"; @@ -132,33 +128,30 @@ void RelaxCodeGen::CodeGenInference() { } stack_.comment("Load weights") .scope_start("open(\"" + graph()->name + "_params.bin\", \"rb\")", "f") - .call_start("tvm.runtime.load_param_dict") + .func_call("tvm.runtime.load_param_dict", "params") .call_arg("f.read()") - .call_end("params") .scope_end() - .call_start("tvm.relax.transform.BindParams") - .call_str_arg("main") + .func_call("tvm.relax.transform.BindParams", "bind_params") + .call_arg(DocUtils::ToStrDoc("main")) .call_arg("params") - .call_end("bind_params") - .call_start("bind_params") + .func_call("bind_params", "mod") .call_arg("mod") - .call_end("mod") - .call_start("tvm.target.Target") - .call_str_arg(target) - .call_end("target") - .call_start("relax.build") + .func_call("tvm.target.Target", "target") + .call_arg(DocUtils::ToStrDoc(target)) + .func_call("tvm.relax.transform.LegalizeOps()", "mod") + .call_arg("mod") + .scope_start("tvm.transform.PassContext(opt_level=3)") + .func_call("relax.build", "ex") .call_arg("mod") .call_arg("target") - .call_end("ex") - .call_start("relax.VirtualMachine") + .func_call("relax.VirtualMachine", "vm") .call_arg("ex") .call_arg(device) - .call_end("vm") - .call_start("vm[\"main\"]"); + .scope_end() + .func_call("vm[\"main\"]", "outputs"); for (const auto& i : graph()->GetInputs()) { stack_.call_arg("inputs[\"" + i->alias + "\"]"); } - stack_.call_end("outputs"); } const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index fe5c0c7591d7..53a18b8cb221 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -47,19 +47,18 @@ const Array RelaxOpCode::GetDocs() { } void RelaxOpCode::BuilderEmit(const String& ret, const String& name) { - stack_.call_start("block_builder.emit").call_arg(ret); + stack_.func_call("block_builder.emit", ret).call_arg(ret); if (name.size() > 0) { - stack_.call_str_arg(name, "name_hint"); + stack_.call_arg(DocUtils::ToStrDoc(name), "name_hint"); } - stack_.call_end(ret); } -const std::string RelaxOpCode::GetOutDtype(const String& key) { +const ExprDoc RelaxOpCode::GetOutDtype(const String& key) { std::string out_dtype; if (!node()->GetAttr(key, &out_dtype) && config()->from_relay) { - return node()->OutputAt(0)->DTypeName(); + return DocUtils::ToStrDoc(node()->OutputAt(0)->DTypeName()); } - return out_dtype; + return DocUtils::ToStrDoc(out_dtype); } const std::vector RelaxOpCode::GetAxes(const String& key) { @@ -77,83 +76,119 @@ const std::vector RelaxOpCode::GetAxes(const String& key) { class RelaxAdaptivePool2dCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxAdaptivePool2dCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start() + stack_.op_call() .op_input_arg() .op_list_arg("output_size") .op_str_arg("layout") - .op_str_arg("out_layout") - .op_end(); + .op_str_arg("out_layout"); } }; class RelaxAstypeCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxAstypeCodeGen) + protected: - void CodeGenBuild() final { stack_.op_start().op_input_arg().op_str_arg("dtype").op_end(); } + void CodeGenBuild() final { stack_.op_call().op_input_arg().op_str_arg("dtype"); } }; class RelaxAttentionCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxAttentionCodeGen) + protected: void CodeGenBuild() final { for (size_t i = 0; i < 3; i++) { const String& axes_key = i == 0 ? "axes" : "axes_" + std::to_string(i); - stack_.op_start("relax.op.permute_dims") + stack_.op_call("relax.op.permute_dims", IdxInput(i)) .op_input_arg(i) - .op_list_arg(axes_key, "axes") - .op_end(IdxInput(i)); + .op_list_arg(axes_key, "axes"); } - stack_.op_start() - .op_inputs_arg(false) - .op_arg("scale") - .op_str_arg("causal_mask") - .op_end(); + stack_.op_call().op_inputs_arg(false).op_arg("scale").op_str_arg("causal_mask"); } }; class RelaxAxisCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxAxisCodeGen) + protected: void CodeGenBuild() final { std::vector axes = GetAxes("axis"); - stack_.op_start().op_input_arg(); + stack_.op_call().op_input_arg(); if (axes.size() > 0) { stack_.call_arg(axes[0], "axis"); } - stack_.op_end(); } }; class RelaxAxesCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxAxesCodeGen) + protected: void CodeGenBuild() final { const String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_start().op_input_arg().call_list_arg(GetAxes(key), key).op_end(); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(GetAxes(key)), key); } }; class RelaxBatchMatmulCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxBatchMatmulCodeGen) + protected: void CodeGenBuild() final { bool transpose_a = node()->GetTypeAttr("transpose_a"); bool transpose_b = node()->GetTypeAttr("transpose_b"); if (!transpose_a && !transpose_b) { - stack_.op_start().op_inputs_arg(false).op_str_arg("out_dtype").op_end(); + stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); + } else if (transpose_a && !transpose_b) { + std::vector axes; + for (size_t i = 0; i < node()->InputAt(0)->Ndim() - 2; i++) { + axes.push_back(i); + } + axes.push_back(node()->InputAt(0)->Ndim() - 1); + axes.push_back(node()->InputAt(0)->Ndim() - 2); + stack_.op_call("relax.op.permute_dims", IdxInput(0)) + .op_input_arg() + .call_arg(DocUtils::ToListDoc(axes)); + BuilderEmit(IdxInput(0)); + stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); + } else if (!transpose_a && transpose_b) { + std::vector axes; + for (size_t i = 0; i < node()->InputAt(1)->Ndim() - 2; i++) { + axes.push_back(i); + } + axes.push_back(node()->InputAt(1)->Ndim() - 1); + axes.push_back(node()->InputAt(1)->Ndim() - 2); + stack_.op_call("relax.op.permute_dims", IdxInput(1)) + .op_input_arg(1) + .call_arg(DocUtils::ToListDoc(axes)); + BuilderEmit(IdxInput(1)); + stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); } else { - LOG(FATAL) << "Unexpected nn.batch_matmul " << node(); + for (size_t idx = 0; idx < 2; idx++) { + std::vector axes; + for (size_t i = 0; i < node()->InputAt(idx)->Ndim() - 2; i++) { + axes.push_back(i); + } + axes.push_back(node()->InputAt(idx)->Ndim() - 1); + axes.push_back(node()->InputAt(idx)->Ndim() - 2); + stack_.op_call("relax.op.permute_dims", IdxInput(idx)) + .op_input_arg(idx) + .call_arg(DocUtils::ToListDoc(axes)); + BuilderEmit(IdxInput(idx)); + } + stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); } } }; class RelaxBatchNormCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxBatchNormCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start() + stack_.op_call() .op_input_arg() .op_weight_arg("gamma") .op_weight_arg("beta") @@ -163,16 +198,16 @@ class RelaxBatchNormCodeGen : public RelaxOpCode { .op_arg("epsilon") .op_arg("center") .op_arg("scale") - .op_arg("momentum") - .op_end(); + .op_arg("momentum"); } }; class RelaxBiasAddCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxBiasAddCodeGen) + protected: void CodeGenBuild() final { - int axis = node()->GetTypeAttr("axis"); + int axis = CommonUtils::GetIndex(node()->GetTypeAttr("axis"), node()->OutputAt(0)->Ndim()); Array expand_shape; for (size_t i = 0; i < node()->InputAt(0)->Ndim(); i++) { if (i == static_cast(axis)) { @@ -181,47 +216,53 @@ class RelaxBiasAddCodeGen : public RelaxOpCode { expand_shape.push_back(Integer(1)); } } - stack_.op_start("relax.op.reshape") + stack_.op_call("relax.op.reshape", IdxInput(1)) .op_input_arg(1) - .call_list_arg(expand_shape, "shape") - .call_end(IdxInput(1)); + .call_arg(DocUtils::ToListDoc(expand_shape), "shape"); BuilderEmit(IdxInput(1)); - stack_.op_start().op_inputs_arg(false).op_end(); + stack_.op_call().op_inputs_arg(false); } }; class RelaxBroadcastToCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxBroadcastToCodeGen) + protected: - void CodeGenBuild() final { stack_.op_start().op_input_arg().op_list_arg("shape").op_end(); } + void CodeGenBuild() final { stack_.op_call().op_input_arg().op_list_arg("shape"); } }; class RelaxClipCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxClipCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg(); + stack_.op_call().op_input_arg(); if (config()->from_relay) { stack_.op_arg("a_min", "min").op_arg("a_max", "max"); } else { stack_.op_arg("min").op_arg("max"); } - stack_.op_end(); } }; +class RelaxConcatCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxConcatCodeGen) + + protected: + void CodeGenBuild() final { stack_.op_call().op_inputs_arg().op_arg("axis"); } +}; + class RelaxConstantCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxConstantCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start() - .call_str_arg(node()->name) - .call_inplace_start("relax.TensorStructInfo") - .call_list_arg(node()->OutputAt(0)->shape, "", true) - .call_str_arg(node()->OutputAt(0)->DTypeName()) - .call_inplace_end() - .call_end() - .op_end(); + stack_.op_call() + .op_name_arg("") + .func_call("relax.TensorStructInfo") + .call_arg(DocUtils::ToListDoc(node()->OutputAt(0)->shape, true), "") + .call_arg(DocUtils::ToStrDoc(node()->OutputAt(0)->DTypeName())) + .pop_nest(); } }; @@ -232,7 +273,7 @@ class RelaxConvCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - stack_.op_start() + stack_.op_call() .op_input_arg() .op_weight_arg("weight") .op_list_arg("strides") @@ -242,8 +283,7 @@ class RelaxConvCodeGen : public RelaxOpCode { .op_str_arg("data_layout") .op_str_arg("kernel_layout") .op_str_arg("out_layout") - .call_str_arg(GetOutDtype(), "out_dtype") - .op_end(); + .call_arg(GetOutDtype(), "out_dtype"); if (use_bias_) { std::string out_layout_str; if (!node()->GetAttr("out_layout", &out_layout_str)) { @@ -260,15 +300,11 @@ class RelaxConvCodeGen : public RelaxOpCode { } } BuilderEmit(IdxNode()); - stack_.call_start("relax.op.reshape") - .call_arg(IdxWeight("bias", true)) - .call_list_arg(expand_shape, "shape") - .call_end("expand_bias"); + stack_.func_call("relax.op.reshape", "expand_bias") + .op_weight_arg("bias") + .call_arg(DocUtils::ToListDoc(expand_shape), "shape"); BuilderEmit("expand_bias"); - stack_.call_start("relax.op.add") - .call_arg(IdxNode()) - .call_arg("expand_bias") - .call_end(IdxNode()); + stack_.func_call("relax.op.add", IdxNode()).call_arg(IdxNode()).call_arg("expand_bias"); } } @@ -278,22 +314,47 @@ class RelaxConvCodeGen : public RelaxOpCode { class RelaxCreateCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxCreateCodeGen) + protected: - void CodeGenBuild() final { - stack_.op_start().op_list_arg("shape").op_str_arg("dtype").op_end(); - } + void CodeGenBuild() final { stack_.op_call().op_list_arg("shape").op_str_arg("dtype"); } +}; + +class RelaxCreateLikeCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxCreateLikeCodeGen) + + protected: + void CodeGenBuild() final { stack_.op_call().op_input_arg().op_str_arg("dtype"); } }; class RelaxCumsumCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxCumsumCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg().op_arg("axis").op_str_arg("dtype").op_end(); + stack_.op_call().op_input_arg().op_arg("axis").op_str_arg("dtype"); + } +}; + +class RelaxEinsumCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxEinsumCodeGen) + + protected: + void CodeGenBuild() final { + const String& key = config()->from_relay ? "equation" : "subscripts"; + const auto& producer = node()->ProducerOf(0); + stack_.op_call(); + if (node()->inputs.size() == 1 && producer->optype == "tuple") { + stack_.op_input_arg(); + } else { + stack_.op_inputs_arg(); + } + stack_.op_str_arg(key, "subscripts"); } }; class RelaxStridedSliceCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxStridedSliceCodeGen) + protected: void CodeGenBuild() final { std::vector axes; @@ -302,13 +363,12 @@ class RelaxStridedSliceCodeGen : public RelaxOpCode { axes.push_back(i); } } - stack_.op_start() + stack_.op_call() .op_input_arg() - .call_list_arg(axes, "axes") + .call_arg(DocUtils::ToListDoc(axes), "axes") .op_list_arg("begin") .op_list_arg("end") - .op_list_arg("strides") - .op_end(); + .op_list_arg("strides"); } }; @@ -319,119 +379,141 @@ class RelaxEmbeddingCodeGen : public RelaxOpCode { void CodeGenBuild() final { const auto& input = node()->InputAt(0); if (input->DTypeName() != "int32") { - stack_.op_start("relax.op.astype").op_input_arg().call_str_arg("int32").op_end(IdxInput()); + stack_.op_call("relax.op.astype", IdxInput()) + .op_input_arg() + .call_arg(DocUtils::ToStrDoc("int32")); BuilderEmit(IdxInput()); } if (input->Ndim() > 1) { - stack_.op_start("relax.op.reshape") + stack_.op_call("relax.op.reshape", IdxInput()) .op_input_arg() - .call_list_arg(std::vector{-1}, "shape") - .op_end(IdxInput()); + .call_arg(DocUtils::ToListDoc(std::vector{-1}), "shape"); BuilderEmit(IdxInput()); } - stack_.op_start().op_weight_arg("weight").op_input_arg().op_arg("axis").op_end(); + stack_.op_call().op_weight_arg("weight").op_input_arg().op_arg("axis"); if (input->Ndim() > 1) { BuilderEmit(IdxNode()); - stack_.op_start("relax.op.reshape") + stack_.op_call("relax.op.reshape") .op_output_arg() - .call_list_arg(node()->OutputAt(0)->shape) - .op_end(); + .call_arg(DocUtils::ToListDoc(node()->OutputAt(0)->shape)); } } }; class RelaxFullCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxFullCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start() - .op_list_arg("shape") - .op_input_arg(0, "fill_value") - .op_str_arg("dtype") - .op_end(); + stack_.op_call().op_list_arg("shape").op_input_arg(0, "fill_value").op_str_arg("dtype"); } }; class RelaxGetItemCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxGetItemCodeGen) + protected: void CodeGenBuild() final { const auto& producer = node()->ProducerOf(0); - stack_.op_start().call_arg(IdxNode(producer)).op_arg("index").call_end(IdxNode()); + stack_.op_call("msc::auto", IdxNode()).call_arg(IdxNodeBase(producer)).op_arg("index"); } }; class RelaxGroupNormCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxGroupNormCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg().op_weight_arg("gamma").op_weight_arg("beta").op_arg( + stack_.op_call().op_input_arg().op_weight_arg("gamma").op_weight_arg("beta").op_arg( "num_groups"); if (config()->from_relay) { std::vector axes; for (size_t i = 2; i < node()->InputAt(0)->Ndim(); i++) { axes.push_back(i); } - stack_.op_arg("axis", "channel_axis").call_list_arg(axes, "axes"); + stack_.op_arg("axis", "channel_axis").call_arg(DocUtils::ToListDoc(axes), "axes"); } else { stack_.op_arg("channel_axis").op_list_arg("axes"); } - stack_.op_arg("epsilon").op_arg("center").op_arg("scale").op_end(); + stack_.op_arg("epsilon").op_arg("center").op_arg("scale"); } }; class RelaxLayerNormCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxLayerNormCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg().op_weight_arg("gamma").op_weight_arg("beta"); + stack_.op_call().op_input_arg().op_weight_arg("gamma").op_weight_arg("beta"); if (config()->from_relay) { stack_.op_arg("axis", "axes"); } else { stack_.op_list_arg("axes"); } - stack_.op_arg("epsilon").op_arg("center").op_arg("scale").op_end(); + stack_.op_arg("epsilon").op_arg("center").op_arg("scale"); } }; class RelaxLinearCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxLinearCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start() - .op_input_arg() - .op_weight_arg("weight") - .op_weight_arg("bias") - .call_str_arg(GetOutDtype(), "out_dtype") - .op_end(); + stack_.op_call(); + if (node()->inputs.size() == 1) { + stack_.op_input_arg().op_weight_arg("weight").op_weight_arg("bias"); + } else { + stack_.op_inputs_arg(false); + } + stack_.call_arg(GetOutDtype(), "out_dtype"); } }; class RelaxMatmulCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxMatmulCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_inputs_arg(false).call_str_arg(GetOutDtype(), "out_dtype").op_end(); + stack_.op_call().op_inputs_arg(false).call_arg(GetOutDtype(), "out_dtype"); } }; class RelaxNllLossCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxNllLossCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start() - .op_inputs_arg(false) - .op_str_arg("reduction") - .op_arg("ignore_index") - .op_end(); + stack_.op_call().op_inputs_arg(false).op_str_arg("reduction").op_arg("ignore_index"); + } +}; + +class RelaxPadCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxPadCodeGen) + + protected: + void CodeGenBuild() final { + Array pad_width; + const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); + ICHECK(attr_pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); + for (size_t i = 0; i < attr_pad_width.size(); i += 2) { + const String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + + std::to_string(attr_pad_width[i + 1]) + "]"; + pad_width.push_back(cur_pad); + } + stack_.op_call() + .op_input_arg() + .op_list_arg("pad_width") + .op_input_arg(1, "pad_value") + .op_str_arg("pad_mode"); } }; class RelaxPool2dCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxPool2dCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start() + stack_.op_call() .op_input_arg() .op_list_arg("pool_size") .op_list_arg("strides") @@ -439,8 +521,22 @@ class RelaxPool2dCodeGen : public RelaxOpCode { .op_list_arg("dilation") .op_arg("ceil_mode") .op_str_arg("layout") - .op_str_arg("out_layout") - .op_end(); + .op_str_arg("out_layout"); + } +}; + +class RelaxPermuteDimsCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxPermuteDimsCodeGen) + + protected: + void CodeGenBuild() final { + std::vector axes; + if (!node()->GetAttr("axes", &axes)) { + for (size_t i = node()->InputAt(0)->Ndim(); i > 0; i--) { + axes.push_back(i - 1); + } + } + stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(axes), "axes"); } }; @@ -451,14 +547,14 @@ class RelaxReduceAxisCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg(); + stack_.op_call().op_input_arg(); std::vector axes = GetAxes("axis"); if (as_list_) { - stack_.call_list_arg(axes, "axis"); + stack_.call_arg(DocUtils::ToListDoc(axes), "axis"); } else if (axes.size() > 0) { stack_.call_arg(axes[0], "axis"); } - stack_.op_arg("keepdims").op_end(); + stack_.op_arg("keepdims"); } private: @@ -467,23 +563,24 @@ class RelaxReduceAxisCodeGen : public RelaxOpCode { class RelaxRepeatCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxRepeatCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg().op_arg("repeats").op_arg("axis").op_end(); + stack_.op_call().op_input_arg().op_arg("repeats").op_arg("axis"); } }; class RelaxReshapeCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxReshapeCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg(); + stack_.op_call().op_input_arg(); if (config()->from_relay) { stack_.op_list_arg("newshape", "shape"); } else { stack_.op_list_arg("shape"); } - stack_.op_end(); } }; @@ -498,12 +595,12 @@ class RelaxResize2dCodeGen : public RelaxOpCode { for (const auto& r : roi) { roi_list.push_back("float(" + std::to_string(r) + ")"); } - stack_.op_start() + stack_.op_call() .op_input_arg() - .call_inplace_start("relax.ShapeExpr") + .func_call("relax.ShapeExpr") .op_list_arg("size", "values") - .call_inplace_end() - .call_list_arg(roi_list) + .pop_nest() + .call_arg(DocUtils::ToListDoc(roi_list)) .op_str_arg("layout") .op_str_arg("method") .op_str_arg("coordinate_transformation_mode") @@ -511,60 +608,75 @@ class RelaxResize2dCodeGen : public RelaxOpCode { .op_arg("cubic_alpha") .op_arg("cubic_exclude") .op_arg("extrapolation_value") - .op_str_arg("out_dtype") - .op_end(); + .op_str_arg("out_dtype"); } }; class RelaxShapeCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxShapeCodeGen) + protected: - void CodeGenBuild() final { stack_.op_start().op_list_arg("shape", "values").op_end(); } + void CodeGenBuild() final { stack_.op_call().op_list_arg("shape", "values"); } }; class RelaxSimpleCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxSimpleCodeGen) + protected: - void CodeGenBuild() final { stack_.op_start().op_inputs_arg(false).op_end(); } + void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false); } }; class RelaxSplitCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxSplitCodeGen) + protected: void CodeGenBuild() final { - stack_.op_start().op_input_arg(); + stack_.op_call().op_input_arg(); int sections; if (node()->GetAttr("indices_or_sections", §ions)) { stack_.op_arg("indices_or_sections"); } else { stack_.op_list_arg("indices_or_sections"); } - stack_.op_arg("axis").op_end(); + stack_.op_arg("axis"); } }; class RelaxTakeCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxTakeCodeGen) + protected: - void CodeGenBuild() final { stack_.op_start().op_inputs_arg(false).op_arg("axis").op_end(); } + void CodeGenBuild() final { stack_.op_call().op_inputs_arg(false).op_arg("axis"); } +}; + +class RelaxTileCodeGen : public RelaxOpCode { + RELAX_OP_CODEGEN_METHODS(RelaxTileCodeGen) + + protected: + void CodeGenBuild() final { + const String& key = config()->from_relay ? "reps" : "repeats"; + stack_.op_call().op_input_arg().op_list_arg(key, "repeats"); + } }; class RelaxTupleCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxTupleCodeGen) + protected: - void CodeGenBuild() final { stack_.op_start().op_inputs_arg().op_end(); } + void CodeGenBuild() final { stack_.op_call().op_inputs_arg(); } }; class RelaxTriCodeGen : public RelaxOpCode { RELAX_OP_CODEGEN_METHODS(RelaxTriCodeGen) + protected: void CodeGenBuild() final { if (node()->optype == "trilu") { const String& func_name = node()->GetTypeAttr("upper") ? "relax.op.triu" : "relax.op.tril"; - stack_.op_start(func_name).op_input_arg().op_arg("k").op_end(); + stack_.op_call(func_name).op_input_arg().op_arg("k"); } else { - stack_.op_start().op_input_arg().op_arg("k").op_end(); + stack_.op_call().op_input_arg().op_arg("k"); } } }; @@ -589,18 +701,23 @@ const std::shared_ptr>> map->emplace("cos", std::make_shared("relax.op.cos")); map->emplace("cosh", std::make_shared("relax.op.cosh")); map->emplace("divide", std::make_shared("relax.op.divide")); - map->emplace("exp", std::make_shared("relax.op.exp")); map->emplace("equal", std::make_shared("relax.op.equal")); + map->emplace("erf", std::make_shared("relax.op.erf")); + map->emplace("exp", std::make_shared("relax.op.exp")); map->emplace("floor", std::make_shared("relax.op.floor")); map->emplace("floor_divide", std::make_shared("relax.op.floor_divide")); map->emplace("greater", std::make_shared("relax.op.greater")); map->emplace("greater_equal", std::make_shared("relax.op.greater_equal")); + map->emplace("isfinite", std::make_shared("relax.op.isfinite")); + map->emplace("isinf", std::make_shared("relax.op.isinf")); + map->emplace("isnan", std::make_shared("relax.op.isnan")); map->emplace("less", std::make_shared("relax.op.less")); map->emplace("less_equal", std::make_shared("relax.op.less_equal")); map->emplace("log", std::make_shared("relax.op.log")); map->emplace("logical_and", std::make_shared("relax.op.logical_and")); map->emplace("logical_or", std::make_shared("relax.op.logical_or")); map->emplace("logical_xor", std::make_shared("relax.op.logical_xor")); + map->emplace("logical_not", std::make_shared("relax.op.logical_not")); map->emplace("maximum", std::make_shared("relax.op.maximum")); map->emplace("minimum", std::make_shared("relax.op.minimum")); map->emplace("multiply", std::make_shared("relax.op.multiply")); @@ -618,6 +735,7 @@ const std::shared_ptr>> map->emplace("subtract", std::make_shared("relax.op.subtract")); map->emplace("tan", std::make_shared("relax.op.tan")); map->emplace("tanh", std::make_shared("relax.op.tanh")); + map->emplace("where", std::make_shared("relax.op.where")); // reduce axis ops map->emplace("argmax", std::make_shared("relax.op.argmax", false)); @@ -633,32 +751,38 @@ const std::shared_ptr>> map->emplace("nn.log_softmax", std::make_shared("relax.op.nn.log_softmax")); map->emplace("nn.softmax", std::make_shared("relax.op.nn.softmax")); map->emplace("expand_dims", std::make_shared("relax.op.expand_dims")); - map->emplace("permute_dims", std::make_shared("relax.op.permute_dims")); map->emplace("squeeze", std::make_shared("relax.op.squeeze")); - map->emplace("transpose", std::make_shared("relax.op.permute_dims")); // math ops map->emplace("astype", std::make_shared("relax.op.astype")); map->emplace("broadcast_to", std::make_shared("relax.op.broadcast_to")); map->emplace("cast", std::make_shared("relax.op.astype")); map->emplace("clip", std::make_shared("relax.op.clip")); + map->emplace("concat", std::make_shared("relax.op.concat")); + map->emplace("concatenate", std::make_shared("relax.op.concat")); map->emplace("cumsum", std::make_shared("relax.op.cumsum")); + map->emplace("einsum", std::make_shared("relax.op.einsum")); map->emplace("matmul", std::make_shared("relax.op.linear_algebra.matmul")); + map->emplace("permute_dims", std::make_shared("relax.op.permute_dims")); map->emplace("repeat", std::make_shared("relax.op.repeat")); map->emplace("reshape", std::make_shared("relax.op.reshape")); map->emplace("split", std::make_shared("relax.op.split")); map->emplace("strided_slice", std::make_shared("relax.op.strided_slice")); map->emplace("take", std::make_shared("relax.op.take")); + map->emplace("tile", std::make_shared("relax.op.tile")); + map->emplace("transpose", std::make_shared("relax.op.permute_dims")); // create ops map->emplace("constant", std::make_shared("relax.Var")); map->emplace("full", std::make_shared("relax.op.full")); map->emplace("ones", std::make_shared("relax.op.ones")); + map->emplace("ones_like", std::make_shared("relax.op.ones_like")); map->emplace("tril", std::make_shared("relax.op.tril")); map->emplace("triu", std::make_shared("relax.op.triu")); map->emplace("trilu", std::make_shared("")); map->emplace("zeros", std::make_shared("relax.op.zeros")); + map->emplace("zeros_like", std::make_shared("relax.op.zeros_like")); // nn ops map->emplace("nn.adaptive_avg_pool2d", @@ -670,11 +794,13 @@ const std::shared_ptr>> map->emplace("nn.bias_add", std::make_shared("relax.op.add")); map->emplace("nn.conv1d", std::make_shared("relax.op.nn.conv1d", false)); map->emplace("nn.conv2d", std::make_shared("relax.op.nn.conv2d", false)); + map->emplace("nn.dense", std::make_shared("relax.op.linear_algebra.linear")); map->emplace("nn.gelu", std::make_shared("relax.op.nn.gelu")); map->emplace("nn.group_norm", std::make_shared("relax.op.nn.group_norm")); map->emplace("nn.layer_norm", std::make_shared("relax.op.nn.layer_norm")); map->emplace("nn.max_pool2d", std::make_shared("relax.op.nn.max_pool2d")); map->emplace("nn.nll_loss", std::make_shared("relax.op.nn.nll_loss")); + map->emplace("nn.pad", std::make_shared("relax.op.nn.pad")); map->emplace("nn.relu", std::make_shared("relax.op.nn.relu")); map->emplace("nn.silu", std::make_shared("relax.op.nn.silu")); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.h b/src/contrib/msc/framework/tvm/relax_opcode.h index 9efc61626e78..b7471072f587 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.h +++ b/src/contrib/msc/framework/tvm/relax_opcode.h @@ -64,7 +64,7 @@ class RelaxOpCode : public BaseOpCode { void BuilderEmit(const String& ret, const String& name = ""); /*! \brief Get the out_dtype attribute*/ - const std::string GetOutDtype(const String& key = "out_dtype"); + const ExprDoc GetOutDtype(const String& key = "out_dtype"); /*! \brief Get the axes attribute*/ const std::vector GetAxes(const String& key = "axes"); diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 6e410e584c5b..aea8e97ddcb3 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -1416,8 +1416,8 @@ def forward(self, data): expected = { "inputs": [{"name": "inp_0", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "full", "shape": [10, 10], "dtype": "float32", "layout": ""}], - "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1}, + "outputs": [{"name": "const", "shape": [10, 10], "dtype": "float32", "layout": ""}], + "nodes": {"total": 2, "input": 1, "constant": 1}, } verify_model(InplaceFill(), [([10, 10], "float32")], expected) @@ -1537,8 +1537,8 @@ def forward(self, x): expected = { "inputs": [{"name": "inp_0", "shape": [1, 2, 3], "dtype": "float32", "layout": ""}], - "outputs": [{"name": "full", "shape": [1, 2, 3], "dtype": "float32", "layout": ""}], - "nodes": {"total": 3, "input": 1, "constant": 1, "full": 1}, + "outputs": [{"name": "const", "shape": [1, 2, 3], "dtype": "float32", "layout": ""}], + "nodes": {"total": 2, "input": 1, "constant": 1}, } input_info = [([1, 2, 3], "float32")] diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 72f58c40f755..fcf12947f01f 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -27,11 +27,11 @@ from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen -def verify_model(torch_model, input_info): +def verify_model(torch_model, input_info, opt_config=None): graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): expected = from_fx(graph_model, input_info) - graph, weights = translate.from_relax(expected) + graph, weights = translate.from_relax(expected, opt_config=opt_config) mod = tvm_codegen.to_relax(graph, weights, codegen_config={"explicit_name": False}) tvm.ir.assert_structural_equal(mod, expected) @@ -753,7 +753,7 @@ def forward(self, data): data.fill_(1.5) return data - verify_model(InplaceFill(), [([10, 10], "float32")]) + verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) def test_arange(): @@ -833,7 +833,7 @@ def forward(self, x): return x.new_ones(1, 2, 3) input_info = [([1, 2, 3], "float32")] - verify_model(NewOnes(), input_info) + verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) def test_expand(): diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index 70f5e8cda548..45441b69d72d 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -158,9 +158,9 @@ def forward(self, x, y): return torch.matmul(x, y) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dense1(), input_info) - verify_model(Dense2(), input_info) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) + verify_model(Dense1(), input_info, build_target="llvm") + verify_model(Dense2(), input_info, build_target="llvm") + verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")], build_target="llvm") def test_bmm():