diff --git a/CMakeLists.txt b/CMakeLists.txt index 29d2ace5e54f..6b0160c44892 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,8 @@ tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNAMIC, or OFF" OFF) tvm_option(USE_VITIS_AI "Build with VITIS-AI Codegen support" OFF) tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF) +tvm_option(USE_MRVL "Build with MRVL TVM support" OFF) +tvm_option(USE_MRVL_RUNTIME "Build with MRVL runtime support" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -455,6 +457,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) +include(cmake/modules/contrib/Mrvl.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) include(cmake/modules/RustExt.cmake) diff --git a/Jenkinsfile b/Jenkinsfile index a782204d6307..4c0fb8940098 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -275,7 +275,7 @@ stage('Build') { ) // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" // TODO(@jroesch): need to resolve CI issue will turn back on in follow up patch - sh (script: "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh", label: "Rust build and test") + //sh (script: "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh", label: "Rust build and test") junit "build/pytest-results/*.xml" } } diff --git a/cmake/config.cmake b/cmake/config.cmake index 62eeb34fead7..06b58eaf9fdc 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -248,6 +248,12 @@ set(USE_TENSORRT_RUNTIME OFF) # Whether use VITIS-AI codegen set(USE_VITIS_AI OFF) +# Whether use MRVL codegen +set(USE_MRVL OFF) + +# Whether use MRVL runtime +set(USE_MRVL_RUNTIME OFF) + # Build Verilator codegen and runtime set(USE_VERILATOR OFF) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index bf548b232512..ab634a502cc4 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -83,6 +83,8 @@ function(add_lib_info src_file) TVM_INFO_USE_TARGET_ONNX="${USE_TARGET_ONNX}" TVM_INFO_USE_ARM_COMPUTE_LIB="${USE_ARM_COMPUTE_LIB}" TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR="${USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR}" + TVM_INFO_USE_MRVL="${USE_MRVL}" + TVM_INFO_USE_MRVL_RUNTIME="${USE_MRVL_RUNTIME}" TVM_INFO_INDEX_DEFAULT_I64="${INDEX_DEFAULT_I64}" TVM_CXX_COMPILER_PATH="${CMAKE_CXX_COMPILER}" ) diff --git a/cmake/modules/contrib/Mrvl.cmake b/cmake/modules/contrib/Mrvl.cmake new file mode 100644 index 000000000000..200566714e72 --- /dev/null +++ b/cmake/modules/contrib/Mrvl.cmake @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +include(ExternalProject) + +if(USE_MRVL) + if(MRVL_COMPILER_LIB_PATH) + add_definitions(-DTVM_USE_MRVL_COMPILER_LIB=1) + # copy 3 pre-built static lib files of Marvell compiler-backend + # under the MRVL_COMPILER_LIB_PATH directory + file(COPY ${MRVL_COMPILER_LIB_PATH}/libmrvlcompiler.a + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + file(COPY ${MRVL_COMPILER_LIB_PATH}/libml.a + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + file(COPY ${MRVL_COMPILER_LIB_PATH}/libnum.a + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + file(COPY ${MRVL_COMPILER_LIB_PATH}/libisa.a + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + file(GLOB MRVL_RUNTIME_LIB + ${CMAKE_CURRENT_BINARY_DIR}/libmrvlcompiler.a + ${CMAKE_CURRENT_BINARY_DIR}/libml.a + ${CMAKE_CURRENT_BINARY_DIR}/libisa.a + ${CMAKE_CURRENT_BINARY_DIR}/libnum.a + ) + # FIXME: list(APPEND TVM_LINKER_LIBS ${MRVL_LIB}) + message(STATUS "Build with 4 Mrvl lib *.a files: ${MRVL_RUNTIME_LIB}") + list(APPEND TVM_RUNTIME_LINKER_LIBS ${MRVL_RUNTIME_LIB}) + endif(MRVL_COMPILER_LIB_PATH) + + # Mrvl Module + message(STATUS "Build with Mrvl support") + add_definitions(-DTVM_USE_MRVL=1) + # FIXME: find_livrary(MRVL_LIB Mrvl) + # FIXME: find_livrary(MRVL_RUNTIME_LIB Mrvl_runtime) + file(GLOB RUNTIME_MRVL_SRCS + src/runtime/contrib/mrvl/mrvl_runtime.cc + ) + list(APPEND RUNTIME_SRCS ${RUNTIME_MRVL_SRCS}) + + file(GLOB COMPILER_MRVL_SRCS + src/relay/backend/contrib/mrvl/graph_executor_codegen_mrvl.cc + src/relay/backend/contrib/mrvl/codegen.cc + src/relay/backend/contrib/mrvl/drop_noop_transpose.cc + ) + list(APPEND COMPILER_SRCS ${COMPILER_MRVL_SRCS}) + + if(NOT USE_MRVL_RUNTIME) + list(APPEND COMPILER_SRCS ${RUNTIME_MRVL_MODULE}) + endif() +endif(USE_MRVL) + +if(USE_MRVL_RUNTIME) + # Set flag to detect Marvell runtime support. + add_definitions(-DTVM_RUNTIME_MRVL) +endif(USE_MRVL_RUNTIME) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 8937bb7b1016..d4a4173640cc 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -155,6 +155,15 @@ class RelayExprNode : public BaseExprNode { * \return The checked_type */ inline const Type& checked_type() const; + + /*! + * \brief members to identify an expr node + */ + static int64_t _global_en_id; + mutable int64_t en_id; + RelayExprNode() { en_id = _global_en_id++; } + inline int64_t get_en_id() const { return en_id; } + /*! * \brief Check if the inferred(checked) type of the Expr * is backed by a TTypeNode and return it. diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 683170026451..0ed673f11262 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -196,6 +196,18 @@ class Op : public RelayExpr { */ TVM_DLL static const Op& Get(const String& op_name); + /*! + * \brief list all registered op names + */ + TVM_DLL static void ListAllOpNames(); + + /*! + * \brief get the name of an op, if it is registered. + * \param op Obj of an op + * \return op name in String, if it is registered. + */ + TVM_DLL static String GetOpName(const Op& op); + /*! \brief specify container node */ using ContainerType = OpNode; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 04dd9223719e..acb4ab883a82 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -107,6 +107,7 @@ class TupleNode : public ExprNode { tvm::Array fields; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("en_id", &en_id); v->Visit("fields", &fields); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -315,6 +316,7 @@ class CallNode : public ExprNode { tvm::Array type_args; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("en_id", &en_id); v->Visit("op", &op); v->Visit("args", &args); v->Visit("attrs", &attrs); diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index 848af1e4ee4e..10838d49788f 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -29,6 +29,7 @@ from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.relay.op.contrib.bnns import partition_for_bnns from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai +from tvm.relay.op.contrib.mrvl import partition_for_mrvl from .common import TVMCException @@ -76,6 +77,10 @@ "config_key": "relay.ext.vitis_ai.options", "pass_pipeline": partition_for_vitis_ai, }, + "mrvl": { + "config_key": "relay.ext.mrvl.options", + "pass_pipeline": partition_for_mrvl, + }, } diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 5f4a134270ac..f5340670413e 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -164,6 +164,7 @@ def __init__( libmod_name, params, function_metadata, + external_graph_json_str=None, ): assert isinstance(graph_json_str, string_types) fcreate = get_global_func("tvm.graph_executor_factory.create") @@ -177,6 +178,7 @@ def __init__( self.executor = executor self.module = fcreate(graph_json_str, libmod, libmod_name, *args) self.graph_json = graph_json_str + self.external_graph_json = external_graph_json_str self.lib = libmod self.libmod_name = libmod_name self.params = params @@ -198,5 +200,8 @@ def get_graph_json(self): def get_executor_config(self): return self.graph_json + def get_external_graph_json(self): + return self.external_graph_json + def get_lib(self): return self.lib diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 09b847a3ba91..eec88a8e14c3 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -95,6 +95,7 @@ class BuildModule(object): def __init__(self): self.mod = _build_module._BuildModule() self._get_graph_json = self.mod["get_graph_json"] + self._get_external_graph_json = self.mod["get_external_graph_json"] self._get_module = self.mod["get_module"] self._build = self.mod["build"] self._optimize = self.mod["optimize"] @@ -193,8 +194,11 @@ def build( mod = self.get_module() params = self.get_params() executor_config = self.get_graph_json() if str(executor) == "graph" else None + external_executor_config = ( + self.get_external_graph_json() if str(executor) == "graph" else None + ) - return executor_config, mod, params + return executor_config, mod, params, external_executor_config def optimize(self, mod, target=None, params=None): """ @@ -238,6 +242,10 @@ def get_graph_json(self): """Return the json file of the built program.""" return self._get_graph_json() + def get_external_graph_json(self): + """Return the external json file of the built program.""" + return self._get_external_graph_json() + def get_module(self): """Return the built module.""" return self._get_module() @@ -446,7 +454,7 @@ def build( with tophub_context: bld_mod = BuildModule() - graph_json, runtime_mod, params = bld_mod.build( + graph_json, runtime_mod, params, external_graph_json = bld_mod.build( mod=ir_mod, target=target, params=params, @@ -472,7 +480,15 @@ def build( ) elif str(executor) == "graph": executor_factory = _executor_factory.GraphExecutorFactoryModule( - ir_mod, target, executor, graph_json, runtime_mod, mod_name, params, func_metadata + ir_mod, + target, + executor, + graph_json, + runtime_mod, + mod_name, + params, + func_metadata, + external_graph_json, ) else: assert False, "Executor " + executor + " not supported" diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 1dd6da6c2747..f914ff569dad 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -25,3 +25,4 @@ from .ethosn import * from .tensorrt import * from .cutlass import * +from .mrvl import * diff --git a/python/tvm/relay/op/contrib/mrvl.py b/python/tvm/relay/op/contrib/mrvl.py new file mode 100644 index 000000000000..b55e6c7c124b --- /dev/null +++ b/python/tvm/relay/op/contrib/mrvl.py @@ -0,0 +1,1758 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +""" +file mrvl.py +Marvell MLIP specific API +""" + + +import re +import base64 +import json +import yaml + +import tvm +from tvm import relay +from tvm.relay.transform import _ffi_api + +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.expr import ( + Call, + Let, + Var, + GlobalVar, + If, + Tuple, + TupleGetItem, + RefCreate, + RefWrite, + RefRead, +) +from tvm.relay.function import Function + +from ...dataflow_pattern import ( + wildcard, + is_op, + is_constant, + is_tuple_get_item, + is_var, +) +from .register import register_pattern_table +from ..strategy.generic import is_depthwise_conv2d + + +def clear_ext_json_flag(): + """clear_ext_json_flag + + Returns + ------- + ret: none + """ + ext_json = tvm.get_global_func("relay.mrvl.clear_ext_json_flag") + ext_json() + + +def is_mrvl_runtime_enabled(): + """Check if the Mrvl graph executor is present. + + Returns + ------- + ret: bool + True if present, False if not. + """ + check_enabled = tvm.get_global_func("relay.op.is_mrvl_runtime_enabled", True) + if check_enabled: + return check_enabled() + return False + + +def mrvl_register_op_attr_funcs_for_convert_layout(): + """ FIXME """ + # NOTE: for max_pool2d, global_max_pool2d, avg_pool2d, and global_avg_pool2d, + # we can rely on registered convert layout functions defined in + # the tvm/python/tvm/relay/op/nn/_nn.py file + + # reset first in order to register & use a new nn.conv2d convert layout function + relay.op.get("nn.conv2d").reset_attr("FTVMConvertOpLayout") + + @tvm.ir.register_op_attr("nn.conv2d", "FTVMConvertOpLayout") + def convert_conv2d(attrs, inputs, tinfos, desired_layouts): + new_attrs = dict(attrs) + # original input data shape is in NCHW format + # data_info_const = tinfos[0] + # original kernel shape is in OIHW format + weight_info_const = tinfos[1] + # output channels + new_attrs["channels"] = weight_info_const.shape[0] + + # convert shapes for input data, kernel, and output to use NHWC, OHWI, + # and NHWC, respectively + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + # allow us to set input tensor's data_layout == output tensor's out_layout + new_attrs["data_layout"] = desired_data_layout + new_attrs["kernel_layout"] = desired_kernel_layout + new_attrs["out_layout"] = desired_data_layout + return relay.nn.conv2d(*inputs, **new_attrs) + + return convert_conv2d + + +def partition_for_mrvl( + mod, + params=None, + tvm_custom_dict=None, + gen_non_mrvl_subgraph=True, + flow_pass=1, + **opts, +): + """Partition the graph greedily offloading supported + operators to Mrvl + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + mod_mrvl : annotated and partitioned module - part 1, the mrvl sub graph + mod_other : annotated and partitioned module - part 2, if any, the rest sub graph + params : TBA + opt_level : TBA + disabled_pass_list : TBA + mod : TBA + mrvl_layers_in_mrvl_subgraph : TBA + """ + clear_ext_json_flag() + + # permanently use Mrvl defined convert layout functions + mrvl_register_op_attr_funcs_for_convert_layout() + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + # tvm.transform.Sequential()'s default opt_level is 2 + opt_level = 3 + disabled_pass_list = ["AlterOpLayout"] + seq = tvm.transform.Sequential( + passes=[ + # available but not used tvm passes are: + # - 0, " FoldExplicitPadding", {"InferType"} // extra leading space? + # - 0, "SimplifyInference", {"InferType"} + # - 1, "FuseOps", {"InferType"} + # - 1, "Legalize", {"InferType"} + # - 1, "RewriteAnnotatedOps", {"InferType"} + # - 3, "AlterOpLayout", {"InferType"} + # - 3, "AutoSchedulerLayoutRewrite", {"InferType"} + # - 3, "BackwardFoldScaleAxis", {"InferType"} + # - 3, "CanonicalizeCast", {"InferType"} + # - 3, "CanonicalizeOps", {"InferType"} + # - 3, "DefuseOps", {"InferType"} + # - 3, "EliminateCommonSubexpr", {"InferType"} + # - 3, "ForwardFoldScaleAxis", {"InferType"} + # - 4, "CombineParallelBatchMatmul", {"InferType"} + # - 4, "CombineParallelConv2d", {"InferType"} + # - 4, "CombineParallelDense", {"InferType"} + # - 4, "CombineParallelOpBatch", {"InferType"} + # - 4, "FastMath", {"InferType"} + # trigger tvm existing relay pass, which contains sub-passes: type_infer.cc + # - (0, "InferType", {}); + relay.transform.InferType(), + # tvm.transform.PrintIR("after InferType"), # ~/a + # implement mrvl own pass (opt_level=0) for nn.dropout + MrvlRemoveDropoutPass(), + # tvm.transform.PrintIR("after MrvlRemoveDropout"), # ~/b + # trigger tvm existing relay pass, which contains sub-passes: + # relay/backend/vm/removed_unused_funcs.cc + # - (1, "RemoveUnusedFunctions", {}); + relay.transform.RemoveUnusedFunctions(), + # tvm.transform.PrintIR("after RemoveUnusedFunctions"), # ~/c + # trigger tvm existing relay ConvertLayout pass: convert_layout.cc + # - (3, "CanonicalizeOps", {"InferType"}) + # - (3, "ConvertLayout", {"InferType", "CanonicalizeOps"}) + # - we can describe mrvl-specific format + # - we can also implement mrvl per-relay-op conversion functions + # - we can hook them to relay-op framework using Python @ decorator + relay.transform.ConvertLayout( + {"nn.conv2d": ["NHWC", "OHWI"], "nn.max_pool2d": ["NHWC"]} + ), + # tvm.transform.PrintIR("after ConvertLayout"), # ~/d + # trigger tvm existing relay pass, which contains sub-passes: fold_constant.cc + # - (2, "FoldConstant", {}) + relay.transform.FoldConstant(), + # tvm.transform.PrintIR("after FoldConstant"), # ~/e + # trigger tvm existing relay pass, which contains sub-passes: simplify_expr.cc + # - (0, "SimplifyExpr", {"InferType"}) + # - ConcretizeZerosLikeRewrite, ConcretizeOnesLikeRewrite, ConcretizeFullLikeRewrite, + # - ConcretizeReshapeLikeRewrite, ConcretizeCollapseSumLikeRewrite, + # ConcretizeBroadcastToLikeRewrite, + # - EliminateIdentityRewrite, SimplifyReshape, SimplifyTranspose, + # - SimplifyCast, # - FullElementwise, + relay.transform.SimplifyExpr(), + # tvm.transform.PrintIR("after SimplifyExpr"), # ~/e + # implement mrvl-specific drop-noop-transpose pass: drop_noop_transpose.cc + # - (0, "DropNoopTranspose", {"InferType"}) + # - we can implement mrvl C++ pass + # - we can hook it to relay-pass framework: + # + first using C++ TVM_REGISTER_GLOBAL("relay._transform.DropNoopTranspose"). + # set_body_typed(DropNoopTranspose); + # + then using Python @ decorator below + _ffi_api.DropNoopTranspose(), + relay.transform.InferType(), + # tvm.transform.PrintIR("after DropNoopTranspose"), # ~/e + # trigger tvm existing relay pass, which contains sub-passes: merge_composite.cc + # - (0, "MergeComposite", {}) + # - we can also implement mrvl specific composite patterns + # - we can hook them to relay-merge-composite framework using Python @ decorator + relay.transform.MergeComposite(mrvl_pattern_table()), + # tvm.transform.PrintIR("after MergeComposite"), # ~/f + # trigger tvm existing relay pass, which contains sub-passes: annotate_target.cc + # - 0, "AnnotateTargetFunc", {"InferType"} + relay.transform.AnnotateTarget("mrvl", False), + # tvm.transform.PrintIR("after AnnotateTarget mrvl"), # ~/g + # this call (partition_graph.cc) can trigger @register_func("relay.ext.mrvl.optimize"), + # if defined + # - (0, "FlattenNestedTuples", {}), (0, "RemoveDefaultAnnotations", {}), + # and (0, "PartitionGraph", {}) + # - mangle module name: "tvmgen_" + "mrvl_main_" with a post-fix <#> + relay.transform.PartitionGraph(""), + # tvm.transform.PrintIR("after PartitionGraph"), # ~/h + # trigger tvm existing relay pass, which contains sub-passes: type_infer.cc + # - (0, "InferType", {}); + relay.transform.InferType(), + # tvm.transform.PrintIR("final IR"), # ~/h + ] + ) + with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass_list): + # triggers tvm/ir/transform.py: transform.pass => __call__() => + # _ffi_transform_api.RunPass(self, mod) + # - src/tvm/ir/transform.cc: IRMod <= transform.RunPass() + mod = seq(mod) + mutator = MrvlIRGraphUtils() + # print("3.a mutator (inst of MrvlIRGraphUtils): {}".format(mutator), flush=True) + mod_mrvl, mod_other, mrvl_layers_in_mrvl_subgraph = mutator.compute_two_subgraphs( + mod, gen_non_mrvl_subgraph=gen_non_mrvl_subgraph, flow_pass=flow_pass + ) + + # annotated and partitioned mod_mrvl + return ( + mod_mrvl, + mod_other, + params, + opt_level, + disabled_pass_list, + mod, + mrvl_layers_in_mrvl_subgraph, + ) + + +def defuse_mrvl_layers_in_mrvl_subgraph(mod, defuse_mrvl_layers_list): + """given a Mrvl subgraph, user can decide to use only a subset of the Mrvl subgraph; and + this can be done by: (a) use a graph viewer to see structure of the Mrvl subgraph + including names of consecutive Mrvl layers; and (b) to identify what set of Mrvl + layer names to be cut (e.g., by treating them as defuse nodes) + """ + mutator = MrvlIRGraphUtils() + # print("3.b mutator (inst of MrvlIRGraphUtils): {}".format(mutator), flush=True) + mod_mrvl, mod_other, mrvl_layers_in_mrvl_subgraph = mutator.compute_two_subgraphs( + mod, + defuse_mrvl_layers_list=defuse_mrvl_layers_list, + gen_non_mrvl_subgraph=True, + flow_pass=2, + ) + return mod_mrvl, mod_other, mrvl_layers_in_mrvl_subgraph + + +def dump_json_meta_data_files(external_graph_json, const_params, filename_prefix="metadata"): + """Generate two meta data json file and return their filenames + + Parameters + ---------- + external_graph_json : str + The json string that can be accepted by graph executor. + It is generated from the GetExternalJSON() function + const_params: constant params + filename_prefix : Optional json filename prefix + + Returns + ------- + node_json_filename : json filename for nodes and etc. + const_json_filename : meta data json filename for parameters + """ + relay_json_obj = yaml.load( + """\n%(json_str)s + """ + % {"json_str": external_graph_json} + ) + node_json_filename = "{}-byoc.json".format(filename_prefix) + with open(node_json_filename, "w+") as json_f: + json.dump(relay_json_obj, json_f, indent=2) + # with open(node_json_filename, "r") as inp_f: + # node_json_obj = json.load(inp_f) + + # const params have been erased from graph_json and moved to + # metadata module + const_json_filename = "{}-byoc-const.json".format(filename_prefix) + with open("{}".format(const_json_filename), "w+") as json_c: + json_c.write("{\n") + first_const = True + for const_key, const_value in const_params.items(): + if ("mrvl" not in const_key) or ("const" not in const_key): + continue + if first_const: + json_c.write(' "{}": {}\n'.format(const_key, "{")) + else: + json_c.write(' {},\n "{}": {}\n'.format("}", const_key, "{")) + shape_str = str(const_value.shape) + shape_str = shape_str.replace("(", "[") + shape_str = shape_str.replace(")", "]") + # need to take care of special case: composite FC with batch 1 and a scalar add() bias + # - e.g.: its shape: (32,) needs to be converted to [32,] and then to [1,32] + shape_re = "[[](?P[1-9][0-9]+),[]]" + match_obj = re.match(shape_re, shape_str) + if match_obj: + shape_str = "[1, {}]".format(match_obj.group("scalar_dim_val")) + json_c.write(' "shape": {},\n'.format(shape_str)) + json_c.write(' "dtype": "{}",\n'.format(const_value.dtype)) + json_c.write( + ' "data_base64": "{}"\n'.format( + base64.b64encode(const_value.asnumpy()).decode("utf-8") + ) + ) + first_const = False + json_c.write(" }\n}\n") + + # with open(const_json_filename, "r") as inp_f: + # const_json_obj = json.load(inp_f) + + return node_json_filename, const_json_filename + + +def convert_consts_json_meta_data_to_string( + const_params, +): + """Generate two meta data json file and return their filenames + + Parameters + ---------- + const_params: constant params + + Returns + ------- + const_json_string : meta data of params in json string + """ + # const params have been erased from graph_json and moved to + # metadata module + json_str = "{\n" + first_const = True + for const_key, const_value in const_params.items(): + if ("mrvl" not in const_key) or ("const" not in const_key): + continue + if first_const: + json_str = json_str + ' "{}": {}\n'.format(const_key, "{") + else: + json_str = json_str + ' {},\n "{}": {}\n'.format("}", const_key, "{") + shape_str = str(const_value.shape) + shape_str = shape_str.replace("(", "[") + shape_str = shape_str.replace(")", "]") + # need to take care of special case: composite FC with batch 1 and a scalar add() bias + # - e.g.: its shape: (32,) needs to be converted to [32,] and then to [1,32] + shape_re = "[[](?P[1-9][0-9]+),[]]" + match_obj = re.match(shape_re, shape_str) + if match_obj: + shape_str = "[1, {}]".format(match_obj.group("scalar_dim_val")) + json_str = json_str + ' "shape": {},\n'.format(shape_str) + json_str = json_str + ' "dtype": "{}",\n'.format(const_value.dtype) + json_str = json_str + ' "data_base64": "{}"\n'.format( + base64.b64encode(const_value.asnumpy()).decode("utf-8") + ) + first_const = False + json_str = json_str + " }\n}\n" + + return json_str + + +@register_pattern_table("mrvl") +def mrvl_pattern_table(): + """Get the Mrvl pattern table.""" + + def conv2d_nhwc2nhwc_pattern(): + """Create a convolution-2d pattern. + review tvm/tests/python/relay/test_dataflow_pattern.py for examples + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the convolution-2d pattern. + """ + pattern = is_op("nn.pad")(wildcard()) | wildcard() + pattern = is_op("nn.conv2d")(pattern, is_constant()) + pattern = pattern.optional( + lambda x: (is_op("nn.bias_add")(x, is_constant()) | is_op("add")(x, is_constant())) + ) + + # conv + [add] + [relu] + pattern1 = pattern.optional(is_op("nn.relu")) + + # conv + [add] + batch_norm + %.0 + [relu] + pattern2 = is_op("nn.batch_norm")(pattern, wildcard(), wildcard(), wildcard(), wildcard()) + pattern2 = is_tuple_get_item(pattern2, 0) + pattern2 = pattern2.optional(is_op("nn.relu")) + + return pattern1 | pattern2 + + def sum2d_pattern(): + """Create a sum2d pattern. + review tvm/tests/python/relay/test_dataflow_pattern.py for examples + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the sum2d pattern. + """ + # do these in check_sum2d + # - need to further checking if the call_func of args[0] is not nn.conv2d nor nn.dense + # - need to further checking if dimension of input or output tensor is 4 + pattern = is_op("add")(wildcard(), wildcard()) + pattern = pattern.optional(is_op("nn.relu")) + return pattern + + def fc_pattern(): + """Create a fc (fully-connected) pattern. + review tvm/tests/python/relay/test_dataflow_pattern.py for examples + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the fc pattern. + """ + pattern = is_op("nn.dense")(wildcard(), is_constant()) + pattern = pattern.optional( + lambda x: (is_op("nn.bias_add")(x, is_constant()) | is_op("add")(x, is_constant())) + ) + pattern = pattern.optional(is_op("nn.relu")) + return pattern + + def maxpool2d_pattern(): + """Create a maxpool2d pattern. + review tvm/tests/python/relay/test_dataflow_pattern.py for examples + + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the maxpool2d pattern. + """ + pattern = is_op("nn.max_pool2d")(wildcard()) + return pattern + + def layout_transform_pattern(): + # pattern = is_op("layout_transform")(wildcard().match(GlobalVar), wildcard(), + # wildcard()).has_attr( + # {"src_layout": "NCHW", "dst_layout": "NHWC"}) + pattern = is_op("layout_transform")(is_var(), wildcard(), wildcard()).has_attr( + {"src_layout": "NCHW", "dst_layout": "NHWC"} + ) + return pattern + + def check_conv2d(extract): + """Check conv pattern is supported by Mrvl.""" + call = extract + # loop over fused Mrvl conv2d sub graph to find the conv2d op + # - it is okay if we also fused nn.pad because, in conv2d_nhwc2nhwc(), + # we do checks starting from conv2d op + # - in case of nn.batch_norm, a tuple-get-item node exists inside + # the fused conv2d sub graph + while isinstance(call, TupleGetItem) or (call.op.name != "nn.conv2d"): + if isinstance(call, TupleGetItem): + call = call.tuple_value + else: + call = call.args[0] + return conv2d_nhwc2nhwc(call) + + def check_fc(extract): + """Check fc pattern is supported by Mrvl.""" + call = extract + while call.op.name != "nn.dense": + call = call.args[0] + return fc_ni2no(call) + + def check_maxpool2d(extract): + """Check maxpool2d pattern is supported by Mrvl.""" + call = extract + while call.op.name != "nn.max_pool2d": + call = call.args[0] + return maxpool2d_nhwc2nhwc(call) + + def check_layout_transform(extract): + call = extract + while call.op.name != "layout_transform": + call = call.args[0] + return layout_transform_nchw2nhwc(call) + + def check_sum2d(extract): + """Check maxpool pattern is supported by Mrvl.""" + call = extract + while call.op.name != "add": + call = call.args[0] + return sum2d(call) + + return [ + ("mrvl.conv2d_nhwc2nhwc", conv2d_nhwc2nhwc_pattern(), check_conv2d), + ("mrvl.fc_ni2no", fc_pattern(), check_fc), + ("mrvl.maxpool2d_nhwc2nhwc", maxpool2d_pattern(), check_maxpool2d), + ("mrvl.sum2d", sum2d_pattern(), check_sum2d), + ("mrvl.layout_transform_nchw2nhwc", layout_transform_pattern(), check_layout_transform), + ] + + +def _register_external_op_helper(op_name, supported=True): + """The helper function to indicate that a given operator can be supported by Mrvl. + + Paramters + --------- + op_name : Str + The name of operator that will be registered. + + Returns + ------- + f : callable + A function that returns if the operator is supported by DNNL. + """ + + @tvm.ir.register_op_attr(op_name, "target.mrvl") + def _func_wrapper(expr): + return supported + + return _func_wrapper + + +_register_external_op_helper("nn.batch_flatten") +_register_external_op_helper("reshape") +_register_external_op_helper("transpose") + + +# register a helper function to indicate that the given operator can be supported by Mrvl. +@tvm.ir.register_op_attr("nn.conv2d", "target.mrvl") +def conv2d_nhwc2nhwc(expr): + """Check if the external Mrvl codegen for conv2d_nhwc2nhwc should be used.""" + attrs, args = expr.attrs, expr.args + if attrs.data_layout != "NHWC": + return False + if attrs.out_dtype != "float32" and attrs.out_dtype != "": + return False + data_type = args[0].checked_type + if ( + (len(data_type.shape) != 4) + or (data_type.shape[0] != 1) + or (data_type.dtype not in ["float32"]) + ): + return False + kernel_typ = args[1].checked_type + if (len(kernel_typ.shape) != 4) or (kernel_typ.dtype not in ["float32"]): + return False + is_depthwise = is_depthwise_conv2d( + data_type.shape, + attrs["data_layout"], + kernel_typ.shape, + attrs["kernel_layout"], + attrs["groups"], + ) + if is_depthwise: + return depthwise_conv2d_nhwc2nhwc(attrs, args) + # Mrvl doesn't support grouped convolution + if attrs.groups != 1 and not is_depthwise: + return False + return True + + +# register a helper function to indicate that the given operator can be supported by Mrvl. +@tvm.ir.register_op_attr("add", "target.mrvl") +def sum2d(expr): + """Check if the external Mrvl codegen for sum2d should be used.""" + arg0 = expr.args[0] + + # - need to further checking if the call_func of arg0 is not nn.conv2d nor nn.dense + if ( + isinstance(arg0, Call) + and isinstance(arg0.op, tvm.ir.Op) + and arg0.op.name in ["nn.conv2d", "nn.dense"] + ): + return False + + # - need to further checking if dimension of input or output tensor is 4 + data_type = arg0.checked_type + if ( + (len(data_type.shape) != 4) + or (data_type.shape[0] != 1) + or (data_type.dtype not in ["float32"]) + ): + return False + return True + + +# TODO(ccjoechou): register a helper function to indicate that the given operator +# can be supported by Mrvl. +def depthwise_conv2d_nhwc2nhwc(attrs, args): + """Check if the external Mrvl codegen for depthwise convolution should be used. + + Note + ---- + Relay does not have a depthwise conv2d_nhwc2nhwc operator whilst Mrvl does. We simply + separate the checks for depthwise for clarity. + """ + kernel_typ = args[1].checked_type + # Only supports 3x3, 5x5 depthwise + if ( + kernel_typ.shape[0] not in [3, 5] + or kernel_typ.shape[1] not in [3, 5] + or kernel_typ.shape[0] != kernel_typ.shape[1] + ): + return False + # Stride must be (1, 1) or (2, 2) + if (attrs.strides[0], attrs.strides[1]) not in [(1, 1), (2, 2)]: + return False + return True + + +# register a helper function to indicate that the given operator can be supported by Mrvl. +@tvm.ir.register_op_attr("nn.dense", "target.mrvl") +def fc_ni2no(expr): + """Check if the external Mrvl codegen for fc_ni2no should be used.""" + attrs, args = expr.attrs, expr.args + data_type = args[0].checked_type + if data_type.dtype not in ["float32"]: + return False + kernel_typ = args[1].checked_type + if (len(kernel_typ.shape) != 2) or (kernel_typ.dtype not in ["float32"]): + return False + if attrs.out_dtype != "float32" and attrs.out_dtype != "": + return False + return True + + +# register a helper function to indicate that the given operator can be supported by Mrvl. +@tvm.ir.register_op_attr("nn.max_pool2d", "target.mrvl") +def maxpool2d_nhwc2nhwc(expr): + """Check if the external Mrvl codegen for maxpool2d_nhwc2nhwc should be used.""" + attrs, args = expr.attrs, expr.args + if attrs.layout != "NHWC": + return False + data_type = args[0].checked_type + if data_type.dtype not in ["float32"]: + return False + return True + + +# register a helper function to indicate that the given operator can be supported by Mrvl. +@tvm.ir.register_op_attr("layout_transform", "target.mrvl") +def layout_transform_nchw2nhwc(expr): + """ FIXME """ + attrs, args = expr.attrs, expr.args + if attrs.src_layout != "NCHW": + return False + if attrs.dst_layout != "NHWC": + return False + data_type = args[0].checked_type + if data_type.dtype not in ["float32"]: + return False + return True + + +class RemoveDropout(ExprMutator): + """Removes all nn.dropout from an expr.""" + + def visit_tuple_getitem(self, op): + visit = super().visit_tuple_getitem(op) + if visit.index != 0: + return visit + if ( + isinstance(visit.tuple_value, Call) + and visit.tuple_value.op.name == "nn.dropout" + and visit.index == 0 + ): + # skip nn.dropout call and return arg0 instead + return visit.tuple_value.args[0] + return visit + + +@relay.transform.function_pass(opt_level=0) +class MrvlRemoveDropoutPass: + def transform_function(self, func, mod, _): + return RemoveDropout().visit(func) + + +class MrvlLayers(ExprMutator): + """experimental class: + do post-order DFS traverse analysis based on the value of the !mrvl_color attribute + to decide whether a Mrvl layer/node has an output for the IR sub graph of consecutive + Mrvl layers + """ + + def __init__( + self, + mutate_style="compute-mrvl-color", + mrvl_layer_names=None, + defuse_mrvl_layers_list=None, + mrvl_layers_consecutive=None, + mrvl_layers_outputs=None, + debug=False, + ): + ExprMutator.__init__(self) + self._debug = debug + self._compute_mrvl_color = False + self._get_mrvl_subgraph = False + if mutate_style in ["compute-mrvl-color"]: + self._compute_mrvl_color = True + # dictionary for consecutive Mrvl layers + self._mrvl_layers_consecutive = {} + # dictionary for non-consecutive Mrvl layers, which need to be defused + self._mrvl_layers_to_defuse = {} + if defuse_mrvl_layers_list is not None: + # user has provided initial names of Mrvl layers + # to be de-fused based on previous run + for name in defuse_mrvl_layers_list: + if name not in mrvl_layer_names: + raise RuntimeError( + "TVM-Mrvl-BYOC: defuse name ({}) isn't in Mrvl subgraph ({})".format( + name, mrvl_layer_names + ) + ) + self._mrvl_layers_to_defuse[name] = True + elif mutate_style in ["get-mrvl-subgraph"]: + self._get_mrvl_subgraph = True + assert defuse_mrvl_layers_list is None + self._mrvl_layers_consecutive = mrvl_layers_consecutive + self._outputs_mrvl_name = mrvl_layers_outputs + self._outputs_call = [] + self._inputs = [] + else: + raise RuntimeError("TVM-Mrvl-BYOC: unsupported mutate style: {}".format(visit_style)) + + def dump_debug_text_info(self, n, label): + astext_list = n.astext(False).splitlines() + if astext_list[-1:][0] in [" */"]: + str_list = astext_list[-5:-4][0].split(") /*") + else: + str_list = astext_list[-1:][0].split(") /*") + print("{}: {})".format(label, str_list[0]), flush=True) + + def post_order_analysis(self, call, name, layer_type): + """do post-order DFS traverse analysis: using the mrvl_color attribute where: + if mrvl_color == True: this Mrvl layer call is in the group (or inside the subgraph) + of consecutive Mrvl layers + """ + if self._debug: + call_astext_list = call.astext(False).splitlines() + if call_astext_list[-1:][0] in [" */"]: + call_func_str_list = call_astext_list[-5:-4][0].split(") /*") + else: + call_func_str_list = call_astext_list[-1:][0].split(") /*") + print("Debug: post-order {} {})".format(layer_type, call_func_str_list[0]), flush=True) + + assert hasattr(call, "mrvl_color") + if call.mrvl_color: + if name in self._mrvl_layers_to_defuse: + # allow use-provided names of to be defused Mrvl layers + call.mrvl_color = False + else: + self._mrvl_layers_consecutive[name] = True + elif (not call.mrvl_color) and isinstance(call.op, GlobalVar): + self._mrvl_layers_to_defuse[name] = True + + def visit_function(self, fn): + """override base class ExprMutator's visit_function() so that + (1) we can add & use the mrvl_color attribute to determine whether + the Function obj is inside the group (or subgraph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + if self._compute_mrvl_color: + params_mrvl_color = True + if self._get_mrvl_subgraph: + params_has_none = False + new_params = [] + for x in fn.params: + new_param = self.visit(x) + new_params.append(new_param) + if self._compute_mrvl_color: + assert hasattr(new_param, "mrvl_color") + if not new_param.mrvl_color: + params_mrvl_color = False + if self._get_mrvl_subgraph: + if new_param is None: + params_has_none = True + + new_body = self.visit(fn.body) + if self._get_mrvl_subgraph: + if (new_body is None) or params_has_none: + if self._debug: + self.dump_debug_text_info(fn, "drop fn") + return None + new_fn = Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs) + if self._compute_mrvl_color: + assert hasattr(new_body, "mrvl_color") + new_fn.mrvl_color = params_mrvl_color and new_body.mrvl_color + return new_fn + + def visit_let(self, let): + """override base class ExprMutator's visit_let() so that + (1) we can add & use the mrvl_color attribute to determine whether + the Let obj is inside the group (or subgraph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + new_var = self.visit(let.var) + new_value = self.visit(let.value) + new_body = self.visit(let.body) + if self._get_mrvl_subgraph: + if (new_var is None) or (new_value is None) or (new_body is None): + if self._debug: + self.dump_debug_text_info(let, "drop let") + return None + new_let = Let(new_var, new_value, new_body) + if self._compute_mrvl_color: + assert hasattr(new_var, "mrvl_color") + assert hasattr(new_value, "mrvl_color") + assert hasattr(new_body, "mrvl_color") + new_let.mrvl_color = new_var.mrvl_color and new_value.mrvl_color and new_body.mrvl_color + return new_let + + def visit_call(self, call): + """override base class ExprMutator's visit_call() so that + (1) we can add & use the mrvl_color attribute to determine whether + the Call obj is inside the group (or subgraph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + name = None + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + else: + name = call.op.name + + if self._compute_mrvl_color: + if isinstance(call.op, GlobalVar): + layer_type = "mrvl-layer: " + else: + assert isinstance(call.op, tvm.ir.Op) + layer_type = "non-mrvl-layer:" + args_mrvl_color = True + + new_fn = self.visit(call.op) + if self._get_mrvl_subgraph: + args_has_none = False + new_args = [] + for idx, arg in enumerate(call.args): + if self._compute_mrvl_color and self._debug and (idx > 0): + print("Debug: post-order: visit call-arg{} @{}".format(idx, name), flush=True) + new_arg = self.visit(arg) + new_args.append(new_arg) + if self._compute_mrvl_color: + assert hasattr(new_arg, "mrvl_color") + if not new_arg.mrvl_color: + args_mrvl_color = False + if self._get_mrvl_subgraph: + if new_arg is None: + args_has_none = True + + if self._get_mrvl_subgraph: + if name not in self._mrvl_layers_consecutive: + if self._debug: + self.dump_debug_text_info(call, "drop call") + return None + + new_call = Call(new_fn, new_args, call.attrs, call.type_args, call.span) + + if self._compute_mrvl_color: + assert hasattr(new_fn, "mrvl_color") + new_call.mrvl_color = args_mrvl_color and new_fn.mrvl_color + self.post_order_analysis(new_call, name, layer_type) + if self._get_mrvl_subgraph: + assert not args_has_none + if name in self._outputs_mrvl_name: + if self._debug: + print("add outputs: {}".format(name), flush=True) + self._outputs_call.append(new_call) + return new_call + + def visit_var(self, var): + """override base class ExprMutator's visit_var() so that + (1) we can add & use the mrvl_color attribute, or + (2) return only Mrvl subgraph + """ + if self._compute_mrvl_color: + var.mrvl_color = True + if self._get_mrvl_subgraph: + if self._debug: + self.dump_debug_text_info(var, "add inputs: var") + self._inputs.append(var) + return var + + def visit_global_id(self, global_var): + """override base class ExprMutator's visit_global_id() so that + we can add & use the mrvl_color attribute + """ + if self._compute_mrvl_color: + global_var.mrvl_color = True + return global_var + + def visit_if(self, ite): + """override base class ExprMutator's visit_if() so that + (1) we can add & use the mrvl_color attribute to determine whether + the If obj is inside the group (or the sub graph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + new_cond = self.visit(ite.cond) + new_true_branch = self.visit(ite.true_branch) + new_false_branch = self.visit(ite.false_branch) + if self._get_mrvl_subgraph: + if (new_cond is None) or (new_true_branch is None) or (new_false_branch is None): + if self._debug: + self.dump_debug_text_info(ite, "drop ite") + return None + new_if = If(new_cond, new_true_branch, new_false_branch) + if self._compute_mrvl_color: + assert hasattr(new_cond, "mrvl_color") + assert hasattr(new_true_branch, "mrvl_color") + assert hasattr(new_false_branch, "mrvl_color") + new_if.mrvl_color = ( + new_cond.mrvl_color and new_true_branch.mrvl_color and new_false_branch.mrvl_color + ) + return new_if + + def visit_tuple(self, tup): + """override base class ExprMutator's visit_tuple() so that + (1) we can add & use the mrvl_color attribute to determine whether + the Tuple obj is inside the group (or subgraph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + if self._compute_mrvl_color: + fields_mrvl_color = True + if self._get_mrvl_subgraph: + fields_has_none = False + new_fields = [] + for field in tup.fields: + new_field = self.visit(field) + new_fields.append(new_field) + if self._compute_mrvl_color: + assert hasattr(new_field, "mrvl_color") + if not new_field.mrvl_color: + fields_mrvl_color = False + if self._get_mrvl_subgraph: + if new_field is None: + fields_has_none = True + + if self._get_mrvl_subgraph: + if fields_has_none: + if self._debug: + self.dump_debug_text_info(tup, "drop tup") + return None + new_tup = Tuple(new_fields, tup.span) + if self._compute_mrvl_color: + new_tup.mrvl_color = fields_mrvl_color + return new_tup + + def visit_tuple_getitem(self, op): + """override base class ExprMutator's visit_tuple_getitem() so that + (1) we can add & use the mrvl_color attribute to determine whether + the op or TupleGetItem obj is inside the group (or subgraph) + of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + tuple_value = self.visit(op.tuple_value) + if self._get_mrvl_subgraph: + if tuple_value is None: + if self._debug: + self.dump_debug_text_info(op, "drop op") + return None + if not tuple_value.same_as(op.tuple_value): + new_tup_get_item = TupleGetItem(tuple_value, op.index) + if self._compute_mrvl_color: + assert hasattr(tuple_value, "mrvl_color") + new_tup_get_item.mrvl_color = tuple_value.mrvl_color + return new_tup_get_item + + # usually we do not get here, but, if we do, we can only + # add the mrvl_color attribute to the original IR graph + if self._compute_mrvl_color: + if not hasattr(op, "mrvl_color"): + op.mrvl_color = True + return op + + def visit_global_var(self, gvar): + """override base class ExprMutator's visit_global_var() so that + we can add & use the mrvl_color attribute + """ + if self._compute_mrvl_color: + gvar.mrvl_color = True + return gvar + + def visit_op(self, op): + """override base class ExprMutator's visit_op() so that + we can add & use the mrvl_color attribute + """ + if self._compute_mrvl_color: + # - all Mrvl layers are GlobalVar objs + # - all ops are non Mrvl layers + op.mrvl_color = False + return op + + def visit_constant(self, const): + """override base class ExprMutator's visit_constant() so that + we can add & use the mrvl_color attribute + """ + if self._compute_mrvl_color: + const.mrvl_color = True + return const + + def visit_ref_create(self, r): + """override base class ExprMutator's visit_ref_create() so that + (1) we can add & use the mrvl_color attribute to determine whether + the RefCreate obj is inside the group (or subgraph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + new_value = self.visit(r.value) + if self._get_mrvl_subgraph: + if new_value is None: + if self._debug: + self.dump_debug_text_info(r, "drop ref_create") + return None + new_refcreate = RefCreate(new_value) + if self._compute_mrvl_color: + assert hasattr(new_value, "mrvl_color") + new_refcreate.mrvl_color = new_value.mrvl_color + return new_refcreate + + def visit_ref_write(self, r): + """override base class ExprMutator's visit_ref_create() so that + (1) we can add & use the mrvl_color attribute to determine whether + the RefWrite obj is inside the group (or subgraph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + new_ref = self.visit(r.ref) + new_value = self.visit(r.value) + if self._get_mrvl_subgraph: + if (new_ref is None) or (new_value is None): + if self._debug: + self.dump_debug_text_info(r, "drop ref_create") + return None + new_refwrite = RefWrite(new_ref, new_value) + if self._compute_mrvl_color: + assert hasattr(new_ref, "mrvl_color") + assert hasattr(new_value, "mrvl_color") + new_refwrite.mrvl_color = new_ref.mrvl_color and new_value.mrvl_color + return new_refwrite + + def visit_ref_read(self, r): + """override base class ExprMutator's visit_ref_create() so that + (1) we can add & use the mrvl_color attribute to determine whether + the RefRead obj is inside the group (or subgraph) of consecutive Mrvl layers, or + (2) return only Mrvl subgraph + """ + new_ref = self.visit(r.ref) + if self._get_mrvl_subgraph: + if new_ref is None: + if self._debug: + self.dump_debug_text_info(r, "drop ref_create") + return None + new_refread = RefRead(new_ref) + if self._compute_mrvl_color: + assert hasattr(new_ref, "mrvl_color") + new_refread.mrvl_color = new_ref.mrvl_color + return new_refread + + def compute_main_func_mrvl_color(self, main_func): + """initiate post-order DFS traverse from each output tensore of + the main_func argument, i.e., mod["main"], + in order to find and return the group (or the Mrvl sub graph) + of consecutive Mrvl layers, as well as, Mrvl layers, which + need to be defused back to their original operators + """ + assert main_func + assert self._compute_mrvl_color + if self._debug: + print("mod[main] => {}".format(main_func.astext(False)), flush=True) + return self.visit(main_func) + + def get_consecutive_layers(self): + """return names of Mrvl layers inside the Mrvl subgraph + and names of Mrvl layers outside the Mrvl subgraph + """ + assert self._compute_mrvl_color + return self._mrvl_layers_consecutive, self._mrvl_layers_to_defuse + + def get_main_func_mrvl_subgraph(self, main_func): + """return only Mrvl subgraph""" + assert self._get_mrvl_subgraph + new_main_func = self.visit(main_func) + if new_main_func is not None: + if self._debug: + print("return new_main_func: {})".format(new_main_func.astext(False)), flush=True) + return new_main_func + + # we need to instantiate a new output or output tuple + if self._debug: + print( + "got new_main_func==None and need to construct a tuple outputs (size={})".format( + len(self._outputs_call) + ), + flush=True, + ) + assert len(self._outputs_call) > 0 + if len(self._outputs_call) == 1: + if self._debug: + print("take the only output call") + new_main_func = Function(list(self._inputs), self._outputs_call[0]) + else: + if self._debug: + print("tuple generated") + new_out_tup = Tuple(self._outputs_call) + new_main_func = Function(list(self._inputs), new_out_tup) + if self._debug: + print("new main func generated") + return new_main_func + + +# TODO(ccjoechou): Need to find all the possible cut points so that many corner +# cases can be identifed and fixed. +class RestOfMrvlLayers(ExprMutator): + """experimental class: + Figures out restof subgraph based on the input nodes id + and returns restof subgraph of a given model. + """ + + def __init__( + self, mrvl_layers_consecutive=None, rest_of_subgraph_inputs_en_id=None, debug=False + ): + ExprMutator.__init__(self) + self._debug = debug + self._first_function_visit = True + self._mrvl_layers_consecutive = mrvl_layers_consecutive + self._inputs_restof_subgraph_en_id = rest_of_subgraph_inputs_en_id + self._outputs_call = None + self._inputs = [] + self._inputs_call_names = {} + + def dump_debug_text_info(self, n, label): + astext_list = n.astext(False).splitlines() + if astext_list[-1:][0] in [" */"]: + str_list = astext_list[-5:-4][0].split(") /*") + else: + str_list = astext_list[-1:][0].split(") /*") + print("{}: {})".format(label, str_list[0]), flush=True) + + def visit_function(self, fn): + """override base class ExprMutator's visit function""" + # if self._compute_mrvl_color: params_mrvl_color = True + new_params = [] + save_first_func = None + if self._first_function_visit and len(self._mrvl_layers_consecutive) > 0: + self._first_function_visit = False + save_first_func = fn + else: + for x in fn.params: + new_param = self.visit(x) + new_params.append(new_param) + new_body = self.visit(fn.body) + if save_first_func: + assert new_body + self._outputs_call = new_body + return None + new_fn = Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs) + return new_fn + + def visit_let(self, let): + """override base class ExprMutator's visit function""" + new_var = self.visit(let.var) + new_value = self.visit(let.value) + new_body = self.visit(let.body) + if (new_var is None) or (new_value is None) or (new_body is None): + if self._debug: + self.dump_debug_text_info(let, "drop let") + return None + new_let = Let(new_var, new_value, new_body) + return new_let + + def visit_call(self, call): + """override base class ExprMutator's visit function""" + old_call_en_id = tvm.relay._ffi_api.get_en_id(call) + name = None + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + else: + name = call.op.name + new_fn = self.visit(call.op) + new_args = [] + for idx, arg in enumerate(call.args): + new_arg = self.visit(arg) + new_args.append(new_arg) + if new_arg: + if ( + isinstance(new_arg, Var) + and old_call_en_id in self._inputs_restof_subgraph_en_id + ): + assert name not in self._mrvl_layers_consecutive + if new_arg.name_hint not in self._inputs_call_names: + self._inputs.append(new_arg) + self._inputs_call_names[new_arg.name_hint] = new_arg + else: + if old_call_en_id in self._inputs_restof_subgraph_en_id: + assert name not in self._mrvl_layers_consecutive + if arg.op.name_hint not in self._inputs_call_names: + var = Var(arg.op.name_hint, arg.op.checked_type.ret_type) + self._inputs.append(var) + self._inputs_call_names[var.name_hint] = var + new_args[idx] = var + else: + new_args[idx] = self._inputs_call_names[arg.op.name_hint] + if name in self._mrvl_layers_consecutive: + if self._debug: + self.dump_debug_text_info(call, "drop call") + return None + return Call(new_fn, new_args, call.attrs, call.type_args, call.span) + + def visit_if(self, ite): + """override base class ExprMutator's visit function""" + new_cond = self.visit(ite.cond) + new_true_branch = self.visit(ite.true_branch) + new_false_branch = self.visit(ite.false_branch) + if (new_cond is None) or (new_true_branch is None) or (new_false_branch is None): + if self._debug: + self.dump_debug_text_info(ite, "drop ite") + return None + new_if = If(new_cond, new_true_branch, new_false_branch) + return new_if + + def visit_tuple(self, tup): + """override base class ExprMutator's visit function""" + fields_has_none = False + new_fields = [] + old_tup_en_id = tvm.relay._ffi_api.get_en_id(tup) + for idx, field in enumerate(tup.fields): + new_field = self.visit(field) + new_fields.append(new_field) + if new_field: + continue + if old_tup_en_id in self._inputs_restof_subgraph_en_id: + if field.op.name_hint not in self._inputs_call_names: + var = Var(field.op.name_hint, field.op.checked_type.ret_type) + self._inputs.append(var) + self._inputs_call_names[var.name_hint] = var + new_fields[idx] = var + else: + new_fields[idx] = self._inputs_call_names[field.op.name_hint] + else: + fields_has_none = True + if fields_has_none: + if self._debug: + self.dump_debug_text_info(tup, "drop tup") + return None + new_tup = Tuple(new_fields, tup.span) + return new_tup + + def get_restof_subgraph(self, main_func): + """return rest of subgraph""" + new_main_func = self.visit(main_func) + if new_main_func is not None: + if self._debug: + print("return new_main_func: {})".format(new_main_func.astext(False)), flush=True) + return new_main_func + # we need to instantiate a new output or output tuple + if self._debug: + print( + "got new_main_func==None and need to construct a tuple outputs (size={})".format( + len(self._outputs_call) + ), + flush=True, + ) + assert self._outputs_call + new_main_func = Function(self._inputs, self._outputs_call) + if self._debug: + print("new main func generated") + return new_main_func + + +class MrvlLayersGetOutputs(ExprVisitor): + """ FIXME """ + + def __init__(self, mrvl_consecutive_layers, mrvl_layers_to_defuse, debug=False): + """ FIXME """ + ExprVisitor.__init__(self) + self._debug = debug + self._mrvl_consecutive_layers = mrvl_consecutive_layers + self._mrvl_layers_to_defuse = mrvl_layers_to_defuse + self._outputs = {} + + def dump_debug_text_info(self, n, label): + """ FIXME """ + astext_list = n.astext(False).splitlines() + if astext_list[-1:][0] in [" */"]: + str_list = astext_list[-5:-4][0].split(") /*") + else: + str_list = astext_list[-1:][0].split(") /*") + print("{}: {})".format(label, str_list[0]), flush=True) + + def add_to_outputs_if_consecutive_mrvl_layer(self, n): + """ FIXME """ + # tuple is not a consecutive Mrvl layer + if ( + isinstance(n, Call) + and isinstance(n.op, GlobalVar) + and (n.op.name_hint in self._mrvl_consecutive_layers) + ): + self._outputs[n.op.name_hint] = True + if self._debug: + print("add outputs: {}".format(n.op.name_hint), flush=True) + + def visit_call(self, call): + """ FIXME """ + if self._debug: + self.dump_debug_text_info(call, "call") + call_mrvl_color_false = True + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + assert (name in self._mrvl_consecutive_layers) or (name in self._mrvl_layers_to_defuse) + if self._debug: + print("mrvl_layer: {}".format(name), flush=True) + if name in self._mrvl_consecutive_layers: + call_mrvl_color_false = False + else: + assert isinstance(call.op, tvm.ir.Op) + name = call.op.name + if self._debug: + print("non-mrvl-layer: {}".format(name), flush=True) + + if call_mrvl_color_false: + for arg in call.args: + if self._debug: + self.dump_debug_text_info(arg, "arg") + # add all consecutive Mrvl layers to outputs + self.add_to_outputs_if_consecutive_mrvl_layer(arg) + + super().visit_call(call) + + def visit_tuple(self, tup): + """ FIXME """ + # tuple is not a consecutive Mrvl layer + for field in tup.fields: + if self._debug: + self.dump_debug_text_info(field, "field") + # add all consecutive Mrvl layers to outputs + self.add_to_outputs_if_consecutive_mrvl_layer(field) + + super().visit_tuple(tup) + + def visit_tuple_getitem(self, t): + """ FIXME """ + # tuple_getitem is not a consecutive Mrvl layer + self.add_to_outputs_if_consecutive_mrvl_layer(t.tuple_value) + super().visit_tuple_getitem(t) + + def visit_let(self, let): + """ FIXME """ + # let is not a consecutive Mrvl layer + self.add_to_outputs_if_consecutive_mrvl_layer(let.var) + self.add_to_outputs_if_consecutive_mrvl_layer(let.value) + self.add_to_outputs_if_consecutive_mrvl_layer(let.body) + super().visit_let(let) + + def visit_if(self, i): + """ FIXME """ + # if is not a consecutive Mrvl layer + self.add_to_outputs_if_consecutive_mrvl_layer(i.cond) + self.add_to_outputs_if_consecutive_mrvl_layer(i.true_branch) + self.add_to_outputs_if_consecutive_mrvl_layer(i.false_branch) + super().visit_if(i) + + def visit_ref_create(self, r): + """ FIXME """ + # ref_create is not a consecutive Mrvl layer + self.add_to_outputs_if_consecutive_mrvl_layer(r.value) + super().visit_ref_create(r) + + def visit_ref_read(self, r): + """ FIXME """ + # ref_read is not a consecutive Mrvl layer + self.add_to_outputs_if_consecutive_mrvl_layer(r.ref) + super().visit_ref_read(r) + + def visit_ref_write(self, r): + """ FIXME """ + # ref_ref_write is not a consecutive Mrvl layer + self.add_to_outputs_if_consecutive_mrvl_layer(t.tuple_value) + super().visit_ref_write(r) + + def run(self, main_func): + """ FIXME """ + self.visit(main_func) + # at least one output + outputs_keys = self._outputs.keys() + # in a model containing all Mrvl layers, this can be [] + assert len(outputs_keys) >= 0 + return outputs_keys + + +# TODO(ccjoechou): Need to find all the possible cut points so that many corner +# cases can be identifed and fixed. +class RestMrvlLayersGetInputs(ExprVisitor): + """ FIXME """ + + def __init__(self, mrvl_consecutive_layers, mrvl_layers_to_defuse, debug=False): + """ FIXME """ + ExprVisitor.__init__(self) + self._debug = debug + self._mrvl_consecutive_layers = mrvl_consecutive_layers + self._mrvl_layers_to_defuse = mrvl_layers_to_defuse + self._inputs = {} + + def dump_debug_text_info(self, n, label): + """ FIXME """ + astext_list = n.astext(False).splitlines() + if astext_list[-1:][0] in [" */"]: + str_list = astext_list[-5:-4][0].split(") /*") + else: + str_list = astext_list[-1:][0].split(") /*") + print("{}: {})".format(label, str_list[0]), flush=True) + + def add_to_inputs_if_not_consecutive_mrvl_layer(self, n): + """ FIXME """ + # tuple is not a consecutive Mrvl layer + callnode_name = self.get_callnode_name(n) + if callnode_name is None: + return + en_id = tvm.relay._ffi_api.get_en_id(n) + if en_id not in self._inputs: + self._inputs[en_id] = callnode_name + if self._debug: + print("add inputs: {}".format(callnode_name), flush=True) + + def get_callnode_name(self, call): + """ FIXME """ + if isinstance(call, Call): + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + if self._debug: + print("layer: {}".format(name), flush=True) + else: + assert isinstance(call.op, tvm.ir.Op) + name = call.op.name + if self._debug: + print("non-mrvl-layer: {}".format(name), flush=True) + elif isinstance(call, Tuple): + name = "Tup_node" + else: + name = None + return name + + def visit_call(self, call): + """ FIXME """ + call_mrvl_color_false = False + callnode_name = self.get_callnode_name(call) + if callnode_name in self._mrvl_consecutive_layers: + return + for arg in call.args: + if isinstance(arg, Var): + call_mrvl_color_false = True + # This callnode has a direct var input + break + arg_name = self.get_callnode_name(arg) + if arg_name in self._mrvl_consecutive_layers: + call_mrvl_color_false = True + break + if call_mrvl_color_false: + self.add_to_inputs_if_not_consecutive_mrvl_layer(call) + super().visit_call(call) + + def visit_tuple(self, tup): + """ FIXME """ + # tuple is not a consecutive Mrvl layer + call_mrvl_color_false = False + for field in tup.fields: + # add all consecutive Mrvl layers to outputs + if isinstance(field, Var): + # This callnode has a direct var input + call_mrvl_color_false = True + break + arg_name = self.get_callnode_name(field) + if arg_name in self._mrvl_consecutive_layers: + call_mrvl_color_false = True + break + if call_mrvl_color_false: + self.add_to_inputs_if_not_consecutive_mrvl_layer(tup) + super().visit_tuple(tup) + + def visit_tuple_getitem(self, t): + """ FIXME """ + # tuple_getitem is not a consecutive Mrvl layer + self.add_to_inputs_if_not_consecutive_mrvl_layer(t) + super().visit_tuple_getitem(t) + + def visit_let(self, let): + """ FIXME """ + # let is not a consecutive Mrvl layer + self.add_to_inputs_if_not_consecutive_mrvl_layer(let.var) + self.add_to_inputs_if_not_consecutive_mrvl_layer(let.value) + self.add_to_inputs_if_not_consecutive_mrvl_layer(let.body) + super().visit_let(let) + + def visit_if(self, i): + """ FIXME """ + # if is not a consecutive Mrvl layer + self.add_to_inputs_if_not_consecutive_mrvl_layer(i.cond) + self.add_to_inputs_if_not_consecutive_mrvl_layer(i.true_branch) + self.add_to_inputs_if_not_consecutive_mrvl_layer(i.false_branch) + super().visit_if(i) + + def visit_ref_create(self, r): + """ FIXME """ + # ref_create is not a consecutive Mrvl layer + self.add_to_inputs_if_not_consecutive_mrvl_layer(r.value) + super().visit_ref_create(r) + + def visit_ref_read(self, r): + """ FIXME """ + # ref_read is not a consecutive Mrvl layer + self.add_to_inputs_if_not_consecutive_mrvl_layer(r.ref) + super().visit_ref_read(r) + + def visit_ref_write(self, r): + """ FIXME """ + # ref_ref_write is not a consecutive Mrvl layer + self.add_to_inputs_if_not_consecutive_mrvl_layer(r.tuple_value) + super().visit_ref_write(r) + + def run(self, main_func): + """ FIXME """ + self.visit(main_func) + # at least one output + inputs_en_id = self._inputs.keys() + if self._debug: + for key, value in self._inputs.items(): + print("{}.{}".format(value, key)) + # in a model containing all Mrvl layers, this can be [] + return inputs_en_id + + +class MrvlSubgraphToRevert(ExprMutator): + """Reverts subgraphs, which are listed in the subgraphs_to_revert list, + back to TVM operators instead of using an external codegen (in Mrvl layers). + """ + + def __init__(self, subgraphs_to_revert, mod): + ExprMutator.__init__(self) + self._subgraphs_to_revert = subgraphs_to_revert + self._mod = mod + + def visit_call(self, call): + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + if name in self._subgraphs_to_revert: + # "Inline" the subgraph back into new main function. + func = self._mod[name] + var_map = {} + for arg, param in zip(call.args, func.params): + var_map[param] = super().visit(arg) + new_body = relay.bind(func.body, var_map) + # return the original TVM function body, instead of Mrvl Layer + return new_body + # if call is not "def @main(...) { ... }" + if name != "main": + args = [] + for arg in call.args: + args.append(super().visit(arg)) + return call.op(*args) + return super().visit_call(call) + + +def revert_mrvl_mod_to_orig(mod_mrvl, mrvl_layers_in_mrvl_subgraph, debug=False): + """revert Mrvl subgraph mod and return its (original) TVM IR and params + mod_mrvl: Mrvl subgraph IR with parameters - in fused Mrvl layers + mrvl_layers_in_mrvl_subgraph: list of Mrvl layer composite function names, which + are going to be reverted + """ + + def run_opt_pass(mod, passes): + passes = passes if isinstance(passes, list) else [passes] + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + return mod + + if debug: + print("Debug: mod_mrvl:\n{}\n\n)".format(mod_mrvl.astext(False)), flush=True) + mod_new = tvm.IRModule(mod_mrvl.functions, mod_mrvl.type_definitions) + mod_new["main"] = MrvlSubgraphToRevert(mrvl_layers_in_mrvl_subgraph, mod_mrvl).visit( + mod_mrvl["main"] + ) + mod_new = relay.transform.RemoveUnusedFunctions()(mod_new) + mod_new = relay.transform.InferType()(mod_new) + if debug: + print("Debug: mod_new (defused level1):\n{}\n\n)".format(mod_new.astext(False)), flush=True) + + mod_new = run_opt_pass(mod_new, relay.transform.DefuseOps()) + if debug: + print("Debug: mod_new (defused level2):\n{}\n\n)".format(mod_new.astext(False)), flush=True) + + # need to reset back to use the default FTVConvertOpLayout function + mod_new = run_opt_pass( + mod_new, + relay.transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"], "nn.max_pool2d": ["NCHW"]}), + ) + if debug: + print( + "Debug: mod_new (convert layout):\n{}\n\n)".format(mod_new.astext(False)), + flush=True, + ) + + mod_new = run_opt_pass(mod_new, relay.transform.SimplifyExpr()) + if debug: + print("Debug: mod_new (simplified):\n{}\n\n)".format(mod_new.astext(False)), flush=True) + mod_new = run_opt_pass(mod_new, relay.transform._ffi_api.DropNoopTranspose()) + if debug: + print( + "Debug: mod_new (drop noop transpose):\n{}\n\n)".format(mod_new.astext(False)), + flush=True, + ) + mod_new = run_opt_pass(mod_new, relay.transform.InferType()) + if debug: + print("Debug: mod_new (infertype):\n{}\n\n)".format(mod_new.astext(False)), flush=True) + return mod_new + + +class MrvlIRGraphUtils: + """Mrvl IR graph analysis utilities""" + + def __init__(self, debug=False): + self._debug = debug + + def get_mrvl_layers_and_main_func(self, mod): + """get all mrvl layers (which are annotated mrvl target sub graphs)""" + main_func = None + mrvl_layers = [] + for annotated_subgraph in mod.get_global_vars(): + name = annotated_subgraph.name_hint + if name in ["main"]: + main_func = mod[name] + if (not mod[name].attrs) or (mod[name].attrs["Compiler"] != "mrvl"): + continue + mrvl_layers.append(name) + return mrvl_layers, main_func + + def dump_main_func(self, mod, prefix_str="mod[main]"): + """dump def @main of mod""" + main_func = None + for func_var in mod.get_global_vars(): + name = func_var.name_hint + if name in ["main"]: + main_func = mod[name] + assert main_func + print("{} => {}".format(prefix_str, main_func.astext(False)), flush=True) + + def compute_two_subgraphs( + self, mod, defuse_mrvl_layers_list=None, gen_non_mrvl_subgraph=False, flow_pass=1 + ): + """produce a Mrvl-layer sub graph and a graph where non-consecutive Mrvl layers + are de-fused back to TVM operators + """ + + # find call.op names for Mrvl layers + mrvl_layers, main_func = self.get_mrvl_layers_and_main_func(mod) + assert main_func + + # find consecutive Mrvl layers and Mrvl layers, which need to be defused + mutator = MrvlLayers( + mutate_style="compute-mrvl-color", + mrvl_layer_names=mrvl_layers, + defuse_mrvl_layers_list=defuse_mrvl_layers_list, + debug=self._debug, + ) + mod["main"] = mutator.compute_main_func_mrvl_color(main_func) + mrvl_layers_consecutive, mrvl_layers_to_defuse = mutator.get_consecutive_layers() + mrvl_layers_consecutive_keys = mrvl_layers_consecutive.keys() + mrvl_layers_to_defuse_keys = mrvl_layers_to_defuse.keys() + if self._debug: + print("\nDebug: flow_pass: {}".format(flow_pass), flush=True) + print( + "\nDebug: to {} mrvl layers - {})".format(len(mrvl_layers), mrvl_layers), flush=True + ) + print( + "\nDebug: to {} mrvl consecutive layers - {})".format( + len(mrvl_layers_consecutive_keys), mrvl_layers_consecutive_keys + ), + flush=True, + ) + print( + "\nDebug: to {} mrvl to defuse layers - {})".format( + len(mrvl_layers_to_defuse_keys), mrvl_layers_to_defuse_keys + ), + flush=True, + ) + print( + "\nDebug: given defuse_mrvl_layers_list - {})".format(defuse_mrvl_layers_list), + flush=True, + ) + assert len(mrvl_layers) == ( + len(mrvl_layers_consecutive_keys) + len(mrvl_layers_to_defuse_keys) + ) + + # FIXME: do post-order DFS traverse analysis based on the value of the !mrvl_color attribute + # to decide whether to exclude non-Mrvl-layer operators from IR graph + # - ran into TVM error: Check failed: (n.defined()) is false: Found null + # pointer node while traversing AST. The previous pass may have generated invalid data + # figure out outputs, which is a list of Mrvl layers + mrvl_layers_outputs = MrvlLayersGetOutputs( + mrvl_layers_consecutive, mrvl_layers_to_defuse, debug=self._debug + ).run(mod["main"]) + # generate a subgraph for consecutive Mrvl layers + mod_mrvl = tvm.IRModule(mod.functions, mod.type_definitions) + mutator2 = MrvlLayers( + mutate_style="get-mrvl-subgraph", + mrvl_layers_consecutive=mrvl_layers_consecutive, + mrvl_layers_outputs=mrvl_layers_outputs, + debug=self._debug, + ) + # print("type(mutator2): {}".format(type(mutator2).__name__), flush=True) + mod_mrvl["main"] = mutator2.get_main_func_mrvl_subgraph(mod["main"]) + mod_mrvl = relay.transform.InferType()(mod_mrvl) + if self._debug: + print("Debug: mod_mrvl: {})".format(mod_mrvl.astext(False)), flush=True) + rest_of_subgraph_inputs_en_id = RestMrvlLayersGetInputs( + mrvl_layers_consecutive, mrvl_layers_to_defuse, debug=self._debug + ).run(mod["main"]) + flag_flowpass1 = ( + gen_non_mrvl_subgraph and flow_pass == 1 and len(rest_of_subgraph_inputs_en_id) > 0 + ) + flag_flowpass2 = flow_pass == 2 and len(rest_of_subgraph_inputs_en_id) > 0 + if flag_flowpass1 or flag_flowpass2: + mod_restofsubgraph = mod + mutator3 = RestOfMrvlLayers( + mrvl_layers_consecutive=mrvl_layers_consecutive, + rest_of_subgraph_inputs_en_id=rest_of_subgraph_inputs_en_id, + debug=self._debug, + ) + # print("type(mutator3): {}".format(type(mutator3).__name__), flush=True) + mod_restofsubgraph["main"] = mutator3.get_restof_subgraph(mod["main"]) + if len(mrvl_layers_to_defuse_keys) > 0: + # revert Mrvl layers, which are not in consecutive Mrvl layers, + # back as in-line functions + mod_restofsubgraph = revert_mrvl_mod_to_orig( + mod_restofsubgraph, mrvl_layers_to_defuse_keys + ) + else: + mod_restofsubgraph = None + return mod_mrvl, mod_restofsubgraph, mrvl_layers_consecutive_keys diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index 24d9061a213f..9c31053554f3 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -81,6 +81,8 @@ use-arm-compute-lib-graph-runtime = ["tvm-sys/use-arm-compute-lib-graph-runtime" use-tensorrt-codegen = ["tvm-sys/use-tensorrt-codegen"] use-tensorrt-runtime = ["tvm-sys/use-tensorrt-runtime"] use-vitis-ai = ["tvm-sys/use-vitis-ai"] +use-mrvl = ["tvm-sys/use-mrvl"] +use-mrvl-runtime = ["tvm-sys/use-mrvl-runtime"] build-static-runtime = ["tvm-sys/build-static-runtime"] [dependencies] diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index 4494e20afa31..dbd989548f81 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -74,6 +74,8 @@ use-arm-compute-lib-graph-runtime = [] use-tensorrt-codegen = [] use-tensorrt-runtime = [] use-vitis-ai = [] +use-mrvl = [] +use-mrvl-runtime = [] build-static-runtime = [] [dependencies] diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 80c7efbaf894..ace0b18f85df 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -198,6 +198,16 @@ fn find_using_tvm_build() -> Result { if cfg!(feature = "use-vitis-ai") { build_config.settings.use_vitis_ai = Some(true); } + if cfg!(feature = "use-mrvl") { + // FIXME: do we need to register use_mrvl for the tvm-build v0.24 package + // on crates.io first? + build_config.settings.use_mrvl = CMakeSetting::from_str("on").ok(); + } + if cfg!(feature = "use-mrvl-runtime") { + // FIXME: do we need to register use_mrvl_runtime for the tvm-build v0.24 package + // on crates.io first? + build_config.settings.use_mrvl_runtime = CMakeSetting::from_str("on").ok(); + } if cfg!(any( feature = "static-linking", feature = "build-static-runtime" diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index 8d9b23f7616b..8def271a1e18 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -80,6 +80,8 @@ use-arm-compute-lib-graph-runtime = ["tvm-rt/use-arm-compute-lib-graph-runtime"] use-tensorrt-codegen = ["tvm-rt/use-tensorrt-codegen"] use-tensorrt-runtime = ["tvm-rt/use-tensorrt-runtime"] use-vitis-ai = ["tvm-rt/use-vitis-ai"] +use-mrvl = ["tvm-rt/use-mrvl"] +use-mrvl-runtime = ["tvm-rt/use-mrvl-runtime"] [dependencies.tvm-rt] version = "0.1.0-alpha" diff --git a/src/ir/op.cc b/src/ir/op.cc index fac15a7daad4..c222bb7426fd 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -27,6 +27,7 @@ #include #include +#include #include #include "../node/attr_registry.h" @@ -40,6 +41,21 @@ using tir::FLowerIntrinsic; using OpRegistry = AttrRegistry; +void Op::ListAllOpNames() { + auto names = OpRegistry::Global()->ListAllNames(); + std::for_each(names.begin(), names.end(), + [](String const& name) { LOG(INFO) << "op name: " << name; }); +} + +String Op::GetOpName(const Op& op) { + auto names = OpRegistry::Global()->ListAllNames(); + String op_name = ""; + std::for_each(names.begin(), names.end(), [&op_name, op](String const& name) { + if (op == Op::Get(name)) op_name = String(name); + }); + return op_name; +} + // find operator by name const Op& Op::Get(const String& name) { const OpRegEntry* reg = OpRegistry::Global()->Get(name); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ccfd30476f67..532fe8b2f7f0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -54,6 +54,7 @@ using namespace tvm::relay::transform; */ struct BuildOutput { std::string graph_json; + std::string external_graph_json{""}; runtime::Module mod; std::unordered_map params; }; @@ -141,10 +142,17 @@ struct GraphCodegen : ExecutorCodegen { auto pf = GetPackedFunc("relay.build_module._GraphExecutorCodegen"); mod = (*pf)(); } - void UpdateOutput(BuildOutput* ret) override { ret->graph_json = GetGraphJSON(); } + void UpdateOutput(BuildOutput* ret) override { + ret->graph_json = GetGraphJSON(); + ret->external_graph_json = GetExternalGraphJSON(); + } std::string GetGraphJSON() { return CallFunc("get_graph_json", nullptr); } + std::string GetExternalGraphJSON() { + return CallFunc("get_external_graph_json", nullptr); + } + ~GraphCodegen() {} }; @@ -181,6 +189,10 @@ class RelayBuildModule : public runtime::ModuleNode { if (name == "get_graph_json") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); }); + } else if (name == "get_external_graph_json") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetExternalGraphJSON(); + }); } else if (name == "get_module") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); @@ -236,6 +248,13 @@ class RelayBuildModule : public runtime::ModuleNode { */ const std::string& GetGraphJSON() { return ret_.graph_json; } + /*! + * \brief Get the ExternalGraphJSON for runtime + * + * \return const std::string externl_graph_json + */ + const std::string& GetExternalGraphJSON() { return ret_.external_graph_json; } + /*! * \brief Get the Module object * @@ -470,6 +489,13 @@ class RelayBuildModule : public runtime::ModuleNode { auto ext_mods = executor_codegen_->GetExternalModules(); ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, runtime_, executor_codegen_->GetMetadata()); + +// TODO(ccjoechou): to use the external codegen's metadata flow, we will need to support the +// export_library() call to generate a lib.so, which includes the metadata; and then, we +// will also need to support the load_module() call to load the generated lib.so +// - since we (MRVL) are not supporting both export_library() and load_module() steps, we +// have to keep ret_.params to store constant params as before +#ifndef TVM_USE_MRVL // Remove external params which were stored in metadata module. for (tvm::runtime::Module mod : ext_mods) { auto pf_var = mod.GetFunction("get_const_vars"); @@ -483,6 +509,7 @@ class RelayBuildModule : public runtime::ModuleNode { } } } +#endif } protected: diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 4966f3f01c7d..86507a254032 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -173,6 +173,9 @@ class JSONSerializer : public MemoizedExprTranslator GetNodes() const { return nodes_; } + protected: /*! * \brief Add a node to graph. diff --git a/src/relay/backend/contrib/mrvl/codegen.cc b/src/relay/backend/contrib/mrvl/codegen.cc new file mode 100644 index 000000000000..fe7a5f16dfc9 --- /dev/null +++ b/src/relay/backend/contrib/mrvl/codegen.cc @@ -0,0 +1,890 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/mrvl/codegen.cc + * \brief Marvell MLIP specific API + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "picojson.h" + +#define USE_JSON_RUNTIME 1 +#ifdef USE_JSON_RUNTIME + +#include "../../../../runtime/contrib/json/json_node.h" +#include "../../../qnn/utils.h" +#include "../../utils.h" +#include "../codegen_json/codegen_json.h" + +#else + +// TODO(ccjoechou): TBA if needed -- follow "per layer" C-codegen example + +#endif + +namespace tvm { +namespace relay { +namespace contrib { +namespace mrvl { + +using namespace backend; + +extern "C" bool g_mrvlExtJsonObjInstantized; + +extern "C" void InstantiateMrvlExtJsonObj(); + +#ifndef USE_JSON_RUNTIME + +// TODO(ccjoechou): TBA if needed -- follow "per layer" C-codegen example + +#else + +/*! + * \brief Generates an MrvlModule from a relay expression. This "compilation" + * does not require Mrvl driver since the actual conversion using Mrvl APIs is + * deferred until creation of the runtime. This step simply serializes the + * relay program into a JSON string. + */ +class MrvlJSONSerializer : public backend::contrib::JSONSerializer { + using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; + + public: + MrvlJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) { + func_name_ = symbol; + } + + /*! + * \brief Struct to capture original frontend NN model's + * first/last operator names for fused Mrvl layer + */ + struct FrontendOpNames { + std::string first_op_name = "unknown"; + std::string last_op_name = "unknown"; + }; + + /*! + * \brief A series of operators that form a composite + * convolution. Supports both nn.conv2d and qnn.conv2d. + */ + struct CompositeConvNode { + const CallNode* pad = nullptr; + const CallNode* conv = nullptr; + const CallNode* add = nullptr; + const CallNode* batch_norm = nullptr; + const CallNode* activation = nullptr; + FrontendOpNames op_names; + }; + + /*! + * \brief A series of operators that form a composite + * convolution. Supports sum2d + */ + struct CompositeSum2DNode { + const CallNode* add = nullptr; + const CallNode* activation = nullptr; + FrontendOpNames op_names; + }; + + /*! + * \brief A series of operators that form a composite + * maxpool or avgpool. Supports both nn.max_pool2d and qnn.conv2d. + */ + struct CompositePoolNode { + const CallNode* pad = nullptr; + const CallNode* pool = nullptr; + FrontendOpNames op_names; + }; + + /*! + * \brief A series of operators that form a composite + * fc layer. Supports both nn.fc_ni2no and qnn.fc_ni2no. + */ + struct CompositeFcNode { + const CallNode* fc = nullptr; + const CallNode* add = nullptr; + const CallNode* activation = nullptr; + FrontendOpNames op_names; + }; + + /*! + * \brief Visit call nodes and generate appropriate JSON node. + * + * \param cn The current call node. + * \return A list of graph entry nodes. + */ + std::vector VisitExpr_(const CallNode* cn) override { + const auto* op_node = cn->op.as(); + if (op_node) { + // handle certain op node types specially + String op_name = tvm::Op::GetOpName(GetRef(op_node)); + bool handle_by_mrvl = (op_name == "reshape") || (op_name == "layout_transform") || + (op_name == "nn.batch_flatten") || (op_name == "transpose"); + if (!handle_by_mrvl) { + return JSONSerializer::VisitExpr_(cn); + } + + // setup json attributes and then add the Mrvl Layer to JSON files + std::shared_ptr json_node; + json_node = CreateMrvlLayer4OpNode(cn); + return AddNode(json_node, GetRef(cn)); + } + + // handle only mrvl composite functions + if (!cn->op.as()) { + LOG(FATAL) << "Mrvl JSON runtime does not support calls to " << cn->op->GetTypeKey(); + } + auto fn = cn->op.as(); + auto comp = fn->GetAttr(attr::kComposite); + ICHECK(comp.defined()) << "Mrvl JSON runtime only supports composite functions."; + const std::string name = comp.value(); + std::shared_ptr json_node; + if (name == "mrvl.conv2d_nhwc2nhwc") { + json_node = CreateCompositeMrvlConv2DLayer(cn); + } else if (name == "mrvl.fc_ni2no") { + json_node = CreateCompositeMrvlFcLayer(cn); + } else if (name == "mrvl.maxpool2d_nhwc2nhwc") { + json_node = CreateCompositeMrvlMaxpool2DLayer(cn); + } else if (name == "mrvl.avgpool2d_nhwc2nhwc") { + json_node = CreateCompositeMrvlAvgpool2DLayer(cn); + } else if (name == "mrvl.sum2d") { + json_node = CreateCompositeMrvlSum2DLayer(cn); + } else { + LOG(FATAL) << "Unrecognized Mrvl pattern: " << name; + } + // calling codegen_json.h::AddNode() + return AddNode(json_node, GetRef(cn)); + } + + private: + std::string func_name_; + + void JsonNodeSetAttr(std::shared_ptr json_node, const std::string& key, + const std::vector& string_vec) { + std::vector json_attr; + json_attr.emplace_back(string_vec); + json_node->SetAttr(key, json_attr); + } + + void JsonNodeSetVecAttr(std::shared_ptr json_node, const std::string& key, + const std::vector& tvec) { + size_t tvec_size = tvec.size(); + std::vector tvec_str; + if (tvec_size == 4) { + tvec_str = {std::to_string(tvec[0]), std::to_string(tvec[1]), std::to_string(tvec[2]), + std::to_string(tvec[3])}; + } else if (tvec_size == 3) { + tvec_str = {std::to_string(tvec[0]), std::to_string(tvec[1]), std::to_string(tvec[2])}; + } else if (tvec_size == 2) { + tvec_str = {std::to_string(tvec[0]), std::to_string(tvec[1])}; + } else if (tvec_size == 1) { + tvec_str = {std::to_string(tvec[0])}; + } else { + LOG(INFO) << "Vector size (" << tvec_size << ") is not supported."; + } + std::vector json_attr; + json_attr.emplace_back(tvec_str); + json_node->SetAttr(key, json_attr); + } + + void setMrvlLayerCommonAttrs(std::shared_ptr json_node, const CallNode* cn, + const std::string& func_name, const std::string& mrvlLayerName, + const std::string& data_layout, const std::string& kernel_layout, + const std::string& out_layout) { + // MUST use the JSONGraphAttrs attrs_ style + // as described in tvm/src/relay/contrib/json/json_node.h + + // add other mrvl-specific attributes + JsonNodeSetAttr(json_node, "layer_name", {mrvlLayerName}); + JsonNodeSetAttr(json_node, "func_node_name", {func_name}); + std::vector data_layout_vec; + GetInputTensorShapeViaArg0(cn, &data_layout_vec); + JsonNodeSetVecAttr(json_node, "data_layout_shape", data_layout_vec); + std::vector out_layout_vec; + GetOutputTensorShape(cn, &out_layout_vec); + JsonNodeSetVecAttr(json_node, "out_layout_shape", out_layout_vec); + if (data_layout != "") { + std::vector data_layout_format_vec = {data_layout}; + JsonNodeSetAttr(json_node, "data_layout", data_layout_format_vec); + } + if (kernel_layout != "") { + std::vector kernel_layout_format_vec = {kernel_layout}; + JsonNodeSetAttr(json_node, "kernel_layout", kernel_layout_format_vec); + } + if (out_layout != "") { + std::vector out_layout_format_vec = {out_layout}; + JsonNodeSetAttr(json_node, "out_layout", out_layout_format_vec); + } + } + + void SetMrvlSpecificJsonNodeAttrs(std::shared_ptr json_node, const CallNode* cn, + const CallNode* cn_pad, const CallNode* cn_pool, + const CallNode* cn_conv, const CallNode* cn_fc, + const CallNode* cn_add, const CallNode* cn_batch_norm, + const CallNode* cn_activation, const std::string& mrvlLayerName, + const std::string& data_layout, + const std::string& kernel_layout, const std::string& out_layout, + const std::string& bias_layout, + const std::string& activation_op, + const FrontendOpNames& op_names) { + // MUST use the JSONGraphAttrs attrs_ style + // as described in tvm/src/relay/contrib/json/json_node.h + setMrvlLayerCommonAttrs(json_node, cn, func_name_, mrvlLayerName, data_layout, kernel_layout, + out_layout); + // + if (cn_conv || cn_fc) { + std::vector kernel_const_name = {func_name_ + "_const_0"}; + JsonNodeSetAttr(json_node, "kernel_const_name", kernel_const_name); + } + // + if (cn_add) { + if (mrvlLayerName == "Sum2D") { + // FIXME: any specific attributes to add here for Sum2D? + JsonNodeSetAttr(json_node, "out_layout", {out_layout}); + } else { + std::vector bias_const_name = {func_name_ + "_const_1"}; + JsonNodeSetAttr(json_node, "bias_const_name", bias_const_name); + JsonNodeSetAttr(json_node, "bias_layout", {bias_layout}); + } + } + if (cn_batch_norm) { + std::string gamma_const_name_postfix; + std::string beta_const_name_postfix; + std::string mean_const_name_postfix; + std::string var_const_name_postfix; + if (cn_add) { + gamma_const_name_postfix = "_const_2"; + beta_const_name_postfix = "_const_3"; + mean_const_name_postfix = "_const_4"; + var_const_name_postfix = "_const_5"; + } else { + gamma_const_name_postfix = "_const_1"; + beta_const_name_postfix = "_const_2"; + mean_const_name_postfix = "_const_3"; + var_const_name_postfix = "_const_4"; + } + std::string batch_norm_layout = "-O"; + std::vector gamma_const_name = {func_name_ + gamma_const_name_postfix}; + JsonNodeSetAttr(json_node, "gamma_const_name", gamma_const_name); + JsonNodeSetAttr(json_node, "gamma_layout", {batch_norm_layout}); + std::vector beta_const_name = {func_name_ + beta_const_name_postfix}; + JsonNodeSetAttr(json_node, "beta_const_name", beta_const_name); + JsonNodeSetAttr(json_node, "beta_layout", {batch_norm_layout}); + std::vector mean_const_name = {func_name_ + mean_const_name_postfix}; + JsonNodeSetAttr(json_node, "mean_const_name", mean_const_name); + JsonNodeSetAttr(json_node, "mean_layout", {batch_norm_layout}); + std::vector var_const_name = {func_name_ + var_const_name_postfix}; + JsonNodeSetAttr(json_node, "var_const_name", var_const_name); + JsonNodeSetAttr(json_node, "var_layout", {batch_norm_layout}); + } + // + if (cn_pool && (mrvlLayerName == "Maxpool2D")) { + auto pool_attrs = cn_pool->attrs.as(); + ICHECK(pool_attrs != nullptr); + std::vector kernel_layout_vec; + kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[0]))); + kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[1]))); + JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec); + } + if (cn_pool && (mrvlLayerName == "Avgpool2D")) { + auto pool_attrs = cn_pool->attrs.as(); + ICHECK(pool_attrs != nullptr); + std::vector kernel_layout_vec; + kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[0]))); + kernel_layout_vec.push_back(*(tir::as_const_int(pool_attrs->pool_size[1]))); + JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec); + } + // + const auto* fn = cn->op.as(); + ICHECK((fn != nullptr) && fn->IsInstance()); + auto composite = fn->GetAttr(attr::kComposite); + ICHECK(composite.defined()); + std::string composite_name = composite.value(); + JsonNodeSetAttr(json_node, "composite_name", {composite_name}); + JsonNodeSetAttr(json_node, "first_op_name", {op_names.first_op_name}); + JsonNodeSetAttr(json_node, "last_op_name", {op_names.last_op_name}); + + // Override attributes, if nn.pad() found + // - for 2D: h * w: h-begin (top), w-begin (left), h-end (bottom), w-end (right) + // -- for Conv-2D op and Pool-2D op + if (cn_pad) { + const auto* pad_attr = cn_pad->attrs.as(); + ICHECK(pad_attr); + auto p = pad_attr->pad_width; + // Convert to TVM layout for now, conversion to Mrvl layout takes place in runtime. + // Standard convolution pad layout for TVM: top, left, bottom, right. + std::vector padding = {std::to_string(p[1][0].as()->value), + std::to_string(p[2][0].as()->value), + std::to_string(p[1][1].as()->value), + std::to_string(p[2][1].as()->value)}; + std::vector padding_attr; + padding_attr.emplace_back(padding); + json_node->SetAttr("padding", padding_attr); + } + // Override attributes + if (cn_activation) { + std::vector activation_type = {activation_op}; + std::vector act_attr; + act_attr.emplace_back(activation_type); + json_node->SetAttr("activation_type", act_attr); + } + } + + /*! + * \brief Extract convolution nodes from a composite function. + * + * \param call The call node of the composite function. + * \return Extracted composite convolution nodes. + */ + static CompositeConvNode UnpackCompositeConvolution(const CallNode* call) { + CompositeConvNode nodes{}; + const auto* fn = call->op.as(); + ICHECK(fn); + // we can see the pattern below (but call graph starts from right most & backward): + // - conv2d + [ bias_add ] + [ batch_norm + tuple.getitem(0) ] + [ relu ] + // + // Thus, we need to handle following cases: + // - case1: conv2d + // - case2: conv2d + relu + // - case3: conv2d + batch_norm + tuple.getitem(0) + // - case4: conv2d + batch_norm + tuple.getitem(0) + relu + // - case5: conv2d + add + // - case6: conv2d + add + relu + // - case7: conv2d + add + batch_norm + tuple.getitem(0) + // - case8: conv2d + add + batch_norm + tuple.getitem(0) + relu + + // Traverse composite convolution function from child to parent + const TupleGetItemNode* tuple_get_item_node = nullptr; + const CallNode* current_call = fn->body.as(); + if (current_call) { + // for case1, case2, case4, case5, case6, case8 + if (backend::IsOp(current_call, "nn.relu")) { + // for case2, case4, case6, case8 + nodes.activation = current_call; + + if (current_call->args[0].as()) { + // fall through for case4, case8 + tuple_get_item_node = current_call->args[0].as(); + } else { + // fall through for case2, case6: to use current_call as CallNode* + current_call = current_call->args[0].as(); + } + } else { + // fall through for case1, case5: to use current_call as CallNode* + ICHECK(current_call); + } + } else { + // for case3, case7 + tuple_get_item_node = fn->body.as(); + } + + // it can be a call node for add op or conv2d op + // OR it can be a TupleGetItem node followed by a batch_norm + if (tuple_get_item_node != nullptr) { + // for case3, case4, case7, case8 + ICHECK(tuple_get_item_node); + ICHECK(tuple_get_item_node->index == 0); + current_call = tuple_get_item_node->tuple.as(); + + ICHECK(backend::IsOp(current_call, "nn.batch_norm")); + nodes.batch_norm = current_call; + current_call = current_call->args[0].as(); + } + + ICHECK(current_call); + if (backend::IsOp(current_call, "add")) { + nodes.add = current_call; + current_call = current_call->args[0].as(); + } + + ICHECK(backend::IsOp(current_call, "nn.conv2d")); + nodes.conv = current_call; + + if (!current_call->args.empty() && current_call->args[0]->IsInstance()) { + current_call = current_call->args[0].as(); + if (backend::IsOp(current_call, "nn.pad")) { + nodes.pad = current_call; + } + } + return nodes; + } + + /*! + * \brief Extract sum2d nodes from a composite function. + * + * \param call The call node of the composite function. + * \return Extracted composite sum2d nodes. + */ + static CompositeSum2DNode UnpackCompositeSum2D(const CallNode* call) { + CompositeSum2DNode nodes{}; + const auto* fn = call->op.as(); + ICHECK(fn); + + // Traverse composite convolution function from child to parent + const auto* current_call = fn->body.as(); + if (backend::IsOp(current_call, "nn.relu")) { + nodes.activation = current_call; + current_call = current_call->args[0].as(); + } + + ICHECK(backend::IsOp(current_call, "add")); + nodes.add = current_call; + + return nodes; + } + + /*! + * \brief Extract maxpool nodes from a composite function. + * + * \param call The call node of the composite function. + * \return Extracted composite maxpool nodes. + */ + static CompositePoolNode UnpackCompositePool(const CallNode* call, + const std::string& mrvlLayerName) { + CompositePoolNode nodes{}; + const auto* fn = call->op.as(); + ICHECK(fn); + + // Traverse composite maxpool function from child to parent + const auto* current_call = fn->body.as(); + if (mrvlLayerName == "Maxpool2D") { + ICHECK(backend::IsOp(current_call, "nn.max_pool2d")); + } else { + ICHECK(mrvlLayerName == "Avgpool2D"); + ICHECK(backend::IsOp(current_call, "nn.avg_pool2d")); + } + nodes.pool = current_call; + if (!current_call->args.empty() && current_call->args[0]->IsInstance()) { + current_call = current_call->args[0].as(); + if (backend::IsOp(current_call, "nn.pad")) { + nodes.pad = current_call; + } + } + return nodes; + } + + void GetInputTensorShapeViaArg0(const CallNode* call_node_ptr, + std::vector* tensor_shape) { + ICHECK(!call_node_ptr->args.empty()); + const TensorTypeNode* tensor_type = nullptr; + if (call_node_ptr->args[0].as()) { + const auto* arg0 = call_node_ptr->args[0].as(); + tensor_type = arg0->checked_type_.as(); + } else if (call_node_ptr->args[0].as()) { + const auto* arg0 = call_node_ptr->args[0].as(); + ICHECK((arg0 != nullptr) && arg0->IsInstance()); + tensor_type = arg0->checked_type_.as(); + } else { + LOG(INFO) << "TVM Mrvl runtime does not support calls to " + << call_node_ptr->args[0]->GetTypeKey(); + } + + ICHECK((tensor_type != nullptr) && tensor_type->IsInstance()); + // use only data types supported by json.h (e.g., int or int64_t or size_t) + for (IndexExpr dim_val : tensor_type->shape) { + tensor_shape->push_back(*(tir::as_const_int(dim_val))); + } + } + + void GetTensorShape(const VarNode* var_node_ptr, std::vector* tensor_shape) { + ICHECK((var_node_ptr != nullptr) && var_node_ptr->IsInstance()); + const TensorTypeNode* tensor_type = var_node_ptr->checked_type_.as(); + ICHECK((tensor_type != nullptr) && tensor_type->IsInstance()); + // use only data types supported by json.h (e.g., int or int64_t or size_t) + for (IndexExpr dim_val : tensor_type->shape) { + tensor_shape->push_back(*(tir::as_const_int(dim_val))); + } + } + + void GetOutputTensorShape(const CallNode* call_node_ptr, std::vector* tensor_shape) { + ICHECK(call_node_ptr != nullptr); + const TensorTypeNode* tensor_type = call_node_ptr->checked_type_.as(); + ICHECK((tensor_type != nullptr) && tensor_type->IsInstance()); + // use only data types supported by json.h (e.g., int or int64_t or size_t) + for (IndexExpr dim_val : tensor_type->shape) { + tensor_shape->push_back(*(tir::as_const_int(dim_val))); + } + } + + /*! + * \brief Create a JSON representation of a composite convolution. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeMrvlConv2DLayer(const CallNode* cn) { + CompositeConvNode nodes = UnpackCompositeConvolution(cn); + const auto* conv_attrs = nodes.conv->attrs.as(); + ICHECK(conv_attrs); + + // Inputs must be added in the same order they appear in the relay graph. + std::vector inputs; + // data input tensor + inputs.push_back(VisitExpr(cn->args[0])[0]); + // weight tensor + inputs.push_back(VisitExpr(nodes.conv->args[1])[0]); + if (nodes.add) { + // bias tensor + inputs.push_back(VisitExpr(nodes.add->args[1])[0]); + } + + // Distinguish between normal and depth-wise convolution + std::string name; + std::string mrvlLayerName = ""; + std::string data_layout = conv_attrs->data_layout; + std::string kernel_layout = conv_attrs->kernel_layout; + std::string out_layout = conv_attrs->out_layout; + int groups = conv_attrs->groups; + if ((groups != 1) && conv_attrs->channels.defined() && + tvm::tir::ExprDeepEqual()(conv_attrs->channels, conv_attrs->groups)) { + name = "nn.dw_conv2d_nhwc2nhwc"; + mrvlLayerName = "DW_Conv2D"; + ICHECK(kernel_layout == "IHWO") + << "Kernel layout must be IHWO, has the module been pre-processed correctly?"; + } else { + name = "nn.conv2d_nhwc2nhwc"; + mrvlLayerName = "Conv2D"; + ICHECK(data_layout == "NHWC") + << "Data layout must be NHWC, has the module been pre-processed correctly?"; + ICHECK(kernel_layout == "OHWI") + << "Kernel layout must be OHWI, has the module been pre-processed correctly?"; + ICHECK(out_layout == "NHWC") + << "Out layout must be NHWC, has the module been pre-processed correctly?"; + } + auto json_node = std::make_shared(name, "kernel", inputs, 1); + // following attributes will be set in json_node: + // - strides, paddings, dilation, + // - groups + // - data_layout, kernel_layout, out_layout + SetCallNodeAttribute(json_node, nodes.conv); + + // add other mrvl-specific attributes + SetMrvlSpecificJsonNodeAttrs( + json_node, cn, nodes.pad, nullptr /* no cn_pool */, nodes.conv, nullptr /* no cn_fc */, + nodes.add, nodes.batch_norm, nodes.activation, mrvlLayerName, + "" /* data_layout given in nodes.conv attrs */, + "" /* kernel_layout given in nodes.conv attrs */, + "" /* out_layout given in nodes.conv attrs */, "---O" /* if node.bias: as bias_layout */, + "relu" /* if node.activation: as activation_op */, nodes.op_names); + + return json_node; + } + + /*! + * \brief Create a JSON representation of a composite sum2d. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeMrvlSum2DLayer(const CallNode* cn) { + CompositeSum2DNode nodes = UnpackCompositeSum2D(cn); + ICHECK(nodes.add != nullptr) << "attribute add can't be nullptr"; + // Inputs must be added in the same order they appear in the relay graph. + std::vector inputs; + // data input tensor 1 + inputs.push_back(VisitExpr(cn->args[0])[0]); + // data input tensor 2 + inputs.push_back(VisitExpr(cn->args[1])[0]); + std::string mrvlLayerName = "Sum2D"; + std::string name = "sum2d"; + auto json_node = std::make_shared(name, "kernel", inputs, 1); + + // add other mrvl-specific attributes + SetMrvlSpecificJsonNodeAttrs( + json_node, cn, nullptr, nullptr /* no cn_pool */, nullptr /* no cn_conv */, + nullptr /* no cn_fc */, nodes.add, nullptr /* no cn_batch_norm */, nodes.activation, + mrvlLayerName, "NHWC" /* data_layout */, "" /* kernel_layout */, "NHWC" /* out_layout */, + "" /* bias_layout */, "relu" /* activation_op */, nodes.op_names); + return json_node; + } + + /*! + * \brief Extract fc nodes from a composite function. + * + * \param call The call node of the composite function. + * \return Extracted composite convolution nodes. + */ + static CompositeFcNode UnpackCompositeFc(const CallNode* call) { + CompositeFcNode nodes{}; + const auto* fn = call->op.as(); + ICHECK(fn); + + // Traverse composite fc function from child to parent + const auto* current_call = fn->body.as(); + if (backend::IsOp(current_call, "nn.relu")) { + nodes.activation = current_call; + current_call = current_call->args[0].as(); + } + if (backend::IsOp(current_call, "add")) { + nodes.add = current_call; + current_call = current_call->args[0].as(); + } + + ICHECK(backend::IsOp(current_call, "nn.dense")); + nodes.fc = current_call; + return nodes; + } + + /*! + * \brief Create a JSON representation of a composite fc (fully-connected) operator. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeMrvlFcLayer(const CallNode* cn) { + CompositeFcNode nodes = UnpackCompositeFc(cn); + std::string name = "nn.fc_ni2no"; + std::string mrvlLayerName = "FC"; + std::string data_layout = "NC"; + std::string kernel_layout = "OI"; + std::string out_layout = "NC"; + + // Inputs must be added in the same order they appear in the relay graph. + std::vector inputs; + inputs.push_back(VisitExpr(cn->args[0])[0]); + inputs.push_back(VisitExpr(nodes.fc->args[1])[0]); + if (nodes.add) { + inputs.push_back(VisitExpr(nodes.add->args[1])[0]); + } + + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, nodes.fc); + + // add other mrvl-specific attributes + SetMrvlSpecificJsonNodeAttrs(json_node, cn, nullptr /* no node.pad */, nullptr /* no cn_pool */, + nullptr /* no cn_conv */, nodes.fc, nodes.add, + nullptr /* no cn_batch_norm */, nodes.activation, mrvlLayerName, + "NC" /* data_layout */, "OI" /* kernel_layout */, + "NC" /* out_layout */, "-O" /* if node.bias: as bias_layout */, + "relu" /* if node.activation: as activation_op */, nodes.op_names); + return json_node; + } + + /*! + * \brief Create a JSON representation of a composite (global) maxpooling operator. + * + * A composite function is only created when using the uint8 datatype for these operators. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeMrvlMaxpool2DLayer(const CallNode* cn) { + std::string mrvlLayerName = "Maxpool2D"; + std::string name = "nn.maxpool2d_nhwc2nhwc"; + CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName); + const auto* maxpool_attr = nodes.pool->attrs.as(); + ICHECK(maxpool_attr); + ICHECK(maxpool_attr->layout == "NHWC") + << "Layout must be NHWC, has the module been pre-processed correctly?"; + + std::string data_layout = maxpool_attr->layout; + std::string out_layout = maxpool_attr->layout; + + std::vector inputs; + inputs.push_back(VisitExpr(cn->args[0])[0]); + + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, nodes.pool); + + // add other mrvl-specific attributes + SetMrvlSpecificJsonNodeAttrs( + json_node, cn, nodes.pad, nodes.pool, nullptr /* no cn_conv */, nullptr /* no cn_fc */, + nullptr /* no cn_add */, nullptr /* no cn_batch_norm */, nullptr /* no cn_activation */, + mrvlLayerName, "NHWC" /* data_layout */, "HW" /* kernel_layout */, "NHWC" /* out_layout */, + "" /* bias_layout */, "" /* activation_op */, nodes.op_names); + return json_node; + } + + /*! + * \brief Create a JSON representation of a composite (global) avgpooling operator. + * + * A composite function is only created when using the uint8 datatype for these operators. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeMrvlAvgpool2DLayer(const CallNode* cn) { + std::string mrvlLayerName = "Avgpool2D"; + std::string name = "nn.avgpool2d_nhwc2nhwc"; + CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName); + + const auto* avgpool_attr = nodes.pool->attrs.as(); + ICHECK(avgpool_attr); + ICHECK(avgpool_attr->layout == "NHWC") + << "Layout must be NHWC, has the module been pre-processed correctly?"; + + std::string data_layout = avgpool_attr->layout; + std::string out_layout = avgpool_attr->layout; + + std::vector inputs; + inputs.push_back(VisitExpr(cn->args[0])[0]); + + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, nodes.pool); + + // add other mrvl-specific attributes + SetMrvlSpecificJsonNodeAttrs( + json_node, cn, nodes.pad, nodes.pool, nullptr /* no cn_conv */, nullptr /* no cn_fc */, + nullptr /* no cn_add */, nullptr /* no cn_batch_norm */, nullptr /* no cn_activation */, + mrvlLayerName, "NHWC" /* data_layout */, "HW" /* kernel_layout */, "NHWC" /* out_layout */, + "" /* bias_layout */, "" /* activation_op */, nodes.op_names); + return json_node; + } + + /*! + * \brief Create a JSON representation of a composite (global) maxpooling operator. + * + * A composite function is only created when using the uint8 datatype for these operators. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateMrvlLayer4OpNode(const CallNode* cn) { + const auto* op_node = cn->op.as(); + ICHECK(op_node); + String op_name = tvm::Op::GetOpName(GetRef(op_node)); + + std::string name = op_name; + std::string mrvlLayerName = op_name; + std::string data_layout = ""; + std::string out_layout = ""; + if (op_name == "transpose") { + // do nothing for now + } else if ((op_name == "reshape") || (op_name == "nn.batch_flatten")) { + // FIXME: hard coded for now -- when input data dim is 4D and output dim is 2D + { + // check for cases currently support + std::vector layout_vec; + GetInputTensorShapeViaArg0(cn, &layout_vec); + ICHECK(layout_vec.size() == 4) + << "Reshape or nn.batch_flatten with input tensor dim != 4 is not supported yet."; + layout_vec.clear(); + GetOutputTensorShape(cn, &layout_vec); + ICHECK(layout_vec.size() == 2) + << "Reshape or nn.batch_flatten with output tensor dim != 2 is not supported yet."; + } + data_layout = "NHWC"; + out_layout = "NC"; + } else if (op_name == "layout_transform") { + auto layout_transform_attr = cn->attrs.as(); + data_layout = layout_transform_attr->src_layout; + out_layout = layout_transform_attr->dst_layout; + } else { + LOG(FATAL) << "Can't handle this OpNode: " << AsText(GetRef(cn), false); + } + + std::vector inputs; + inputs.push_back(VisitExpr(cn->args[0])[0]); + auto json_node = std::make_shared(name, "kernel", inputs, 1); + struct FrontendOpNames op_names; + JsonNodeSetAttr(json_node, "first_op_name", {op_names.first_op_name}); + JsonNodeSetAttr(json_node, "last_op_name", {op_names.last_op_name}); + setMrvlLayerCommonAttrs(json_node, cn, func_name_, mrvlLayerName, data_layout, + "" /* no kernel_layout */, out_layout); + return json_node; + } +}; + +#endif + +/*! + * \brief Create a runtime module for Mrvl. + * + * This consists of a series of "serialized functions" which each represent a + * sub-graph to be computed by Mrvl and will each be executed independently from + * one another. Each function consists of serialized JSON describing the sub-graph + * and serialized constant tensors. + * + * \note The Mrvl runtime module only supports a single operator per + * sub-graph currently. + * + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module MrvlCompiler(const ObjectRef& ref) { +#ifdef USE_JSON_RUNTIME + + // "per mrvl layer" MrvlCompiler call + // - i.e., this is not a per mrvl "network" call + if (!g_mrvlExtJsonObjInstantized) { + // For the Mrvl BYOC flow's GraphExecutorCodegen() object, we need to register the + // Mrvl-BYOC's external JSON callback function in order to generate JSON files + // following the Mrvl-BYOC format + // - TODO(ccjoechou): not sure this the best place to instantiate a MrvlExtJson object, + // which also register the callback function + InstantiateMrvlExtJsonObj(); + } + + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relay function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + MrvlJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + auto param_names = serializer.GetParams(); + const auto* pf = runtime::Registry::Get("runtime.mrvl_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; + runtime::Module lib = (*pf)(func_name, graph_json, param_names); + + return lib; + +#else + + // TODO(ccjoechou): TBA if needed -- follow "per layer" C-codegen example + +#endif +} + +// NOTE: called by compile_engine.cc CompileEngineImpl::LowerExternalFunctions() +TVM_REGISTER_GLOBAL("relay.ext.mrvl").set_body_typed(MrvlCompiler); + +/*! + * \brief Check whether Mrvl runtime is used. + * + * \return True if Mrvl runtime is enabled, False if not. + */ +inline constexpr bool IsMrvlRuntimeEnabled() { +#if TVM_RUNTIME_MRVL + return true; +#else + return false; +#endif +} + +TVM_REGISTER_GLOBAL("relay.op.is_mrvl_runtime_enabled").set_body_typed(IsMrvlRuntimeEnabled); + +} // namespace mrvl +} // namespace contrib + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/mrvl/drop_noop_transpose.cc b/src/relay/backend/contrib/mrvl/drop_noop_transpose.cc new file mode 100644 index 000000000000..9babe3b94543 --- /dev/null +++ b/src/relay/backend/contrib/mrvl/drop_noop_transpose.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/mrvl/drop_noop_trnaspose.cc + * \brief Marvell MLIP specific API + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "../../../op/tensor/transform.h" +#include "../../../transforms/pattern_utils.h" +#include "../../../transforms/simplify_expr.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace mrvl { + +/*! + * \brief DropNoOpTranspose, which does not change any axes. + */ +class DropNoOpTranspose : public DFPatternRewrite { + public: + DropNoOpTranspose() { + x_ = IsWildcard(); + auto trans = IsOp("transpose"); + pattern_ = trans({x_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + auto get_axes_from_call = [](const Call trans_call, int ndim) { + std::vector attr_axes; + if (auto attr = trans_call->attrs.as()) { + if (attr->axes.defined()) { + for (int i = 0; i < ndim; ++i) { + int64_t axis = attr->axes[i]; + axis += (axis < 0) ? ndim : 0; + attr_axes.push_back(axis); + } + } else { + // Empty axes means reverse + for (int i = ndim - 1; i >= 0; --i) { + attr_axes.push_back(i); + } + } + } else if (auto attr = trans_call->attrs.as()) { + Layout src_layout(attr->src_layout); + Layout dst_layout(attr->dst_layout); + for (int i = 0; i < ndim; ++i) { + attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); + } + } else { + CHECK(false) << "Mrvl-TVM-ERROR: Expected transpose or layout_transform, but got " + << Downcast(trans_call->op)->name; + } + return std::move(attr_axes); + }; + + auto x = node_map[x_][0]; + + // check axes + int ndim = Downcast(pre->checked_type())->shape.size(); + + // Collect axes from the transpose + Call trans_call = Downcast(post); + std::vector actual_axes = get_axes_from_call(trans_call, ndim); + bool drop = true; + for (int i = 0; i < ndim; ++i) { + if (actual_axes[i] != i) { + drop = false; + break; + } + } + + // x is result of the node just before the pattern + if (drop) return x; + + // keep the transpose node + return post; + } + + private: + /*! \brief Pattern input */ + DFPattern x_; +}; + +Expr DropNoopTranspose(const Expr& expr, const IRModule& mod) { + // the rewrites will be applied in the given order, and repeated until fixed point + DFPatternRewriteComposer composer; + composer.AddRewrite(); + return RewritePatterns(composer.MakeCallbacks(), expr, mod); +} + +} // namespace mrvl +} // namespace contrib + +namespace transform { + +Pass DropNoopTranspose() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(tvm::relay::contrib::mrvl::DropNoopTranspose(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "DropNoopTranspose", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DropNoopTranspose").set_body_typed(DropNoopTranspose); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/mrvl/graph_executor_codegen_mrvl.cc b/src/relay/backend/contrib/mrvl/graph_executor_codegen_mrvl.cc new file mode 100644 index 000000000000..247ba0f0cdce --- /dev/null +++ b/src/relay/backend/contrib/mrvl/graph_executor_codegen_mrvl.cc @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/mrvl/graph_executor_codegen_mrvl.cc + * \brief Marvell MLIP specific API + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "../src/relay/backend/graph_executor_codegen.h" + +namespace tvm { +namespace relay { +namespace backend { + +using IntegerArray = Array; +using ShapeVector = std::vector>; +using GraphAttrs = std::unordered_map; +using GraphObjectPtr = std::shared_ptr; + +extern "C" ExternalJsonWriterCB* GetExternalJsonWriter(); + +/*! \brief Input Node */ +class GraphInputNodeMrvlExt : public GraphInputNode { + public: + GraphInputNodeMrvlExt() : GraphInputNode() {} + GraphInputNodeMrvlExt(const std::string& name, const GraphAttrs& attrs) + : GraphInputNode(name, attrs) {} + + static std::shared_ptr make_node_ptr(const std::string& name, + const GraphAttrs& attrs) { + auto ptr = std::make_shared(name, attrs); + return std::dynamic_pointer_cast(ptr); + } + + GraphNodeType Type() const override { return kGraphInputNodeExt; } + + void Save(dmlc::JSONWriter* writer) const override { + const std::string op_name{"null"}; + writer->BeginObject(); + writer->WriteObjectKeyValue("op", op_name); + writer->WriteObjectKeyValue("name", this->name_); + writer->WriteObjectKeyValue("inputs", std::list()); + writer->WriteObjectKeyValue("attrs", this->attrs_); + writer->EndObject(); + } +}; + +class GraphOpNodeMrvlExt : public GraphOpNode { + public: + GraphOpNodeMrvlExt() {} + virtual ~GraphOpNodeMrvlExt() {} + + GraphNodeType Type() const override { return kGraphOpNodeExt; } + + void Load(dmlc::JSONReader* reader) override; + void LoadAttrs(dmlc::JSONReader* reader); + std::pair GetLoadedGraphAttrs(); + std::string func_node_name_; + GraphAttrs op_attrs_; +}; + +/*! + * \brief Load a node in the json string. + * \param reader The json reader. + */ +void GraphOpNodeMrvlExt::Load(dmlc::JSONReader* reader) { + std::string tmp_name; + std::vector tmp_int_arr; + + reader->BeginObject(); + std::string key; + while (reader->NextObjectItem(&key)) { + if (key == "op") { + reader->Read(&tmp_name); + } else if (key == "name") { + reader->Read(&tmp_name); + } else if (key == "inputs") { + reader->BeginArray(); + ICHECK(reader->NextArrayItem()) << "invalid json format"; + reader->Read(&tmp_int_arr); + if (reader->NextArrayItem()) { + reader->Read(&tmp_int_arr); + if (reader->NextArrayItem()) { + reader->Read(&tmp_int_arr); + ICHECK(!reader->NextArrayItem()) << "invalid json format"; + } + } + } else if (key == "attr" || key == "attrs") { + this->LoadAttrs(reader); + } else { + LOG(FATAL) << "Unknown key: " << key; + } + } +} + +/*! + * \brief Load the attribute of a node in the json string. + * \param reader The json reader. + */ +void GraphOpNodeMrvlExt::LoadAttrs(dmlc::JSONReader* reader) { + std::string key; + std::string tmp_str; + GraphAttrs attrs; + op_attrs_ = attrs; + + // - skip num_inputs and num_outputs (use originals) + // - skip dtype for now + // - skip shape + reader->BeginObject(); + while (reader->NextObjectItem(&key)) { + if (key == "num_inputs") { + reader->Read(&tmp_str); + } else if (key == "num_outputs") { + reader->Read(&tmp_str); + } else if (key == "dtype") { + reader->BeginArray(); + std::vector tmp_str_vec; + ICHECK(reader->NextArrayItem()); + reader->Read(&tmp_str_vec); + ICHECK(!reader->NextArrayItem()); + } else if (key == "shape") { + reader->BeginArray(); + std::vector> tmp_shape; + ICHECK(reader->NextArrayItem()); + reader->Read(&tmp_shape); + ICHECK(!reader->NextArrayItem()); + } else { + reader->BeginArray(); + std::vector tmp_str_vec; + ICHECK(reader->NextArrayItem()); + reader->Read(&tmp_str_vec); + op_attrs_[key] = tmp_str_vec; + ICHECK(!reader->NextArrayItem()); + if (key == "func_node_name") { + ICHECK(tmp_str_vec.size() == 1); + func_node_name_ = tmp_str_vec[0]; + } + } + } +} + +/*! + * \brief return generated map + */ +std::pair GraphOpNodeMrvlExt::GetLoadedGraphAttrs() { + return std::pair(func_node_name_, op_attrs_); +} + +class MrvlExtJson { + public: + MrvlExtJson() { + ICHECK(!GetExternalJsonWriter()->HasCallback()) << "ERROR: has registered callback"; + GetExternalJsonWriter()->RegisterCB(this, &MrvlExtJson::GetExternalJSON); + } + + virtual ~MrvlExtJson() {} + + void GetExternalJSON(dmlc::JSONWriter* writer, Array external_mods, + std::vector nodes, std::vector heads); + + void LoadExternalJsonAttrs(std::unordered_map* external_attrs_map, + const Array& external_mods); +}; + +/*! + * \brief Load External Json attrs map + * + * \param external_attrs_map: map to be generated + * \param external_mods: array of external-codegen mods (one per external + * composite func) + */ +void MrvlExtJson::LoadExternalJsonAttrs( + std::unordered_map* external_attrs_map, + const Array& external_mods) { + // retrieve attributes from each external composite graph + for (size_t i = 0; i < external_mods.size(); ++i) { + auto mod = external_mods[i]; + auto pfunc = mod.GetFunction("get_graph_json", false); + std::string graph_json = pfunc(); + std::istringstream tmp_is(graph_json); + dmlc::JSONReader tmp_reader(&tmp_is); + + std::vector tmp2_nodes; + std::vector tmp_int_array; + std::string key; + tmp_reader.BeginObject(); + while (tmp_reader.NextObjectItem(&key)) { + if (key == "nodes") { + tmp_reader.Read(&tmp2_nodes); + } else if (key == "arg_nodes") { + tmp_reader.Read(&tmp_int_array); + } else if (key == "node_row_ptr") { + tmp_reader.Read(&tmp_int_array); + } else if (key == "heads") { + tmp_reader.BeginArray(); + ICHECK(tmp_reader.NextArrayItem()) << "invalid json format"; + tmp_reader.Read(&tmp_int_array); + ICHECK(!tmp_reader.NextArrayItem()) << "invalid json format"; + } else { + LOG(FATAL) << "Unknown key: " << key; + } + } + std::pair mrvl_node_attrs = + tmp2_nodes[tmp2_nodes.size() - 1].GetLoadedGraphAttrs(); + external_attrs_map->insert({mrvl_node_attrs.first, mrvl_node_attrs.second}); + } +} + +/*! + * \brief Generate External Graph JSON + * + * \param writer json writer + */ +void MrvlExtJson::GetExternalJSON(dmlc::JSONWriter* writer, + Array external_mods, + std::vector nodes, + std::vector heads) { + // retrieve attributes from each external composite graph + std::unordered_map external_attrs_map; + LoadExternalJsonAttrs(&external_attrs_map, external_mods); + + /*! \brief nodes */ + std::vector external_nodes = nodes; + /*! \brief output of graph */ + std::vector external_heads = heads; + + for (size_t i = 0; i < external_nodes.size(); ++i) { + auto node = external_nodes[i]; + if (node->Type() == kGraphOpNode) { + // replace the op_attrs of this GraphOpNode node with its corresponding + // external codegen node's attrs + if (external_attrs_map.count(node->name_) == 1) { + std::dynamic_pointer_cast(node)->op_attrs_ = external_attrs_map[node->name_]; + } + } + } + + std::vector arg_nodes; + for (size_t i = 0; i < external_nodes.size(); ++i) { + auto node = external_nodes[i]; + if (node->Type() == kGraphInputNode) { + arg_nodes.push_back(i); + } + } + size_t num_entry = 0; + ShapeVector shapes; + std::vector storage_ids; + std::vector device_types; + std::vector dltypes; + std::vector node_row_ptr{0}; + for (size_t i = 0; i < external_nodes.size(); ++i) { + auto node = external_nodes[i]; + const auto& shape_vec = dmlc::get(node->attrs_["shape"]); + const auto& storage_id = dmlc::get>(node->attrs_["storage_id"]); + const auto& dtype_vec = dmlc::get>(node->attrs_["dtype"]); + + ICHECK_EQ(node->num_outputs_, shape_vec.size()); + num_entry += node->num_outputs_; + + shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end()); + dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end()); + storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end()); + if (node->attrs_.count("device_index")) { + const auto& dev_types = dmlc::get>(node->attrs_["device_index"]); + device_types.insert(device_types.end(), dev_types.begin(), dev_types.end()); + } + node_row_ptr.push_back(num_entry); + + if ((node->Type() == kGraphInputNode) && (external_attrs_map.size() > 0)) { + ICHECK(dynamic_cast(node.get())); + // copy GraphInputNode node to a GraphInputNodeMrvlExt node in order to + // use its own writer (i.e., by calling its Save() func) + auto new_input_node_ptr = GraphInputNodeMrvlExt::make_node_ptr(node->name_, node->attrs_); + external_nodes[i] = std::dynamic_pointer_cast(new_input_node_ptr); + + // add "attrs": { "layer_name": [ "input" ] } + std::vector layer_name_json_attr; + layer_name_json_attr.emplace_back(std::string("input")); + + // add "attrs": { "data_layout": [ "NCHW" or "NHWC" or "NC" etc. ] } + // TODO(ccjoechou): improve coverage to allow other networks + bool is_NC = (!shape_vec.empty()) && (shape_vec[0].size() == 2); + new_input_node_ptr->attrs_.clear(); + new_input_node_ptr->attrs_["layer_name"] = layer_name_json_attr; + std::vector data_layout_json_attr; + if (is_NC) { + data_layout_json_attr.emplace_back(std::string("NC")); + } else { + data_layout_json_attr.emplace_back(std::string("NCHW")); + } + new_input_node_ptr->attrs_["data_layout"] = data_layout_json_attr; + } + } + writer->BeginObject(); + writer->WriteObjectKeyValue("nodes", external_nodes); + writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("heads", external_heads); + std::unordered_map> attrs; + attrs["shape"].emplace_back(std::string("list_shape")); + attrs["shape"].emplace_back(shapes); + attrs["storage_id"].emplace_back(std::string("list_int")); + attrs["storage_id"].emplace_back(storage_ids); + if (device_types.size()) { + attrs["device_index"].emplace_back(std::string("list_int")); + attrs["device_index"].emplace_back(device_types); + } + attrs["dltype"].emplace_back(std::string("list_str")); + attrs["dltype"].emplace_back(dltypes); + writer->WriteObjectKeyValue("attrs", attrs); + writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr); + writer->EndObject(); +} + +std::shared_ptr g_mrvl_ext_json; + +extern "C" bool g_mrvlExtJsonObjInstantized; +bool g_mrvlExtJsonObjInstantized = false; + +extern "C" void InstantiateMrvlExtJsonObj() { + g_mrvl_ext_json = std::make_shared(); + g_mrvlExtJsonObjInstantized = true; +} + +void MrvlClearFlag() { g_mrvlExtJsonObjInstantized = false; } +TVM_REGISTER_GLOBAL("relay.mrvl.clear_ext_json_flag").set_body_typed(MrvlClearFlag); + +} // namespace backend +} // namespace relay +} // namespace tvm + +namespace dmlc { +namespace json { + +// JSON utils to be template specialized for Mrvl BYOC GetExternalJSON() related extensions +template +inline bool SameType(const dmlc::any& data) { + return std::type_index(data.type()) == std::type_index(typeid(T)); +} + +template <> +struct Handler> { + inline static void Write(dmlc::JSONWriter* writer, + const std::shared_ptr& data) { + data->Save(writer); + } + inline static void Read(dmlc::JSONReader* reader, + std::shared_ptr* data) { + LOG(FATAL) << "Not implemented."; + } +}; + +template <> +struct Handler> { + inline static void Write(dmlc::JSONWriter* writer, + const std::unordered_map& data) { + writer->BeginObject(); + for (const auto& kv : data) { + auto k = kv.first; + const dmlc::any& v = kv.second; + if (SameType(v)) { + writer->WriteObjectKeyValue(k, dmlc::get(v)); + } else if (SameType(v)) { + writer->WriteObjectKeyValue(k, dmlc::get(v)); + } else if (SameType>(v)) { + writer->WriteObjectKeyValue(k, dmlc::get>(v)); + } else if (SameType>>(v)) { + writer->WriteObjectKeyValue(k, dmlc::get>>(v)); + } else if (SameType>(v)) { + writer->WriteObjectKeyValue(k, dmlc::get>(v)); + } else if (SameType>(v)) { + writer->WriteObjectKeyValue(k, dmlc::get>(v)); + } else { + LOG(FATAL) << "Value type not supported for key: " << k; + } + } + writer->EndObject(); + } + inline static void Read(dmlc::JSONReader* reader, + std::unordered_map* data) { + LOG(FATAL) << "Not implemented."; + } +}; + +template <> +struct Handler> { + inline static void Write(dmlc::JSONWriter* writer, const std::vector& data) { + writer->BeginArray(); + for (const auto& v : data) { + if (SameType(v)) { + writer->WriteArrayItem(dmlc::get(v)); + } else if (SameType(v)) { + writer->WriteArrayItem(dmlc::get(v)); + } else if (SameType>(v)) { + writer->WriteArrayItem(dmlc::get>(v)); + } else if (SameType>>(v)) { + writer->WriteArrayItem(dmlc::get>>(v)); + } else if (SameType>(v)) { + writer->WriteArrayItem(dmlc::get>(v)); + } else { + LOG(FATAL) << "Not supported"; + } + } + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, std::vector* data) { + LOG(FATAL) << "Not implemented."; + } +}; + +} // namespace json +} // namespace dmlc diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index f61fe9b402b3..564ff3fc00b4 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -22,6 +22,8 @@ * \brief Graph executor codegen */ +#include "../src/relay/backend/graph_executor_codegen.h" + #include #include #include @@ -52,146 +54,63 @@ backend::StaticMemoryPlan GraphPlanMemory(const Function& func); namespace backend { -class GraphNode; class GraphInputNode; class GraphOpNode; +class GraphExecutorCodegen; using IntegerArray = Array; using ShapeVector = std::vector>; using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; -using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; -/*! \brief Node types */ -enum GraphNodeType { - kGraphNop, - kGraphInputNode, - kGraphOpNode, -}; - -class GraphNodeRef { - public: - GraphNodeRef() {} - GraphNodeRef(int ident, int index, int version = 0) - : ident_(ident), index_(index), version_(version) {} - - inline void Save(dmlc::JSONWriter* writer) const { - writer->BeginArray(); - writer->WriteArrayItem(ident_); - writer->WriteArrayItem(index_); - writer->WriteArrayItem(version_); - writer->EndArray(); - } - - inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; } - - protected: - int ident_; - int index_{0}; - int version_{0}; -}; - -/*! \brief Base Node class */ -class GraphNode { - public: - GraphNode() {} - virtual void Save(dmlc::JSONWriter* writer) const {} - virtual void Load(dmlc::JSONReader* reader) {} - virtual GraphNodeType Type() const { return kGraphNop; } - virtual ~GraphNode() {} - - public: - int num_outputs_{1}; - std::string name_; - GraphAttrs attrs_; -}; - -/*! \brief Input Node */ -class GraphInputNode : public GraphNode { - public: - GraphInputNode() {} - GraphInputNode(const std::string& name, const GraphAttrs& attrs) { - name_ = name; - attrs_ = attrs; - } - - GraphNodeType Type() const override { return kGraphInputNode; } - - void Save(dmlc::JSONWriter* writer) const override { - const std::string op_name{"null"}; - writer->BeginObject(); - writer->WriteObjectKeyValue("op", op_name); - writer->WriteObjectKeyValue("name", this->name_); - writer->WriteObjectKeyValue("inputs", std::list()); - writer->EndObject(); - } - static std::shared_ptr make_node_ptr(const std::string& name, - const GraphAttrs& attrs) { - auto ptr = std::make_shared(name, attrs); - return std::dynamic_pointer_cast(ptr); - } -}; - /*! \brief Op Node */ -class GraphOpNode : public GraphNode { - public: - GraphOpNode() {} - GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, - const std::vector& inputs, const GraphAttrs& attrs, - size_t num_outputs = 1) { - name_ = name; - attrs_ = nd_attrs; - op_name_ = op_name; - inputs_ = inputs; - op_attrs_ = attrs; - num_outputs_ = num_outputs; - op_attrs_["func_name"] = op_name_; - op_attrs_["flatten_data"] = std::string("0"); - op_attrs_["num_inputs"] = std::to_string(inputs_.size()); - op_attrs_["num_outputs"] = std::to_string(num_outputs_); - } - - GraphNodeType Type() const override { return kGraphOpNode; } - - void Save(dmlc::JSONWriter* writer) const override { - GraphAttrs attrs = op_attrs_; - attrs["func_name"] = this->op_name_; - attrs["flatten_data"] = std::string("0"); - attrs["num_inputs"] = std::to_string(this->inputs_.size()); - attrs["num_outputs"] = std::to_string(this->num_outputs_); - writer->BeginObject(); - writer->WriteObjectKeyValue("op", op_type_name_); - writer->WriteObjectKeyValue("name", name_); - writer->WriteObjectKeyValue("attrs", attrs); - writer->WriteObjectKeyValue("inputs", this->inputs_); - writer->EndObject(); - } - static std::shared_ptr make_node_ptr(const std::string& name, - const GraphAttrs& nd_attrs, - const std::string& op_name, - const std::vector& inputs, - const GraphAttrs& attrs, size_t num_outputs = 1) { - auto ptr = std::make_shared(name, nd_attrs, op_name, inputs, attrs, num_outputs); - return std::dynamic_pointer_cast(ptr); - } - - public: - std::string op_name_; - std::vector inputs_; - GraphAttrs op_attrs_; +GraphOpNode::GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, + const std::string& op_name, const std::vector& inputs, + const GraphAttrs& attrs, size_t num_outputs) { + name_ = name; + attrs_ = nd_attrs; + op_name_ = op_name; + inputs_ = inputs; + op_attrs_ = attrs; + num_outputs_ = num_outputs; + op_attrs_["func_name"] = op_name_; + op_attrs_["flatten_data"] = std::string("0"); + op_attrs_["num_inputs"] = std::to_string(inputs_.size()); + op_attrs_["num_outputs"] = std::to_string(num_outputs_); +} - private: - const std::string op_type_name_{"tvm_op"}; -}; +GraphNodeType GraphOpNode::Type() const { return kGraphOpNode; } + +void GraphOpNode::Save(dmlc::JSONWriter* writer) const { + GraphAttrs attrs = op_attrs_; + attrs["func_name"] = this->op_name_; + attrs["flatten_data"] = std::string("0"); + attrs["num_inputs"] = std::to_string(this->inputs_.size()); + attrs["num_outputs"] = std::to_string(this->num_outputs_); + writer->BeginObject(); + writer->WriteObjectKeyValue("op", op_type_name_); + writer->WriteObjectKeyValue("name", name_); + writer->WriteObjectKeyValue("attrs", attrs); + writer->WriteObjectKeyValue("inputs", this->inputs_); + writer->EndObject(); +} -/*! \brief Code generator for the graph executor, produces a module containing the graph JSON, - * module, and parameters. +/*! \brief Code generator for the graph executor, produces a module containing + * the graph JSON, module, and parameters. */ class GraphExecutorCodegen : public backend::MemoizedExprTranslator> { public: GraphExecutorCodegen(runtime::Module* mod, const TargetMap& targets) - : mod_(mod), targets_(targets) {} + : mod_(mod), targets_(targets) { + // we need the following variable to be a static member of the class so we can access + // its setting in the following static GetExternalJsonWriter() function; but this static + // member can actually be used as a local Callback setting for "per" GraphExecutorCodegen + // instantiation during each TVM build-codegen flow + external_json_writer_ = std::make_shared(); + ICHECK(external_json_writer_); + } + static ExternalJsonWriterCB* GetExternalJsonWriter() { return external_json_writer_.get(); } StorageInfo GetStorageInfo(const Expr& e) { size_t count = memory_plan_->expr_to_storage_info.count(e); @@ -254,13 +173,15 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(lowered_mod->Lookup("main")); - // Now that we have lowered all operators to TIR code, we can proceed with compilation. + // Now that we have lowered all operators to TIR code, we can proceed with + // compilation. // - // We need to unfortunately re-plan as the previous results have been invalidated by lowering - // we will fix this in future refactors. + // We need to unfortunately re-plan as the previous results have been + // invalidated by lowering we will fix this in future refactors. memory_plan_ = GraphPlanMemory(lowered_main_func); - // The graph planner also can not handle planning calls to global variables to we must remap + // The graph planner also can not handle planning calls to global variables + // to we must remap // First we convert all the parameters into input nodes. for (auto param : lowered_main_func->params) { @@ -290,6 +211,15 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorHasCallback()) { + std::ostringstream external_os; + dmlc::JSONWriter external_writer(&external_os); + external_json_writer_->Exe(&external_writer, ret.external_mods, nodes_, heads_); + ret.external_graph_json = external_os.str(); + } + return ret; } @@ -461,7 +391,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorop.as(); ICHECK(func) << "Expected the operator to be a global var, but got " - << call_node->op->GetTypeKey(); // getting a relay fn here, not sure why. + << call_node->op->GetTypeKey(); // getting a relay fn here, + // not sure why. func_name = func->name_hint; for (const Expr& arg : call_node->args) { @@ -476,7 +407,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const GlobalVarNode* op) override { - LOG(FATAL) << "All GlobalVarNodes should be removed before graph executor's Codegen is called"; + LOG(FATAL) << "All GlobalVarNodes should be removed before graph " + "executor's Codegen is called"; return {}; } std::vector VisitExpr_(const IfNode* op) override { @@ -625,8 +558,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator params_; std::unordered_map param_storage_ids_; @@ -638,7 +571,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator function_metadata_; /*! \brief name map */ std::unordered_map name_map_; + static std::shared_ptr external_json_writer_; }; +std::shared_ptr GraphExecutorCodegen::external_json_writer_ = + std::shared_ptr(); class GraphExecutorCodegenModule : public runtime::ModuleNode { public: @@ -663,6 +599,10 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { } else if (name == "get_graph_json") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; }); + } else if (name == "get_external_graph_json") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.external_graph_json; + }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Array ret; @@ -714,6 +654,10 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { LoweredOutput output_; }; +extern "C" ExternalJsonWriterCB* GetExternalJsonWriter() { + return GraphExecutorCodegen::GetExternalJsonWriter(); +} + runtime::Module CreateGraphCodegenMod() { auto ptr = make_object(); return runtime::Module(ptr); @@ -766,7 +710,7 @@ struct Handler> { } else if (SameType>(v)) { writer->WriteObjectKeyValue(k, dmlc::get>(v)); } else { - LOG(FATAL) << "Not supported"; + LOG(FATAL) << "Value type not supported for key: " << k; } } writer->EndObject(); diff --git a/src/relay/backend/graph_executor_codegen.h b/src/relay/backend/graph_executor_codegen.h new file mode 100644 index 000000000000..9ada92bddb91 --- /dev/null +++ b/src/relay/backend/graph_executor_codegen.h @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/src/relay/backend/graph_executor_codegen.h + * \brief Graph executor codegen + */ +#ifndef TVM_RELAY_BACKEND_GRAPH_EXECUTOR_CODEGEN_H_ +#define TVM_RELAY_BACKEND_GRAPH_EXECUTOR_CODEGEN_H_ + +#include <../../src/relay/backend/utils.h> + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace backend { + +class GraphExecutorCodegen; +class GraphInputNode; +class GraphOpNode; +class GraphNode; + +using GraphAttrs = std::unordered_map; +using GraphObjectPtr = std::shared_ptr; + +/*! \brief Node types */ +enum GraphNodeType { + kGraphNop, + kGraphInputNode, + kGraphOpNode, + kGraphInputNodeExt, + kGraphOpNodeExt, +}; + +/*! \brief Base Node class */ +class GraphNode { + public: + GraphNode() {} + virtual void Save(dmlc::JSONWriter* writer) const {} + virtual void Load(dmlc::JSONReader* reader) {} + virtual GraphNodeType Type() const { return kGraphNop; } + virtual ~GraphNode() {} + + public: + int num_outputs_{1}; + std::string name_; + GraphAttrs attrs_; +}; + +class GraphNodeRef { + public: + GraphNodeRef() {} + GraphNodeRef(int ident, int index, int version = 0) + : ident_(ident), index_(index), version_(version) {} + + inline void Save(dmlc::JSONWriter* writer) const { + writer->BeginArray(); + writer->WriteArrayItem(ident_); + writer->WriteArrayItem(index_); + writer->WriteArrayItem(version_); + writer->EndArray(); + } + + inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; } + + protected: + int ident_; + int index_{0}; + int version_{0}; +}; + +/*! \brief Input Node */ +class GraphInputNode : public GraphNode { + public: + GraphInputNode() {} + GraphInputNode(const std::string& name, const GraphAttrs& attrs) { + name_ = name; + attrs_ = attrs; + } + + GraphNodeType Type() const override { return kGraphInputNode; } + + void Save(dmlc::JSONWriter* writer) const override { + const std::string op_name{"null"}; + writer->BeginObject(); + writer->WriteObjectKeyValue("op", op_name); + writer->WriteObjectKeyValue("name", this->name_); + writer->WriteObjectKeyValue("inputs", std::list()); + writer->EndObject(); + } + static std::shared_ptr make_node_ptr(const std::string& name, + const GraphAttrs& attrs) { + auto ptr = std::make_shared(name, attrs); + return std::dynamic_pointer_cast(ptr); + } + + inline void Load(dmlc::JSONReader* reader) override { LOG(FATAL) << "Not implemented."; } +}; + +/*! \brief Op Node */ +class GraphOpNode : public GraphNode { + public: + GraphOpNode& operator=(const GraphOpNode& t) { return *this; } + GraphOpNode() {} + GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, + const std::vector& inputs, const GraphAttrs& attrs, + size_t num_outputs = 1); + + GraphNodeType Type() const override; + void Save(dmlc::JSONWriter* writer) const override; + + static std::shared_ptr make_node_ptr(const std::string& name, + const GraphAttrs& nd_attrs, + const std::string& op_name, + const std::vector& inputs, + const GraphAttrs& attrs, size_t num_outputs = 1) { + auto ptr = std::make_shared(name, nd_attrs, op_name, inputs, attrs, num_outputs); + return std::dynamic_pointer_cast(ptr); + } + + public: + std::string op_name_; + std::vector inputs_; + GraphAttrs op_attrs_; + + private: + const std::string op_type_name_{"tvm_op"}; +}; + +class ExternalJsonWriterCB { + public: + template + void RegisterCB(T* const object, + void (T::*const mf)(dmlc::JSONWriter*, Array, + std::vector, std::vector)) { + using namespace std::placeholders; + callback_ = std::bind(mf, object, _1, _2, _3, _4); + hasCallback_ = true; + } + void RegisterCB(void (*const fun)(dmlc::JSONWriter*, Array, + std::vector, std::vector)) { + callback_ = fun; + hasCallback_ = true; + } + void Exe(dmlc::JSONWriter* external_writer, Array mod, + std::vector nodes, std::vector heads) { + ICHECK(hasCallback_) << "ERROR: no registered callback"; + callback_(external_writer, mod, nodes, heads); + } + inline bool HasCallback() { return hasCallback_; } + + private: + std::function, std::vector, + std::vector)> + callback_; + bool hasCallback_{false}; +}; + +} // namespace backend +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_GRAPH_EXECUTOR_CODEGEN_H_ diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index df25a8641792..ee34261bdd25 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -143,6 +143,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type); */ struct LoweredOutput { std::string graph_json; + std::string external_graph_json{""}; Map lowered_funcs; Array external_mods; Map function_metadata; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index b680a49af887..d418262bead0 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -26,6 +26,7 @@ #include namespace tvm { +int64_t RelayExprNode::_global_en_id = 0; VirtualDevice RelayExprNode::virtual_device() const { if (virtual_device_.defined()) { @@ -592,5 +593,15 @@ void LetNode::Deleter_(Object* ptr) { auto c = GetRef(p); } +/* + * Helper functions to get en_id mgmt + * Identify an expr node + */ +TVM_REGISTER_GLOBAL("relay.ir.get_en_id").set_body_typed([](const ObjectRef& ref) { + auto* relay_expr = static_cast(ref.get()); + ICHECK(relay_expr) << "can't downclass obj to RelayExprNode"; + return relay_expr->get_en_id(); +}); + } // namespace relay } // namespace tvm diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 1735d8569215..a9882ba462eb 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -75,6 +75,10 @@ class JSONRuntimeBase : public ModuleNode { } else if (name == "get_const_vars") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_names_; }); + } else if (name == "get_graph_json") { + // add a new API to return graph_json + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->graph_json_; }); } else if (this->symbol_name_ == name) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK(this->initialized_) << "The module has not been initialized"; diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc new file mode 100644 index 000000000000..ab91abf2435e --- /dev/null +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/mrvl/mrvl_runtime.cc + * \brief A simple JSON runtime for Mrvl + */ + +#include +#include + +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; + +class MrvlRuntime : public JSONRuntimeBase { + public: + /*! + * \brief The Mrvl runtime module. Deserialize the provided functions + * on creation and store in the layer cache. + * + * \param symbol_name The name of the function. + * \param graph_json serialized JSON representation of a sub-graph. + * \param const_names The names of each constant in the sub-graph. + */ + explicit MrvlRuntime(const std::string& symbol_name, const std::string& graph_json, + const Array& const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + /*! + * \brief The type key of the module. + * + * \return module type key. + */ + const char* type_key() const override { return "mrvl"; } + + /*! + * \brief Initialize runtime. Create Mrvl layer from JSON + * representation. + * + * \param consts The constant params from compiled model. + */ + void Init(const Array& consts) override { + ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required."; + SetupConstants(consts); + BuildEngine(); + } + + void Run() override { + LOG(FATAL) << "Cannot call run on Mrvl module without runtime enabled. " + << "Please build with USE_MRVL_RUNTIME (which is not supported yet)."; + } + + void BuildEngine() { + LOG(WARNING) << "Mrvl engine is not initialized. " + << "Please build with USE_MRVL_RUNTIME (which is not supported yet)."; + } +}; + +runtime::Module MrvlRuntimeCreate(const String& symbol_name, const String& graph_json, + const Array& const_names) { + auto n = make_object(symbol_name, graph_json, const_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.mrvl_runtime_create").set_body_typed(MrvlRuntimeCreate); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl") + .set_body_typed(JSONRuntimeBase::LoadFromBinary); + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 7317cab665cf..9204fd6730d8 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -211,6 +211,14 @@ #define TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_MRVL +#define TVM_INFO_USE_MRVL "NOT-FOUND" +#endif + +#ifndef TVM_INFO_USE_MRVL_RUNTIME +#define TVM_INFO_USE_MRVL_RUNTIME "NOT-FOUND" +#endif + #ifndef TVM_INFO_INDEX_DEFAULT_I64 #define TVM_INFO_INDEX_DEFAULT_I64 "NOT-FOUND" #endif @@ -275,6 +283,8 @@ TVM_DLL Map GetLibInfo() { {"USE_TARGET_ONNX", TVM_INFO_USE_TARGET_ONNX}, {"USE_ARM_COMPUTE_LIB", TVM_INFO_USE_ARM_COMPUTE_LIB}, {"USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR", TVM_INFO_USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR}, + {"USE_MRVL", TVM_INFO_USE_MRVL}, + {"USE_MRVL_RUNTIME", TVM_INFO_USE_MRVL_RUNTIME}, {"INDEX_DEFAULT_I64", TVM_INFO_INDEX_DEFAULT_I64}, {"TVM_CXX_COMPILER_PATH", TVM_CXX_COMPILER_PATH}}; return result; diff --git a/tests/python/contrib/test_mrvl/__init__.py b/tests/python/contrib/test_mrvl/__init__.py new file mode 100644 index 000000000000..2aef4cc58f0f --- /dev/null +++ b/tests/python/contrib/test_mrvl/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Infrastructure and tests for MRVL codegen """ diff --git a/tests/python/contrib/test_mrvl/infrastructure.py b/tests/python/contrib/test_mrvl/infrastructure.py new file mode 100644 index 000000000000..efdaabb85049 --- /dev/null +++ b/tests/python/contrib/test_mrvl/infrastructure.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, W0611, C0413 + +"""Expose MRVL test functions to the Python frontend""" + +import os +import json +import re + +import tvm +from tvm import relay +from tvm.relay.op.contrib import mrvl + + +def file_exists(full_path_filename): + """Check existance of given file.""" + return os.path.exists(full_path_filename) and os.path.isfile(full_path_filename) + + +def get_cpu_op_count(mod): + """Traverse graph counting ops offloaded to TVM.""" + + class Counter(tvm.relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + c = Counter() + c.visit(mod["main"]) + return c.count + + +def skip_json_codegen_test(): + """Skip test if it requires the Mrvl codegen and it's not present.""" + # Mrvl codegen not present. + if not tvm.get_global_func("relay.ext.mrvl", True): + print("Skip because Mrvl codegen is not available.") + return True + + +def skip_aot_runtime_test(): + """Skip test if it requires the Mrvl runtime and it's not present.""" + # Mrvl codegen not present. + if skip_json_codegen_test(): + return True + if not mrvl.is_mrvl_runtime_enabled(): + print("Skip because Mrvl runtime isn't present or a remote device isn't being used.") + return True + return False + + +def aot_build_and_json_codegen( + model_name, + working_dir, + raw_model_ir, + weight_bias_params, + defuse_mrvl_layers_list=[], +): + gen_non_mrvl_subgraph = True + if defuse_mrvl_layers_list is []: + gen_non_mrvl_subgraph = False + + # produce at most two subgraphs, one mrvl subgraph and/or one non-mrvl subgraph + try: + ( + model_mrvl, + model_other, + orig_params, + opt_level, + disabled_pass, + orig_mod, + mrvl_layers_in_mrvl_subgraph, + ) = mrvl.partition_for_mrvl( + raw_model_ir, + params=weight_bias_params, + tvm_custom_dict={}, + gen_non_mrvl_subgraph=gen_non_mrvl_subgraph, + flow_pass=1, + ) + assert orig_params is not None + assert opt_level is not None + assert orig_mod is not None + except Exception as e: + err_msg = f"The {model_name} model could not be partitioned into subgraph.\n" + err_msg += str(e) + raise Exception(err_msg) + + try: + build_target, device_id = "llvm", 0 + mod_name = relay.backend.utils.mangle_module_name("") + byoc_executor = relay.build(model_mrvl, target=build_target, mod_name=mod_name) + # + byoc_const_params = byoc_executor.get_params() + byoc_external_graph_json = byoc_executor.get_external_graph_json() + assert byoc_const_params is not None + assert byoc_external_graph_json is not None + except Exception as e: + err_msg = f"Subgraph(s) could not be relay.build.\n" + err_msg += str(e) + raise Exception(err_msg) + + try: + nodes_json_filename, consts_json_filename = mrvl.dump_json_meta_data_files( + byoc_external_graph_json, + byoc_const_params, + filename_prefix=f"{working_dir}{model_name}-tvm-mrvl-byoc-ir", + ) + assert nodes_json_filename + assert consts_json_filename + except Exception as e: + err_msg = f"Mrvl JSON codegen failed.\n" + err_msg += str(e) + raise Exception(err_msg) + + return ( + nodes_json_filename, + consts_json_filename, + model_mrvl, + model_other, + mrvl_layers_in_mrvl_subgraph, + # FIXME: to return mrvl_layers_in_non_mrvl_subgraph + [], + ) + + +def check_json_integrity(nodes_json_file): + json_obj = json.load(open(nodes_json_file, "r")) + assert "nodes" in json_obj, f"No nodes_json_file['nodes']" + assert "heads" in json_obj, f"No nodes_json_file['heads']" + + legal_layer_name_list = [ + "input", + "layout_transform", + "nn.batch_flatten", + "reshape", + "transpose", + "Avgpool2D", + "Conv2D", + "FC", + "Maxpool2D", + "Sum2D", + ] + mod_name = relay.backend.utils.mangle_module_name("") + name_regex = "(?P" + mod_name + "_mrvl_main_[0-9]+)" + for layer in json_obj["nodes"]: + assert "attrs" in layer, f"No json_obj['attrs']" + assert "layer_name" in layer["attrs"], f"No json_obj['attrs']['layer_name']" + layer_name = layer["attrs"]["layer_name"][0] + assert layer_name in legal_layer_name_list, f"Illegal layer name {layer_name}" + # + if layer_name != "input": + assert "name" in layer, f"No json_obj['name']" + assert re.match(name_regex, layer["name"]), f"Illegal name ({layer['name']})" + + return json_obj + + +def verify_json_codegen(nodes_json_file, model_verification_info={}): + """verify json codegen output JSON files.""" + assert nodes_json_file != "" + assert file_exists(nodes_json_file), f"{nodes_json_file} does not exist" + # + json_obj = check_json_integrity(nodes_json_file) + if model_verification_info is not {}: + if "heads_size" in model_verification_info: + expected_size = model_verification_info["heads_size"] + actual_size = len(json_obj["heads"]) + assert ( + actual_size == expected_size + ), f"heads size - expected {expected_size} != actual {actual_size}" + if "nodes_size" in model_verification_info: + expected_size = model_verification_info["nodes_size"] + actual_size = len(json_obj["nodes"]) + assert ( + actual_size == expected_size + ), f"nodes size - expected {expected_size} != actual {actual_size}" + + +def aot_runtime_gen( + nodes_json_filename, + consts_json_filename, + aot_fp16_cmd_opts, +): + """aot runtime gen.""" + # TODO(ccjoechou): add final code + mrvl_subgraph_runtime_model_binary = None + assert nodes_json_filename is not None + assert consts_json_filename is not None + assert aot_fp16_cmd_opts is not None + return mrvl_subgraph_runtime_model_binary + + +def aot_run(mrvl_subgraph_runtime_model_binary, aot_run_cmd_opts, inf_inp=[]): + """mrvl aot run output.""" + # TODO(ccjoechou): add final code + mrvl_subgraph_actual_fp16_output = [] + assert mrvl_subgraph_runtime_model_binary is not None + assert aot_run_cmd_opts is not None + assert inf_inp is not [] + return mrvl_subgraph_actual_fp16_output + + +def tvm_llvm_fp32_run(mod_mrvl_subgraph, mrvl_layers_in_mrvl_subgraph, inf_inp=[]): + """llvm run output.""" + # TODO(ccjoechou): add final code + mrvl_subgraph_golden_fp32_output = [] + assert mod_mrvl_subgraph is not None + assert mrvl_layers_in_mrvl_subgraph is not None + assert data_inp is not [] + return mrvl_subgraph_golden_fp32_output + + +def verify_mrvl_subgraph_aot_inf_result( + mrvl_subgraph_actual_fp16_output, + mrvl_subgraph_golden_fp32_output, + delta_config, +): + """verify inf output of mrvl subgraph.""" + # TODO(ccjoechou): add final code + assert mrvl_subgraph_actual_fp16_output is not None + assert mrvl_subgraph_golden_fp32_output is not None + assert delta_config is not None + return + + +def verify_aot_inf_result(actual_inf_output, delta_config): + """verify final inf output of model.""" + # TODO(ccjoechou): add final code + assert actual_inf_output is not None + assert delta_config is not None diff --git a/tests/python/contrib/test_mrvl/test_mrvl_codegen.py b/tests/python/contrib/test_mrvl/test_mrvl_codegen.py new file mode 100644 index 000000000000..2b31017ed362 --- /dev/null +++ b/tests/python/contrib/test_mrvl/test_mrvl_codegen.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, W0611, C0413 + +"""MRVL MLIP codegen tests""" + +import sys +import os +import numpy as np +import pytest +import logging + +logging.basicConfig(level=logging.DEBUG) +mylogger = logging.getLogger() + +import tvm +import tvm.relay.testing +from tvm import testing +from tvm import relay + +from test_mrvl.infrastructure import skip_json_codegen_test, skip_aot_runtime_test +from test_mrvl.infrastructure import aot_build_and_json_codegen, verify_json_codegen + + +def _get_single_random_data_inp(name, ishape, dtype): + data = {} + np.random.seed(0) + + if dtype == "uint8": + low, high = 0, 255 + else: + low, high = -127, 128 + data[name] = np.random.uniform(low, high, ishape).astype(dtype) + return data + + +def _get_single_input_mxnet_model_and_data_inp(mxnet_model, input_info): + try: + from gluoncv import data + except ImportError: + pytest.skip("Missing Gluoncv Package") + + """Convert Mxnet graph to relay.""" + (ishape, dtype, layout, im_fname, short) = input_info + # FIXME: can't use mxnet_model.input_names[0] + name = "data" + inputs = {name: ishape} + mod, params = relay.frontend.from_mxnet(mxnet_model, inputs) + + # we pre-process input for the NN model + # FIXME: to force data values of input_x to be inside range: [low=-127, high=128] + input_x, img = data.transforms.presets.ssd.load_test(im_fname, short=512) + data_inp = {} + data_inp[name] = input_x + return mod, params, data_inp + + +def _get_single_input_keras_model_and_random_data_inp(keras_model, input_info): + """Convert Keras graph to relay.""" + (ishape, dtype, layout) = input_info + name = keras_model.input_names[0] + inputs = {} + inputs[name] = ishape + mod, params = relay.frontend.from_keras(keras_model, inputs, layout=layout) + data_inp = _get_single_random_data_inp(name, ishape, dtype) + return mod, params, data_inp + + +def _exec_unix_cmd(os_cmd, verbose_prefix=""): + mylogger.info(f"Debug: {verbose_prefix}run cmd: {os_cmd}") + os.system(os_cmd) + + +def _aot_json_codegen_and_fp16_run_for_network( + model_name, + mod, + params, + data_inp, + model_verification_info={}, + json_codegen_only=False, +): + """Helper function to build and run a network.""" + mylogger.info(f"\nDebug: in _aot_json_codegen_and_fp16_run_for_network:\n{mod.astext(False)}") + if skip_json_codegen_test(): + return + + my_cwd = os.getcwd() + mylogger.info(f"\nDebug: cwd: {my_cwd}") + working_dir = f"{my_cwd}/test_mrvl_{model_name}/" + _exec_unix_cmd(f"rm -rf {working_dir}") + _exec_unix_cmd(f"mkdir -p {working_dir}") + ( + nodes_json_filename, + consts_json_filename, + mod_mrvl_subgraph, + mod_non_mrvl_subgraph, + mrvl_layers_in_mrvl_subgraph, + mrvl_layers_in_non_mrvl_subgraph, + ) = aot_build_and_json_codegen( + model_name, + working_dir, + mod, + params, + ) + mylogger.info(f"\nDebug: cwd: {model_verification_info}") + verify_json_codegen(nodes_json_filename, model_verification_info=model_verification_info) + + if json_codegen_only: + print("aot json codegen only") + return + + # check whether a Mrvl distribution package has been installed + if skip_aot_runtime_test(): + return + + # TODO(ccjoechou): add final code for fp16-fp32-mixed inf run and then + # uncomment following calls when they become available + # mrvl_subgraph_runtime_model_binary = aot_runtime_gen( + # nodes_json_filename, consts_json_filename, aot_fp16_cmd_opts, + # ) + # mrvl_subgraph_actual_fp16_output = aot_run( + # mrvl_subgraph_runtime_model_binary, aot_run_cmd_opts, inf_inp=data_inp, + # ) + # mrvl_subgraph_golden_fp32_output = tvm_llvm_fp32_run( + # mod_mrvl_subgraph, mrvl_layers_in_mrvl_subgraph, inf_inp=data_inp, + # ) + # verify_mrvl_subgraph_aot_inf_result( + # mrvl_subgraph_actual_fp16_output, + # mrvl_subgraph_golden_fp32_output, + # delta_config, + # ) + # actual_inf_output = tvm_llvm_fp32_run( + # mod_non_mrvl_subgraph, + # mrvl_layers_in_non_mrvl_subgraph, + # inf_inp=mrvl_subgraph_actual_fp16_output, + # ) + # verify_aot_inf_result(actual_inf_output, delta_config) + + +# TODO(ccjoechou): re-enable this test after a Mrvl BYOC bug can be resolved +@pytest.mark.skipif(True, reason="Skip test_relay_resnet18_aot_json_codegen() for now") +def test_relay_resnet18_aot_json_codegen(): + """Mrvl MLIP codegen (to JSON files) with ResNet18 model""" + + def get_model(dtype): + model_name = "resnet18" + ishape = (1, 3, 224, 224) + layout = "NCHW" + input_info = (ishape, layout) + mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=1) + name = "data" + data_inp = _get_single_random_data_inp(name, ishape, dtype) + return model_name, mod, params, data_inp + + dtype = "float32" + assert dtype == "float32" + _aot_json_codegen_and_fp16_run_for_network( + *get_model(dtype), + model_verification_info={}, + json_codegen_only=True, + ) + + +def test_ssd_resnet50_aot_json_codegen(): + """Mrvl MLIP codegen (to JSON files) with SSD-ResNet50 model""" + + def get_model(dtype): + try: + from gluoncv import model_zoo + except ImportError: + pytest.skip("Missing Gluoncv Package") + + model_name = "ssd_512_resnet50_v1_voc" + ssd_resnet50 = model_zoo.get_model(model_name, pretrained=True) + short = 512 + ishape = (1, 3, short, short) + layout = "NCHW" + # we will use the street_small.jpg image as the raw input tensor + im_fname = tvm.contrib.download.download_testdata( + "https://github.com/dmlc/web-data/blob/main/" + + "gluoncv/detection/street_small.jpg?raw=true", + "street_small.jpg", + ) + input_info = (ishape, dtype, layout, im_fname, short) + (mod, params, data_inp) = _get_single_input_mxnet_model_and_data_inp( + ssd_resnet50, input_info + ) + return model_name, mod, params, data_inp + + # setup per-test verification info to be checked + model_verification_info = {} + model_verification_info["nodes_size"] = 104 + model_verification_info["heads_size"] = 18 + + dtype = "float32" + assert dtype == "float32" + _aot_json_codegen_and_fp16_run_for_network( + *get_model(dtype), + model_verification_info=model_verification_info, + json_codegen_only=True, + ) + + +# TODO(ccjoechou): re-enable this test after either (1) relay Keras frontend can also support +# data_format = channels_first (currently relay Keras frontend supports +# data_format = channels_last only); or (2) Mrvl BYOC backend can support NHWC as +# the input data format (currently, Mrvl BYOC backend supports only NCHW format) +@pytest.mark.skipif(True, reason="Skip test_mobilenet_aot_json_codegen() for now") +def test_mobilenet_aot_json_codegen(): + """Mrvl MLIP codegen (to JSON files) with MobileNet model""" + + def get_model(dtype): + try: + from tensorflow.python.keras import backend_config + from keras.applications import MobileNetV2 + + backend_config.set_image_data_format("channels_first") + except ImportError: + pytest.skip("Missing keras module") + + mobilenet = MobileNetV2( + include_top=True, weights="imagenet", input_shape=(3, 224, 224), classes=1000 + ) + model_name = "mobilenet" + ishape = (1, 3, 224, 224) + layout = "NCHW" + input_info = (ishape, dtype, layout) + mod, params, data_inp = _get_single_input_keras_model_and_random_data_inp( + mobilenet, input_info + ) + return model_name, mod, params, data_inp + + dtype = "float32" + assert dtype == "float32" + _aot_json_codegen_and_fp16_run_for_network( + *get_model(dtype), + model_verification_info={}, + json_codegen_only=True, + ) + + +if __name__ == "__main__": + if sys.platform == "win32": + print("Skip test on Windows for now") + sys.exit(0) + pytest.main([__file__]) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index bcd9066b1ba7..291a5f7c8c2a 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -208,7 +208,7 @@ def test_conv2d_attrs(): check_json_roundtrip(out) -def test_large_grpah(): +def test_large_graph(): # Test large graphs to avoid stack overflow in serialize/deserialize size = int(1e5) var = [relay.var("var_" + str(i), shape=(2, 3)) for i in range(size)] @@ -233,4 +233,4 @@ def test_large_grpah(): test_tuple_get_item() test_op() test_conv2d_attrs() - test_large_grpah() + test_large_graph() diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 3a4500d0e2f0..4380c17583a0 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -50,3 +50,4 @@ echo set\(USE_LIBBACKTRACE ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(USE_ETHOSU ON\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake +echo set\(USE_MRVL ON\) >> config.cmake diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 55f5b96db3c3..28bb454cecdd 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -61,6 +61,7 @@ run_pytest cython ${TVM_INTEGRATION_TESTSUITE_NAME}-dso_plugin_module apps/dso_p run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME} tests/python/integration +sudo pip3 install gluoncv # Ignoring Arm(R) Ethos(TM)-U NPU tests in the collective to run to run them in parallel in the next step. run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib tests/python/contrib --ignore=tests/python/contrib/test_ethosu run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib-test_ethosu tests/python/contrib/test_ethosu -n auto