From eb344fdfde4bc18afa04cbabc104f9fd2f09d06d Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 24 Jan 2022 15:50:40 -0800 Subject: [PATCH 01/49] [MetatSchedule] testcase for TensorRT builder/runner --- .../unittest/test_meta_schedule_byoc_trt.py | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 tests/python/unittest/test_meta_schedule_byoc_trt.py diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_trt.py new file mode 100644 index 000000000000..ca38a7d118a7 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_byoc_trt.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ + + +import sys +import pytest +import itertools +import tvm +from tvm import relay +from tvm.relay import testing +from tvm.relay.op.contrib import tensorrt +import numpy as np +from typing import List, Tuple + +# from tvm import script +# from tvm._ffi import register_func +# from tvm.runtime import Module +from tvm._ffi import register_func +from tvm.relay.testing.init import Initializer +from tvm.target import Target +from tvm.runtime import Module +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) + +from tvm.tir import FloatImm +from tvm.meta_schedule.testing import get_network + +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) + +# conv2d+relu network +def get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, +): + + data = relay.var("data", relay.TensorType(data_shape, dtype)) + weight = relay.var("weight") + + net = relay.nn.conv2d( + data=data, + weight=weight, # conv kernel + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + channels=out_channels, + kernel_size=kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + net = relay.add(net, net) + net = relay.nn.relu(net) + + inputs = relay.analysis.free_vars(net) + return relay.Function(inputs, net) + + +def verify_meta_schedule_with_tensorrt( + mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm" +): + if use_meta_sched: + # With meta_schedule + dev = "nvidia/geforce-rtx-2080" + + # Build + if use_trt: + + def relay_build_with_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + + mod, config = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + return tvm.relay.build_module._build_module_no_factory( + mod, "cuda", "llvm", params + ) + + builder = LocalBuilder(f_build=relay_build_with_tensorrt) + else: + + def relay_build_without_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + # @Sung: Weird. Cannot pass keyword arg + return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) + + builder = LocalBuilder(f_build=relay_build_without_tensorrt) + + builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) + + (builder_result,) = builder.build([builder_input]) + assert builder_result.error_msg is None + assert builder_result.artifact_path is not None + + # Run + evaluator_config = EvaluatorConfig( + number=5, + repeat=2, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + + runner_input = RunnerInput( + builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)] + ) + + def eval_func(rt_mod, device, evaluator_config, repeated_args): + rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device)) + + eval = rt_mod.module.time_evaluator( + func_name="run", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = eval(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + runner = LocalRunner( + evaluator_config=evaluator_config, + f_run_evaluator=eval_func, + ) + + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.run_secs is not None + assert runner_result.error_msg is None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + else: + # Without meta_schedule + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params + ).evaluate() + + +def test_conv2d_relu(): + data_shape = (1, 1280, 14, 14) + out_channels = 256 + kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 + data_layout, kernel_layout = "NCHW", "OIHW" + dtype = "float32" + + f = get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, + ) + + mod, params = testing.create_workload(f) + verify_meta_schedule_with_tensorrt(mod, params, data_shape) + + +@pytest.mark.parametrize( + "model_name", + ["resnet-50", "mobilenet"], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("use_trt", [True, False]) +def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): + + mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size) + verify_meta_schedule_with_tensorrt( + mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm" + ) + + +# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 03646ac81f0e7ac0b0d672b408a8b5455d69b981 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Mon, 24 Jan 2022 16:40:08 -0800 Subject: [PATCH 02/49] [Runtime][PipelineExecutor] Add Pipeline Executor Interface (#10010) Adding interfaces into Pipeline Executor to "run", "stop","set input", and "get input" from the pipeline executor, In this patch, we also implemented the "BackendRuntime" structure to wrap the graph runtime interface in order to support pipeline executor interface and implement data copy method. This method is used to transfer data between two backend runtimes. --- python/tvm/contrib/pipeline_executor.py | 44 +++++++- src/runtime/pipeline/pipeline_executor.cc | 75 ++++++++++++- src/runtime/pipeline/pipeline_executor.h | 31 ++++++ src/runtime/pipeline/pipeline_scheduler.cc | 12 ++- src/runtime/pipeline/pipeline_scheduler.h | 4 +- src/runtime/pipeline/pipeline_struct.h | 105 +++++++++++++++++++ tests/python/relay/test_pipeline_executor.py | 11 +- 7 files changed, 271 insertions(+), 11 deletions(-) diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index c75aa3dad43b..6e991f0c8d7a 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -115,10 +115,22 @@ def __init__(self, module): else: self.module = module # Get the packed functions from the pipeline executor. - self._get_num_outputs = self.module["get_num_outputs"] - self._get_input_pipeline_map = self.module["get_input_pipeline_map"] self._get_params_group_pipeline_map = self.module["get_params_group_pipeline_map"] + self._run = self.module["run"] + self._stop = self.module["stop"] self._set_param = self.module["set_param"] + self._set_input = self.module["set_input"] + self._get_input = self.module["get_input"] + self._get_num_outputs = self.module["get_num_outputs"] + self._get_input_pipeline_map = self.module["get_input_pipeline_map"] + + def run(self, sync=False): + """Run the pipeline executor.""" + self._run(sync) + + def stop(self): + """Stop the pipeline executor.""" + self._stop() def get_input_pipeline_map(self, name): """Using the "name" to get the corresponding subgraph index and also get the "input name" @@ -145,6 +157,21 @@ def get_params_group_pipeline_map(self, name): """ return self._get_params_group_pipeline_map(name) + def set_input(self, key, value): + """Set the input via input name. + + Parameters + ---------- + key : str + The input name + value : array_like. + The input value + """ + v = self._get_input(key) + if v is None: + raise RuntimeError("Could not find '%s' in pipeline's inputs" % key) + v.copyfrom(value) + def set_params(self, params_group_name, params_data): """Set the parameter group value given the parameter group name. Note that the parameter group name is declared in the pipeline executor config. @@ -163,6 +190,19 @@ def set_params(self, params_group_name, params_data): for key, val in params_data.items(): self._set_param(params_group_name, key, val) + def get_input(self, key): + """Get the input via an input name. + Parameters + ---------- + key : str + The input key + Returns + ------- + data : NDArray + The input data. + """ + return self._get_input(key) + @property def num_outputs(self): """Get the number of outputs. diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index 0ca291a2fbbe..30c09514480f 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -58,6 +58,26 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, LOG(FATAL) << "Function only support the parameter name and the key in the form of string"; } }); + } else if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + this->SetInput(args[0].operator String(), args[1]); + } else { + LOG(FATAL) << "Function only support the input name value in the form of string"; + } + }); + } else if (name == "get_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + *rv = this->GetInput(args[0].operator String()); + } else { + LOG(FATAL) << "Function only support the input name value in the form of string"; + } + }); + } else if (name == "run") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(args[0]); }); + } else if (name == "stop") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Stop(); }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(); @@ -65,6 +85,32 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, return nullptr; } +/*! + * \brief set input to the runtime module. + * \param input_name The input name. + * \param data_in The input data. + */ +void PipelineExecutor::SetInput(std::string input_name, DLTensor* data_in) { + std::pair indexs = this->GetInputIndex(input_name); + if (indexs.first < 0 || indexs.first >= static_cast(runtimes_.size())) { + this->Stop(); + LOG(FATAL) << "input name " << input_name << " not found."; + } + runtimes_[indexs.first]->SetInput(indexs.second, data_in); +} +/*! + * \brief get input from the runtime module. + * \param input_name The input name. + * \return Return the input data for a specific input name. + */ +NDArray PipelineExecutor::GetInput(std::string input_name) { + std::pair indexs = this->GetInputIndex(input_name); + if (indexs.first < 0 || indexs.first >= static_cast(runtimes_.size())) { + this->Stop(); + LOG(FATAL) << "input name " << input_name << " not found."; + } + return runtimes_[indexs.first]->GetInput(indexs.second); +} /*! * \brief Using the global input name to get the index, and also get the input interface name of corresponding subgraph from the input connection configuration. @@ -85,6 +131,20 @@ int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) { return param_connection_config[name]; } +/*! + * \brief Run the pipeline executor. + * \param serialized_mode Whether run the pipeline executor in serialized mode. + */ +void PipelineExecutor::Run(bool serialized_mode) { + // TODO(huajsj): Run the pipeline executor. +} +/*! + * \brief Stop the pipeline executor. + */ +void PipelineExecutor::Stop() { + // TODO(huajsj): Stop the pipeline executor. +} + /*! * \brief Use the mod_config information to create a graph runtime list. * \param mod_config The config information that generates by the export library function call. @@ -152,6 +212,16 @@ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_ int module_index = this->GetParamsGroupPipelineMap(param_group_name); // TODO(huajsj): set the parameters into runtime module. } +/*! + * \brief Return the input index and module index for a given input name. + * \param name The input name. + * \return std::pair A pair of module index and the input index. + */ +std::pair PipelineExecutor::GetInputIndex(const std::string& name) { + std::pair index = input_connection_config[name]; + auto gruntime = runtimes_[index.first]; + return std::make_pair(index.first, gruntime->GetInputIndex(index.second)); +} /*! * \brief Initialize the pipeline executor with a list of modules to be pipelined * and config in JSON format. @@ -165,9 +235,10 @@ void PipelineExecutor::Init(const std::vector& modules, const std::strin dmlc::JSONReader reader(&is); this->LoadConfig(&reader); ICHECK(!pipeline_config_.Empty()) << "The pipeline config information is empty."; + num_outputs_ = pipeline_config_.GetGlobalOutputNum(); // Initialize the pipeline function class used for pipeline thread pool management - // and schedule etc. This function returns the number of output. - num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_); + // and schedule etc. This function returns a list of runtime. + runtimes_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_); return; } diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index 6d4c7ba1fa4f..7dc5baf17ee1 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -82,6 +83,18 @@ class TVM_DLL PipelineExecutor : public ModuleNode { * \return Returning a runtime module index. */ int GetParamsGroupPipelineMap(const std::string& name); + /*! + * \brief Use the input name to set the input data of pipeline executor. + * \param input_name The input name. + * \param data_in The input data. + */ + void SetInput(std::string input_name, DLTensor* data_in); + /*! + * \brief Use the input name to get the input data. + * \param input name The input name. + * \return Return input data. + */ + NDArray GetInput(std::string input_name); /*! * \brief Use the parameters group name to get the specific backend runtime then use * the param_key_name to set param data for the said backend runtime. @@ -96,6 +109,22 @@ class TVM_DLL PipelineExecutor : public ModuleNode { * \return The number of outputs. */ int NumOutputs() const { return num_outputs_; } + /*! + * \brief Run the pipeline executor. + * \param serialized_mode Whether run the pipeline executor in serialized mode. + */ + void Run(bool serialized_mode); + /*! + * \brief Stop the pipeline executor. + */ + void Stop(); + /*! + * \brief A pipeline input with a specific name correspond with a input of a specific + * backend module, this function return a module index and a input index in "pair" + * form for a input name. + * return Return a module index and a input index. + */ + std::pair GetInputIndex(const std::string& name); /*!\brief Load the module files information.*/ ModuleConfig& LoadModuleConfig(dmlc::JSONReader* reader) { reader->BeginArray(); @@ -145,6 +174,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode { ModuleConfig mod_config_; /*!\brief How many outputs are in this pipeline executor.*/ size_t num_outputs_ = 0; + /*!The list of backend runtime module.*/ + std::vector> runtimes_; /*!\brief Json loader.*/ void LoadConfig(dmlc::JSONReader* reader) { reader->BeginObject(); diff --git a/src/runtime/pipeline/pipeline_scheduler.cc b/src/runtime/pipeline/pipeline_scheduler.cc index 67a9795c47d4..499d75784a15 100644 --- a/src/runtime/pipeline/pipeline_scheduler.cc +++ b/src/runtime/pipeline/pipeline_scheduler.cc @@ -27,11 +27,15 @@ namespace runtime { * \param modules The list of graph executor modules. * \param pipeline_conf The dependency information of each graph executor module. */ -size_t PipelineScheduler::PipelineInit(const std::vector& modules, - const ConfigPipelineExecution& pipeline_config) { +std::vector> PipelineScheduler::PipelineInit( + const std::vector& modules, const ConfigPipelineExecution& pipeline_config) { + std::vector> runtimes; graph_modules_ = modules; - int num_output = pipeline_config.GetGlobalOutputNum(); - return num_output; + for (size_t i = 0; i < graph_modules_.size(); i++) { + auto runItem = std::make_shared(graph_modules_[i], i); + runtimes.push_back(runItem); + } + return runtimes; } } // namespace runtime } // namespace tvm diff --git a/src/runtime/pipeline/pipeline_scheduler.h b/src/runtime/pipeline/pipeline_scheduler.h index 0572e060a1b8..02c44420bd51 100644 --- a/src/runtime/pipeline/pipeline_scheduler.h +++ b/src/runtime/pipeline/pipeline_scheduler.h @@ -41,8 +41,8 @@ class PipelineScheduler { * \param modules The list of graph executor module. * \param pipeline_config The dependency information of each graph executor module. */ - size_t PipelineInit(const std::vector& modules, - const ConfigPipelineExecution& pipeline_config); + std::vector> PipelineInit( + const std::vector& modules, const ConfigPipelineExecution& pipeline_config); private: /*!\brief The list of graph executors.*/ diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index aa831070ccdb..40628e989a90 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -21,12 +21,16 @@ #include #include #include +#include +#include #include #include #include #include #include +namespace tvm { +namespace runtime { /*! * \brief All binding information of a output interface. */ @@ -292,7 +296,106 @@ struct ParamConnectionConfig { } } }; +/* + *\brief Backend Runtime. + */ +class BackendRuntime { + private: + /*\brief The index of runtime indicates the runtime position in the pipeline.*/ + int runtime_idx_; + /*\brief The Runtime module of a backend graph executor.*/ + Module module_; + /*! + *\brief In order to transfer data from one backend runtime to another, we need a local + * tensor variable as a medium. "input_tensor_local_copy_" is a map including + * input data and local tensor vairable. + */ + std::unordered_map input_tensor_local_copy_; + /*!\brief The packed functions.*/ + tvm::runtime::PackedFunc set_input_; + tvm::runtime::PackedFunc get_input_; + tvm::runtime::PackedFunc get_num_output_; + tvm::runtime::PackedFunc get_num_inputs_; + tvm::runtime::PackedFunc get_input_index_; + /*! + * \brief Copying from a given tensor and using 'CPU' as the device. + */ + inline DLTensor* CopyDLTensorToCPU(const DLTensor* from) { + DLTensor* ret = NULL; + TVMArrayAlloc(from->shape, from->ndim, from->dtype.code, from->dtype.bits, from->dtype.lanes, + kDLCPU, 0, &ret); + return ret; + } + /*!\brief Creating a new NDArray with same shape and data type as the given DLTensor.*/ + NDArray CreateNDArrayFromDLTensor(const DLTensor* from) { + std::vector shape; + for (int i = 0; i < from->ndim; i++) { + shape.push_back(from->shape[i]); + } + auto ndarray = NDArray::Empty(shape, from->dtype, from->device); + ndarray.CreateView(shape, from->dtype); + return ndarray; + } + /* + *\brief Copying data from one DLTensor to another. + */ + void CopyFromTo(DLTensor* from, DLTensor* to) { + // When the 'from' device and the 'to' device are not the same, we use a temporary CPU + // DLTensor as the bridge. + if (from->device.device_type != to->device.device_type && from->device.device_type != kDLCPU && + to->device.device_type != kDLCPU) { + DLTensor* dltensor_local = nullptr; + if (input_tensor_local_copy_.find(to) == input_tensor_local_copy_.end()) { + dltensor_local = CopyDLTensorToCPU(from); + input_tensor_local_copy_[to] = dltensor_local; + } else { + dltensor_local = input_tensor_local_copy_[to]; + } + TVMArrayCopyFromTo(from, dltensor_local, nullptr); + from = dltensor_local; + } + TVMArrayCopyFromTo(from, to, nullptr); + } + + public: + BackendRuntime(Module mod, int mod_idx) { + module_ = mod; + runtime_idx_ = mod_idx; + get_input_index_ = module_.GetFunction("get_input_index"); + get_num_output_ = module_.GetFunction("get_num_outputs"); + get_num_inputs_ = module_.GetFunction("get_num_inputs"); + set_input_ = module_.GetFunction("set_input"); + get_input_ = module_.GetFunction("get_input"); + } + BackendRuntime(void) {} + ~BackendRuntime() { + for (auto data : input_tensor_local_copy_) { + TVMArrayFree(data.second); + } + } + /*!\brief Return the index of the current module.*/ + int GetModuleIndex() { return runtime_idx_; } + /*!\brief Return the number of output*/ + int NumOutputs() const { return get_num_output_(); } + /*!\brief Return the number of input*/ + int NumInputs() const { return get_num_inputs_(); } + /*!\brief Setting the data to this module via input index.*/ + void SetInput(const int index, DLTensor* data_in) { + NDArray input = get_input_(index); + DLTensor* dltensor_input = const_cast(input.operator->()); + CopyFromTo(data_in, dltensor_input); + } + /*!\brief Setting the data to the current runtime moduel via the input name. */ + void SetInput(const std::string name, DLTensor* data_in) { + int index = this->GetInputIndex(name); + SetInput(index, data_in); + } + /*!\brief Getting the input data via the input index.*/ + NDArray GetInput(int index) const { return get_input_(index); } + /*!\bief Getting the input data via the input name.*/ + int GetInputIndex(const std::string& name) { return get_input_index_(name); } +}; /*! * \brief The information used to initialize the graph executor module, the information * come from the export library function call. @@ -309,4 +412,6 @@ struct GraphModuleLoadInfo { }; /*! The Module information of each module.The 'int' is module index. */ using ModuleConfig = std::unordered_map; +}; // namespace runtime +}; // namespace tvm #endif // TVM_RUNTIME_PIPELINE_PIPELINE_STRUCT_H_ diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index 83cf237dbfcc..99c24ef93b80 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -292,8 +292,17 @@ def test_pipeline(): assert input_map[0] == "0" and input_map[1] == "data_0" module_index = pipeline_module_test.get_params_group_pipeline_map("param_0") assert module_index == 1 - # Use the parameters group name to set parameters. + # Using the parameters group name to set parameters. pipeline_module_test.set_params("param_0", customized_parameters) + # Getting the result from the pipeline executor + data_a = np.full(dshape, 1).astype("float32") + data_b = np.full(dshape, 2).astype("float32") + pipeline_module_test.set_input("data_a", data_a) + pipeline_module_test.set_input("data_b", data_b) + input_data = pipeline_module_test.get_input("data_b") + tvm.testing.assert_allclose(data_b, input_data.numpy()) + input_data = pipeline_module_test.get_input("data_a") + tvm.testing.assert_allclose(data_a, input_data.numpy()) if __name__ == "__main__": From 29a77e3b1344623b96108e31ce9b04e23839e230 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Jan 2022 09:51:26 +0900 Subject: [PATCH 03/49] [skip ci][Docker, CI] Update DGL installation, temp disable DGL tutorial (#10067) --- docker/install/ubuntu_install_dgl.sh | 2 +- gallery/how_to/work_with_relay/build_gcn.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_dgl.sh b/docker/install/ubuntu_install_dgl.sh index 50c540983aa0..94591ce29e73 100755 --- a/docker/install/ubuntu_install_dgl.sh +++ b/docker/install/ubuntu_install_dgl.sh @@ -20,4 +20,4 @@ set -e set -u set -o pipefail -pip3 install dgl +pip3 install dgl==v0.7.2 -f https://data.dgl.ai/wheels/repo.html diff --git a/gallery/how_to/work_with_relay/build_gcn.py b/gallery/how_to/work_with_relay/build_gcn.py index 2dcc7ba49b80..4352088026ef 100644 --- a/gallery/how_to/work_with_relay/build_gcn.py +++ b/gallery/how_to/work_with_relay/build_gcn.py @@ -120,6 +120,11 @@ def evaluate(data, logits): """ dataset = "cora" +# Temporary disable running load_dataset(dataset) until the CI issue is resolved +import sys + +sys.exit() + g, data = load_dataset(dataset) num_layers = 1 From 452e168cd2c5797d7b835c0a9febe6caba9d0bb7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 26 Jan 2022 15:42:54 +0900 Subject: [PATCH 04/49] [CUTLASS] Profile only the largest-possible alignment by default (#10036) * introduce profile_all_alignments option * add profile_all_alignment option to API * wip * fixed dynamic case * black * update gen_gemm too * minor improvement * fix * all tests work * add doc * fixed for sm = 75 case * fix typo * remove unused import * profile_all -> find_first_valid * fix --- python/tvm/contrib/cutlass/build.py | 46 ++++++++++++------- python/tvm/contrib/cutlass/gen_conv2d.py | 36 +++++++-------- python/tvm/contrib/cutlass/gen_gemm.py | 50 +++++++++++---------- python/tvm/contrib/cutlass/gen_tensor_op.py | 44 +++++++++++++++--- python/tvm/runtime/module.py | 6 ++- tests/python/contrib/test_cutlass.py | 8 +++- 6 files changed, 124 insertions(+), 66 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index c919ff283343..fb59d02f9450 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -104,7 +104,7 @@ def select_gemm_kernel( arg1_dtype, use_3xtf32, batched, - profile_all, + find_first_valid, use_multiprocessing, ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic @@ -126,10 +126,10 @@ def select_gemm_kernel( arg1_dtype, use_3xtf32, batched=batched, - profile_all=profile_all, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) - if profile_all: + if not find_first_valid: logger.info("The best kernel is %s", name) else: logger.info("Picked the first kernel found %s", name) @@ -146,7 +146,7 @@ def handle_batch_matmul( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for batch_matmul op workload.""" @@ -165,7 +165,7 @@ def handle_batch_matmul( arg1_dtype, use_3xtf32, True, - profile_all, + find_first_valid, use_multiprocessing, ) @@ -191,7 +191,7 @@ def handle_dense( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for dense op workload.""" @@ -210,7 +210,7 @@ def handle_dense( arg1_dtype, use_3xtf32, False, - profile_all, + find_first_valid, use_multiprocessing, ) @@ -237,7 +237,8 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, - profile_all, + profile_all_alignments, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for conv2d op workload.""" @@ -257,10 +258,11 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, - profile_all=profile_all, + profile_all_alignments, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) - if profile_all: + if not find_first_valid: logger.info("The best kernel is %s", name) else: logger.info("Picked the first kernel found %s", name) @@ -272,7 +274,13 @@ def handle_conv2d( def tune_cutlass_kernels( - mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp" + mod, + sm, + use_3xtf32=True, + profile_all_alignments=False, + find_first_valid=False, + use_multiprocessing=False, + tmp_dir="./tmp", ): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. @@ -286,7 +294,14 @@ def tune_cutlass_kernels( An integer specifying the compute capability. For example, 75 for Turing and 80 or 86 for Ampere. - profile_all : bool + use_3xtf32 : bool + Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for + fp32 inputs on tensorcore. + + profile_all_alignments : bool + When True, profile all kernal variants with smaller alignments than the largest possible. + + find_first_valid : bool Whether or not profile all candidate kernels, or stop profiling after the first applicable kernel is found. @@ -342,7 +357,8 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + profile_all_alignments, + find_first_valid, use_multiprocessing, ) ) @@ -357,7 +373,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ) ) @@ -372,7 +388,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ) ) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index c09017adfd95..b6dba009f2b2 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """Conv2d kernel generator and profiler for CUTLASS.""" -import re from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler from .conv2d_profiler import Conv2dProfilerEmitter @@ -168,14 +167,6 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32): ) return {"name": name, "opdef": opdef} - def check_align(self, op_name, C, K): - """Filter out kernels that cannot be supported.""" - match = re.match(".*_align([1-9]+)", op_name) - assert match is not None and len(match.groups()) == 1 - # The same alignment is used for all axes - align = int(match.groups()[0]) - return all([dim % align == 0 for dim in [C, K]]) - def select_op( self, d_shape, @@ -187,7 +178,8 @@ def select_op( data_dtype, weight_dtype, use_3xtf32, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, ): """ @@ -216,12 +208,16 @@ def select_op( return self.cache[workload] ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32 + out_dtype, + data_dtype, + weight_dtype, + enumerate_conv2d_operators, + lambda align: all([dim % align == 0 for dim in [IC, OC]]), + use_3xtf32, + profile_all_alignments, ) - ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) - - if profile_all: + if not find_first_valid: self.engine.compile_all(ops, use_multiprocessing) args = ( @@ -232,7 +228,7 @@ def select_op( for op in ops: out = self.engine.evaluate(op, args.split(" ")) op["runtime"] = out - if out < float("inf") and not profile_all: + if out < float("inf") and find_first_valid: self.cache[workload] = op return op @@ -252,11 +248,12 @@ def profile( data_dtype, weight_dtype, use_3xtf32=True, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, ): """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. + If find_first_valid is True, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( @@ -269,8 +266,9 @@ def profile( data_dtype, weight_dtype, use_3xtf32, - profile_all=profile_all, - use_multiprocessing=use_multiprocessing, + profile_all_alignments, + find_first_valid, + use_multiprocessing, ) name, opdef = create_conv2d_operator_with_epilogue( diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 445acb9305c8..bb591985cab5 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" -import re from .gemm_operation import GemmOperation, EmitGemmInstance from .gemm_profiler import GemmProfilerEmitter from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP @@ -63,8 +62,9 @@ def create_gemm_operator_with_epilogue( swizzling_functor, ) - return op.procedural_name(), EmitGemmInstance().emit( - op, no_beta_scaling=no_beta_scaling, batched=batched + return ( + op.procedural_name(), + EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling, batched=batched), ) @@ -150,17 +150,6 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm self.cache = {} - def check_align(self, op_name, M, N, K): - """Filter out kernels that cannot be supported.""" - match = re.match(".*_align([1-9]+)", op_name) - assert match is not None and len(match.groups()) == 1 - # The same alignment is used for all axes - align = int(match.groups()[0]) - # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. - # See https://github.com/NVIDIA/cutlass/issues/362. - # When the above issue is resolved, we can remove the alignment check on M below. - return all([dim % align == 0 for dim in [M, N, K]]) - def get_default( self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False ): @@ -168,8 +157,15 @@ def get_default( For now, the default kernel was picked arbitrary. """ ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32 + out_dtype, + arg0_dtype, + arg1_dtype, + enumerate_gemm_operators, + lambda align: align == 1, # Only request align1 kernels + use_3xtf32, + profile_all_alignments=True, # To include all align1 kernels ) + default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] if arg0_dtype == "float32": @@ -200,7 +196,8 @@ def select_op( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, ): """ @@ -211,22 +208,27 @@ def select_op( op = self.cache[(M, N, K)] return op + # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. + # See https://github.com/NVIDIA/cutlass/issues/362. + # When the above issue is resolved, we can remove the alignment check on M below. + ops = GENERATOR_FUNC_TABLE[self.sm]( out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, - use_3xtf32=use_3xtf32, + lambda align: all([dim % align == 0 for dim in [M, N, K]]), + use_3xtf32, + profile_all_alignments=profile_all_alignments, ) - ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) - if profile_all: + if not find_first_valid: self.engine.compile_all(ops, use_multiprocessing) for op in ops: out = self.engine.evaluate(op, [M, N, K]) op["runtime"] = out - if out < float("inf") and not profile_all: + if out < float("inf") and find_first_valid: self.cache[(M, N, K)] = op return op @@ -244,12 +246,13 @@ def profile( arg0_dtype, arg1_dtype, use_3xtf32=True, - profile_all=True, + profile_all_alignments=False, + find_first_valid=False, use_multiprocessing=False, batched=False, ): """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. + If find_first_valid is True, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( @@ -260,7 +263,8 @@ def profile( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all=profile_all, + profile_all_alignments=profile_all_alignments, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 6bb4f290233e..97af84e76990 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -62,7 +62,9 @@ def generate_tensor_op_common( return ops -def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): +def generate_sm75_tensor_op_1688( + out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False +): """Generate GEMM or Conv2D kernels for Turing.""" assert out_dtype in ["float32", "float16", "int32"] min_cc = 75 @@ -114,6 +116,12 @@ def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): ([64, 64, 64], 2, [2, 2, 1], min_cc, max_cc), ] + alignment_constraints = [align for align in alignment_constraints if check_align(align)] + assert len(alignment_constraints) > 0 + + if not profile_all_alignments: + alignment_constraints = [alignment_constraints[0]] + def get_tile_descriptions(math_inst): return [ TileDescription(threadblock_shape, stages, warp_count, math_inst, min_cc, max_cc) @@ -125,7 +133,15 @@ def get_tile_descriptions(math_inst): ) -def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator, use_3xtf32=True): +def generate_sm80_tensor_op_16816( + out_dtype, + arg0_dtype, + arg1_dtype, + op_creator, + check_align, + use_3xtf32=True, + profile_all_alignments=False, +): """Generate GEMM or Conv2D kernels for Ampere.""" min_cc = 80 max_cc = 1024 @@ -218,15 +234,31 @@ def get_tile_descriptions(math_inst): for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] + alignment_constraints = [align for align in alignment_constraints if check_align(align)] + + if len(alignment_constraints) > 0 and not profile_all_alignments: + alignment_constraints = [alignment_constraints[0]] + if arg0_dtype != "float32" and arg1_dtype != "float32": - sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator) + sm75_kernels = generate_sm75_tensor_op_1688( + out_dtype, + arg0_dtype, + arg1_dtype, + op_creator, + check_align, + False, + profile_all_alignments, + ) else: # TF32 (float32 + float32 case) is only supported on sm80 sm75_kernels = [] - sm80_kernels = generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions, op_creator - ) + if len(alignment_constraints) > 0: + sm80_kernels = generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, op_creator + ) + else: + sm80_kernels = [] # TODO(masahi): For int8 kernels, The CUTLASS generator modifies the output tensor alignment # after ops are created. Revisit how important this modification is. diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index da7c52ad119e..cf2787dda750 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -411,7 +411,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No "c", "cc", "cpp", - ], "The module.format needs to be either c, cc or cpp" + "cu", + ], "The module.format needs to be either c, cc, cpp or cu." object_format = module.format has_c_module = True else: @@ -426,7 +427,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No "c", "cc", "cpp", - ], "The module.format needs to be either c, cc or cpp" + "cu", + ], "The module.format needs to be either c, cc, cpp, or cu." object_format = module.format else: object_format = "c" diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 57f2f39c641b..00506ecf0527 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -188,7 +188,8 @@ def profile_and_build( mod, sm, use_3xtf32=use_3xtf32, - profile_all=False, + profile_all_alignments=False, + find_first_valid=True, use_multiprocessing=False, tmp_dir=tmp_dir, ) @@ -239,6 +240,9 @@ def verify_dense( ): if not has_cutlass(): return + if sm < 80 and data_dtype == "float32": + return + mod = tvm.IRModule.from_expr(func) typ = relay.transform.InferType()(mod)["main"].body.checked_type out_dtype = typ.dtype @@ -450,6 +454,8 @@ def verify_conv2d( ): if not has_cutlass(): return + if sm < 80 and data_dtype == "float32": + return mod_nchw = tvm.IRModule.from_expr(expr_nchw) mod_ref = tvm.IRModule.from_expr(expr_ref) From d2ac944b6487bb56ff905c81136c41986b4f9f50 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 25 Jan 2022 23:45:24 -0800 Subject: [PATCH 05/49] [Meta Schedule] Add `ApplyHisotryBest` Meta Schedule Context (#10049) * Add ApplyHisotryBest. Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng * Retrigger CI. * Update integration.py Co-authored-by: Junru Shao Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Siyuan Feng --- python/tvm/meta_schedule/integration.py | 9 ++- src/meta_schedule/integration.cc | 22 ++++++- src/meta_schedule/utils.h | 16 +++++ .../test_meta_schedule_integration.py | 58 +++++++++++++++++++ 4 files changed, 103 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 47003c6faa25..794591cefed3 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -25,6 +25,7 @@ from tvm.target import Target from tvm.tir import PrimFunc +from .database import Database from . import _ffi_api @@ -174,7 +175,13 @@ def __init__(self) -> None: @register_object("meta_schedule.ApplyHistoryBest") class ApplyHistoryBest(MetaScheduleContext): - pass + """An integration context that allows application of historically best record from database""" + + database: Database + """ The database to be queried from""" + + def __init__(self, database) -> None: + self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member def extract_task( diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index cf4262814947..e9d3012f789d 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace meta_schedule { @@ -112,7 +114,21 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Optional> dispatched) { - throw; + ICHECK(dispatched.defined()); + ICHECK_EQ(dispatched.value().size(), 1); + ICHECK(HasOnlyOneFunction(mod)) << mod; + IRModule prim_mod = dispatched.value()[0]; + ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + // Unify func name to make sure it can be found in database + prim_mod = UnifyFuncName(prim_mod); + if (database->HasWorkload(prim_mod)) { + Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); + if (records.size() == 1) { + LOG(INFO) << "Applied history best for " << task_name << "."; + return records[0]->workload->mod; + } + } + return NullOpt; } /**************** FFI ****************/ @@ -146,6 +162,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { return TaskExtraction(); }); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") + .set_body_typed([](Database database) -> ApplyHistoryBest { + return ApplyHistoryBest(database); + }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index bd76ca794a9a..afeb159052ee 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -351,6 +351,22 @@ inline int GetTargetNumCores(const Target& target) { return num_cores; } +/*! + * \brief Unify the function name in workload to "main". + * \param mod The workload. + * \return The new workload with unified function name. + * \note If the name is not unified, the workload may not be found in database. + */ +inline IRModule UnifyFuncName(const IRModule& mod) { + if (!mod->ContainGlobalVar("main") && mod->GetGlobalTypeVars().size() == 1) { + IRModule new_mod = IRModule( + Map({{GlobalVar("main"), mod->functions[mod->GetGlobalVars()[0]]}})); + return new_mod; + } else { + return mod; + } +} + } // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index f508c7d252e1..bc1d5f268ba0 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -22,10 +22,14 @@ import tvm from tvm import meta_schedule as ms from tvm.ir.module import IRModule +from tvm.tir import Schedule +from tvm.target import Target +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord from tvm.meta_schedule.integration import ( ExtractedTask, MetaScheduleContext, TaskExtraction, + ApplyHistoryBest, ) from tvm.meta_schedule.testing import get_network from tvm.script import tir as T @@ -116,5 +120,59 @@ def test_meta_schedule_integration_extract_from_resnet(): assert len(extracted_tasks) == 30 +def test_meta_schedule_integration_apply_history_best(): + class DummyDatabase(PyDatabase): + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + mod, _, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + database = DummyDatabase() + env = ApplyHistoryBest(database) + workload = database.commit_workload(MockModule) + database.commit_tuning_record( + TuningRecord(Schedule(MockModule).trace, [1.0], workload, Target("llvm"), []) + ) + mod = env.query(task_name="mock-task", mod=mod, dispatched=[MockModule]) + assert tvm.ir.structural_equal(mod, workload.mod) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 15e06244b386d846c38a7576e9c3de8cd2fbcb82 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Wed, 26 Jan 2022 02:48:38 -0500 Subject: [PATCH 06/49] [MetaSchedule] Mutator Rule: Mutate Unroll (#10045) * mutate-unroll * mutate-unroll --- python/tvm/meta_schedule/mutator/__init__.py | 1 + .../meta_schedule/mutator/mutate_unroll.py | 31 ++++ src/meta_schedule/mutator/mutate_unroll.cc | 141 ++++++++++++++++++ ...est_meta_schedule_mutator_mutate_unroll.py | 114 ++++++++++++++ 4 files changed, 287 insertions(+) create mode 100644 python/tvm/meta_schedule/mutator/mutate_unroll.py create mode 100644 src/meta_schedule/mutator/mutate_unroll.cc create mode 100644 tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py index f232566785d9..85deb7253e86 100644 --- a/python/tvm/meta_schedule/mutator/__init__.py +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -21,3 +21,4 @@ """ from .mutator import Mutator, PyMutator from .mutate_compute_location import MutateComputeLocation +from .mutate_unroll import MutateUnroll diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py new file mode 100644 index 000000000000..f81953d008d4 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates auto unroll step""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateUnroll") +class MutateUnroll(Mutator): + """Mutator that mutates auto unroll step""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc new file mode 100644 index 000000000000..e4454184071c --- /dev/null +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -0,0 +1,141 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if an instruction is annotate with + * `meta_schedule_unroll_explicit` or `meta_schedule_unroll_implicit` + * \param inst The instruction to be checked + * \return Whether the instruction is annotated + */ +bool IsAnnotateWithUnroll(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_unroll_explicit || + ann_key == attr::meta_schedule_unroll_implicit; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates auto unroll step */ +class MutateUnrollNode : public MutatorNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief A candidate to be mutated */ +struct MutateUnrollNode::Candidate { + /*! \brief The sampling instruction to be mutated */ + Instruction inst; + /*! \brief The probability */ + std::vector probs; + /*! \brief The decision made */ + int decision; +}; + +/*! + * \brief Find the Sample-Categorical instruction to be mutated that affects the maximal unroll step + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidates The mutation candidate + * \return Whether a decision is found + */ +bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, + MutateUnrollNode::Candidate* candidate) { + using tir::InstructionKind; + using tir::InstructionNode; + static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + std::unordered_map sample_insts; + std::vector ann_insts; + sample_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + sample_insts[var_rv] = inst.get(); + } else if (IsAnnotateWithUnroll(inst)) { + ann_insts.push_back(inst.get()); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode); + ICHECK(sample_insts.count(var_rv)); + const InstructionNode* sample_inst = sample_insts.at(var_rv); + ICHECK_EQ(sample_inst->attrs.size(), 2); + candidate->inst = GetRef(sample_inst); + candidate->decision = + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = + support::AsVector(Downcast>(sample_inst->attrs[1])); + return true; +} + +Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { + Candidate candidate; + if (!FindUnrollDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + if (candidate.probs.size() == 0) { + return NullOpt; + } + candidate.probs.erase(candidate.probs.begin() + candidate.decision); + int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); + if (result >= candidate.decision) { + result += 1; + } + return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); +} + +Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py new file mode 100644 index 000000000000..3f3fbcafc0db --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateUnroll, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]]) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + v57 = sch.sample_categorical( + candidates=[0, 16, 64, 512], + probs=[0.25, 0.25, 0.25, 0.25], + decision=0, + ) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.unroll_explicit", ann_val=v57) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateUnroll() + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_unroll_matmul(): + mutator = _make_mutator(target=Target("llvm --num-cores=16")) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + decision = trace.decisions[trace.insts[-2]] + results.add(decision) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {1, 2, 3} + + +if __name__ == """__main__""": + test_mutate_unroll_matmul() From f9e1ff86eaab393dd64374154e833dc5182d7312 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 26 Jan 2022 03:00:07 -0500 Subject: [PATCH 07/49] [TIR][Schedule] Blockize and Tensorize (#9871) * WIP * WIP * WIP * test cases * add examples * lint * Amend co-authors information Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou * WIP * address comments and changed tensorized comparator * update * nit * fix example * lint * lint * lint * remove unused * trigger ci * clang-format * fix * rebase Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou --- include/tvm/arith/iter_affine_map.h | 7 + include/tvm/tir/function.h | 52 ++ include/tvm/tir/schedule/schedule.h | 19 + python/tvm/tir/__init__.py | 2 +- python/tvm/tir/function.py | 48 ++ python/tvm/tir/schedule/schedule.py | 229 ++++++ src/arith/int_set.cc | 2 +- src/tir/ir/function.cc | 53 ++ src/tir/schedule/concrete_schedule.cc | 23 + src/tir/schedule/concrete_schedule.h | 3 + src/tir/schedule/ir_comparator.cc | 363 +++++++++ src/tir/schedule/ir_comparator.h | 116 +++ src/tir/schedule/primitive.h | 18 + .../schedule/primitive/blockize_tensorize.cc | 698 ++++++++++++++++++ src/tir/schedule/schedule.cc | 14 + src/tir/schedule/state.cc | 4 +- src/tir/schedule/traced_schedule.cc | 31 + src/tir/schedule/traced_schedule.h | 3 + .../unittest/test_tir_schedule_blockize.py | 210 ++++++ .../unittest/test_tir_schedule_tensorize.py | 431 +++++++++++ 20 files changed, 2321 insertions(+), 5 deletions(-) create mode 100644 src/tir/schedule/ir_comparator.cc create mode 100644 src/tir/schedule/ir_comparator.h create mode 100644 src/tir/schedule/primitive/blockize_tensorize.cc create mode 100644 tests/python/unittest/test_tir_schedule_blockize.py create mode 100644 tests/python/unittest/test_tir_schedule_tensorize.py diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 22b4cd580e18..eb69c188abf3 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -350,6 +350,13 @@ Array> SubspaceDivide(const Array& bindings, bool require_bijective, arith::Analyzer* analyzer, DiagnosticContext diag_ctx); +/*! + * \brief Given an IterMapExpr, transform it to normal PrimExpr. + * \param expr The input IterMapExpr. + * \return The corresponding normal PrimExpr. + */ +PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_ITER_AFFINE_MAP_H_ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index e482a18c4a5b..1ab911b756df 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,58 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! + * \brief Tensor intrinsics for tensorization + */ +class TensorIntrinNode : public Object { + public: + /*! \brief The function to describe the computation. */ + PrimFunc desc; + /*! \brief The function of the implementation for the execution. */ + PrimFunc impl; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("desc", &desc); + v->Visit("impl", &impl); + } + + static constexpr const char* _type_key = "tir.TensorIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); +}; + +/*! + * \brief Managed reference to TensorIntrinNode. + */ +class TensorIntrin : public ObjectRef { + public: + /*! + * \brief Constructor + * \param desc The function to describe the computation. + * \param impl The function of the implementation for the execution. + */ + TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl); + + /*! + * \brief Create and register a TensorIntrin. After registration, the TensorIntrin can be looked + * up with its name. + * \param name The name of the TensorIntrin to register + * \param intrin The TensorIntrin to register. + * \throws This method throws an exception if the TensorIntrin with the specified name already + * exists. + */ + TVM_DLL static void Register(String name, TensorIntrin intrin); + + /*! + * \brief Look up TensorIntrin by name. Raises an exception if not found. + * \param name The name of the TensorIntrin. + * \return The TensorIntrin with the specified name. + * \throws This method throws an exception if the TensorIntrin does not exist. + */ + TVM_DLL static TensorIntrin Get(String name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) +}; + /*! * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 43f2379a0b56..be06b44820cd 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -473,6 +473,25 @@ class ScheduleNode : public runtime::Object { */ virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; /******** Schedule: Blockize & Tensorize ********/ + /*! + * \brief Convert the subtree rooted at a specific loop into a block. + * \param loop_rv the root of the subtree + * \return the new block + */ + virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with the tensor intrin. + * \param loop_rv The loop to be tensorized + * \param intrin Name of the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with the tensor intrin. + * \param block_rv The block to be tensorized + * \param intrin Name of the tensor intrinsic + */ + virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0; + /******** Schedule: Annotation ********/ /*! * \brief Annotate a loop with a key value pair diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 07ceb29ebf98..5854b9369c16 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize -from .function import PrimFunc +from .function import PrimFunc, TensorIntrin from .op import call_packed, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index ecbcd837cb72..bcebab9ddc0a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -162,3 +162,51 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: return tvm._ffi.get_global_func("script.AsTVMScript")( self, tir_prefix, show_meta ) # type: ignore + + +@tvm._ffi.register_object("tir.TensorIntrin") +class TensorIntrin(Object): + """A tensor intrinsic. + + Parameters + ---------- + desc : PrimFunc + The function to describe the computation. + + impl : PrimFunc + The function of the implementation for the execution. + """ + + def __init__(self, desc, impl): + self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl) + + @staticmethod + def register(name: str, desc: PrimFunc, impl: PrimFunc): + """Register a tensor intrinsic with its name. + + Parameters + ---------- + name : str + The name of the TensorIntrin to register. + desc : PrimFunc + The function to describe the computation. + impl : PrimFunc + The function of the implementation for the execution. + """ + return _ffi_api.TensorIntrinRegister(name, TensorIntrin(desc, impl)) # type: ignore + + @staticmethod + def get(name: str): + """Look up a tensor intrinsic by its name. + + Parameters + ---------- + name : str + The name of the TensorIntrin to look up. + + Returns + ------- + result : TensorIntrin + The TensorIntrin with the specified name. + """ + return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7d352f156a31..96fa21f30020 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1759,6 +1759,235 @@ def after_set_scope( ########## Schedule: Blockize & Tensorize ########## + @type_checked + def blockize(self, loop: LoopRV) -> BlockRV: + """Convert the subtree rooted at a specific loop into a block. + + Parameters + ---------- + loop : LoopRV + The root of the subtree. + + Returns + ------- + result : BlockRV + The new block. + + Examples + -------- + + Before blockize, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_blockize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + ) -> None: + for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): + with T.block("B"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + Create the schedule and do set_scope: + + .. code-block:: python + + sch = tir.Schedule(before_blockize) + B = sch.get_block("B") + _, _, i1, _ = sch.get_loops(B) + sch.blockize(i1) + print(sch.mod["main"].script()) + + After applying blockize, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_blockize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"] + )-> None: + for i_0, j_0 in T.grid(8, 8): + with T.block("B_o"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vio * 16 + vi, vjo * 16 + vj]) + T.writes(B[vio * 16 + vi, vjo * 16 + vj]) + B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] \ + * T.float32(2) + + Note + ---- + blockize requires there is exactly one block under the given loop and the bindings of the + block are divisible by the subspace represented by the loops starting at the given loop. + """ + + return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member + + @type_checked + def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -> None: + """Tensorize the computation enclosed by loop with the tensor intrinsic. + + Parameters + ---------- + block_or_loop : Union[BlockRV, LoopRV] + The loop to be tensorized. + tensor_intrin : str + The tensor intrin or the name of the tensor intrin. + + Examples + -------- + + Before tensorize, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_tensorize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + # body + # with T.block("root") + for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16, 16): + with T.block("update"): + vi = T.axis.spatial(128, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 16 + j_1) + vk = T.axis.reduce(128, k_0 * 16 + k_1) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(C[vi, vj]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + Declare and register the tensor intrinsic: + + .. code-block:: python + + @T.prim_func + def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + + @T.prim_func + def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) + + Create the schedule and do tensorize: + + .. code-block:: python + + sch = tir.Schedule(before_tensorize) + update = sch.get_block("update") + _, _, _, i1, _, _ = sch.get_loops(update) + sch.tensorize(i1, "test_mma_intrin") + print(sch.mod["main"].script()) + + After applying tensorize, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_tensorize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + # body + # with T.block("root") + for i_0, j_0, k_0 in T.grid(8, 8, 8): + with T.block("update_o"): + vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0]) + T.reads( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + ) + T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + A_1 = T.match_buffer( + A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + B_1 = T.match_buffer( + B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + C_1 = T.match_buffer( + C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + [16, 16], + dtype="float32", + offset_factor=1, + ) + with T.init(): + for i_1, j_1 in T.grid(16, 16): + with T.block("update_init"): + vi_init, vj_init = T.axis.remap("SS", [i_1, j_1]) + T.reads() + T.writes(C[vio * 16 + vi_init, vjo * 16 + vj_init]) + C[vio * 16 + vi_init, vjo * 16 + vj_init] = T.float32(0) + T.evaluate( + T.tvm_mma_sync( + C_1.data, + C_1.elem_offset // 256, + A_1.data, + A_1.elem_offset // 256, + B_1.data, + B_1.elem_offset // 256, + C_1.data, + C_1.elem_offset // 256, + dtype="handle", + ) + ) + """ + _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member + self, block_or_loop, tensor_intrin + ) + ########## Schedule: Annotation ########## @type_checked diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 55a1a5a1830e..3d30eef99d7d 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -511,7 +511,7 @@ Range IntSet::CoverRange(Range max_range) const { const IntervalSetNode* s_int = (*this).as(); ICHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { - return Range::FromMinExtent(s_int->min_value, + return Range::FromMinExtent(analyzer.Simplify(s_int->min_value), analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 101d80a52ea1..1c34e34468b5 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -64,6 +64,51 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); +class TensorIntrinManager { + public: + Map reg; + + static TensorIntrinManager* Global() { + static TensorIntrinManager* inst = new TensorIntrinManager(); + return inst; + } +}; + +TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { + // Check the number of func var is equal + CHECK_EQ(desc->params.size(), impl->params.size()) + << "ValueError: The number of parameters of the description and the implementation of the " + "tensor intrinsic doesn't match."; + for (size_t i = 0; i < desc->params.size(); i++) { + CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the " + "tensor intrinsic should be handle only."; + CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of " + "the tensor intrinsic should be handle only."; + } + ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); + + ObjectPtr n = make_object(); + n->desc = std::move(desc); + n->impl = std::move(impl); + data_ = std::move(n); +} + +void TensorIntrin::Register(String name, TensorIntrin intrin) { + TensorIntrinManager* manager = TensorIntrinManager::Global(); + CHECK_EQ(manager->reg.count(name), 0) + << "ValueError: TensorIntrin '" << name << "' has already been registered"; + manager->reg.Set(name, intrin); +} + +TensorIntrin TensorIntrin::Get(String name) { + const TensorIntrinManager* manager = TensorIntrinManager::Global(); + auto it = manager->reg.find(name); + CHECK(it != manager->reg.end()) << "ValueError: TensorIntrin '" << name << "' is not registered"; + return manager->reg.at(name); +} + +TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { // TODO(tvm-team) redirect to Text printer once we have a good text format. @@ -85,5 +130,13 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); +TVM_REGISTER_GLOBAL("tir.TensorIntrin") + .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { + return TensorIntrin(desc_func, intrin_func); + }); + +TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); +TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 9f8dc6dd2daf..fc63f305ff5e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -606,6 +606,29 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { } /******** Schedule: Blockize & Tensorize ********/ +BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Blockize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); + return CreateRV(result); +} + +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + /******** Schedule: Annotation ********/ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 96cb0f728835..5f108178a83b 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -123,6 +123,9 @@ class ConcreteScheduleNode : public ScheduleNode { int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) override; + void Tensorize(const BlockRV& loop_rv, const String& intrin) override; + void Tensorize(const LoopRV& loop_rv, const String& intrin) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc new file mode 100644 index 000000000000..3e61e953a95b --- /dev/null +++ b/src/tir/schedule/ir_comparator.cc @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./ir_comparator.h" + +namespace tvm { + +namespace tir { + +/******** Tensorize Comparator ********/ + +class TensorIntrinMismatchError : public ScheduleError { + public: + explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt rhs_stmt, + std::vector error_messages) + : lhs_mod_(std::move(lhs_mod)), + lhs_stmt_(std::move(lhs_stmt)), + rhs_stmt_(std::move(rhs_stmt)), + error_messages_(std::move(error_messages)) { + ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); + } + + String FastErrorString() const final { + return "ScheduleError: The stmt doesn't match the tensor intrin."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The stmt {0} doesn't match the tensor intrin\n " << rhs_stmt_; + for (const auto& msg : error_messages_) { + os << msg << std::endl; + } + return os.str(); + } + + IRModule mod() const final { return lhs_mod_; } + + Array LocationsOfInterest() const final { return {lhs_stmt_}; } + + private: + IRModule lhs_mod_; + Stmt lhs_stmt_; + Stmt rhs_stmt_; + std::vector error_messages_; +}; + +/* Override the dispatcher to make sure RHS is always valid */ +bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { + bool equal = n.same_as(other) || + ((n->type_index() == other->type_index()) && StmtComparator::VisitStmt(n, other)); + if (!equal && assert_mode_ && (n->IsInstance() || n->IsInstance())) { + throw TensorIntrinMismatchError(lhs_mod_, n, other, std::move(error_messages_)); + } + return equal; +} + +bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { + bool equal = + n.same_as(other) || ((n->type_index() == other->type_index()) && n->dtype == other->dtype && + ExprComparator::VisitExpr(n, other)); + if (!equal && assert_mode_) { + std::ostringstream os; + os << "Expression mismatch: " << n << " vs " << other; + EmitError(os.str()); + } + return equal; +} + +bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + if (!VisitExpr(op->min, rhs->min)) return false; + if (!VisitExpr(op->extent, rhs->extent)) return false; + if (op->thread_binding.defined() != rhs->thread_binding.defined()) return false; + if (op->thread_binding.defined() && + !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) { + return false; + } + if (op->kind != rhs->kind) return false; + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) return false; + return VisitStmt(op->body, rhs->body); +} + +bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); +} + +bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!is_scope_block) { + if (!CompareArray(op->iter_values, rhs->iter_values, &TensorizeComparator::VisitExpr)) { + return false; + } + } + return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); +} + +bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equality. + // All iter vars and buffer regions including the order should match. + // When checking iter vars, DefEqual is used to remap variables. + if (!is_scope_block) { + if (!CompareArray(op->iter_vars, rhs->iter_vars, &TensorizeComparator::CompareIterVar)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + return false; + } + } + if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +// Exprs +#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ + bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ + } + +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); + +bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + auto lhs = GetRef(op); + if (lhs.same_as(other)) return true; + if (op->dtype != rhs->dtype) return false; + auto it = equal_map_.find(lhs); + return it != equal_map_.end() && it->second.same_as(other); +} + +bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value, rhs->true_value) && + VisitExpr(op->false_value, rhs->false_value); +} + +bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = equal_map_.find(lhs); + // If there is already a mapping + if (it != equal_map_.end()) return it->second.same_as(rhs); + // Otherwise remap lhs to rhs + equal_map_[lhs] = rhs; + analyzer_.Bind(lhs, rhs); + return true; +} + +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first != rhs.first) return false; + if (!lhs.second.same_as(rhs.second)) return false; + return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); +} + +bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, + const Map& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + + auto sort_map = + [](const Map& map) -> std::vector> { + std::vector> ret(map.begin(), map.end()); + sort(ret.begin(), ret.end()); + return ret; + }; + + std::vector> lhs_array = sort_map(lhs); + std::vector> rhs_array = sort_map(rhs); + + for (size_t i = 0; i < lhs.size(); ++i) { + if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + auto it = rhs_buffer_map_.find(rhs); + bool equal; + if (it != rhs_buffer_map_.end()) { + equal = (*it).second.same_as(lhs); + } else { + // Remap both buffer itself and buffer data, skip buffer shape + equal = + DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } + } + return equal; +} + +bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) { + if (assert_mode_) { + std::ostringstream os; + os << "Buffer mismatch: " << lhs->buffer << " vs " << rhs->buffer; + EmitError(os.str()); + } + return false; + } + int offset = static_cast(lhs->region.size()) - static_cast(rhs->region.size()); + // Number of indices in RHS (desc of the tensor intrinsic) must be smaller than it in LHS + if (offset < 0) return false; + + auto it = buffer_indices_.find(lhs->buffer); + if (it == buffer_indices_.end()) { + // Update base indices for the buffer, this can only happen if it is visiting the scope block. + ICHECK(is_scope_block); + std::vector indices_base; + indices_base.reserve(lhs->region.size()); + for (int i = 0; i < offset; i++) { + // High-dim region must be element-wise + if (!is_one(lhs->region[i]->extent)) return false; + indices_base.emplace_back(lhs->region[i]->min); + } + for (size_t i = 0; i < rhs->region.size(); i++) { + // save base index + indices_base.emplace_back(lhs->region[i + offset]->min); + // check extent match + if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + return false; + } + } + buffer_indices_.emplace(lhs->buffer, std::move(indices_base)); + } else { + // Check the base indices are consistent. + const std::vector& indices_base = it->second; + for (int i = 0; i < offset; i++) { + // High-dim region must be element-wise + if (!is_one(lhs->region[i]->extent)) return false; + if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) return false; + } + for (size_t i = 0; i < rhs->region.size(); i++) { + // check extent match + if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + return false; + } + PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); + if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { + return false; + } + } + } + return true; +} + +// Comparator for BufferStoreNode and BufferLoadNode +template +bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + int offset = static_cast(lhs->indices.size()) - static_cast(rhs->indices.size()); + if (offset < 0) return false; + auto it = buffer_indices_.find(lhs->buffer); + ICHECK(it != buffer_indices_.end()); + const std::vector& indices_base = (*it).second; + ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset); + for (size_t i = 0; i < rhs->indices.size(); i++) { + PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i + offset]; + if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) { + if (assert_mode_) { + std::ostringstream os; + os << "Buffer indices mismatch: " << lhs->indices[i + offset] << " vs " << rhs->indices[i]; + EmitError(os.str()); + } + return false; + } + } + return true; +} + +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(this->*cmp)(lhs[i], rhs[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { + return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); +} + +bool TensorizeComparator::CompareIterVar(const IterVar& lhs, const IterVar& rhs) { + return DefEqual(lhs->var, rhs->var) && lhs->iter_type == rhs->iter_type; +} + +void TensorizeComparator::EmitError(const std::string& error_message) { + error_messages_.push_back(error_message); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h new file mode 100644 index 000000000000..359677d8852f --- /dev/null +++ b/src/tir/schedule/ir_comparator.h @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ +#define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ + +#include +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace tir { + +using ExprComparator = ExprFunctor; +using StmtComparator = StmtFunctor; + +/*! \brief Deep comparison to check if two IR ASTs are equivalent for tensorization*/ +class TensorizeComparator : public ExprComparator, public StmtComparator { + public: + /*! + * \brief Constructor of TensorizeComparator + * \param assert_mode Whether to raise an error if the two IR ASTs do not match. + * \param lhs_mod The IRModule of the LHS. This is used for error reporting. + */ + explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true) + : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) {} + + bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; + bool VisitStmt(const Stmt& n, const Stmt& other) override; + + bool VisitStmt_(const ForNode* op, const Stmt& other) override; + bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + + bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; + bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; + bool VisitExpr_(const NENode* op, const PrimExpr& other) override; + bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const LENode* op, const PrimExpr& other) override; + bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const GENode* op, const PrimExpr& other) override; + bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; + bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; + bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SelectNode* op, const PrimExpr& other) override; + + /*! \brief Map from RHS buffer to LHS buffer */ + std::unordered_map rhs_buffer_map_; + /*! \brief Base indices of the LHS buffer. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; + + protected: + bool DefEqual(const Var& lhs, const Var& rhs); + virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + template + bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + bool CompareRange(const Range& lhs, const Range& rhs); + bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); + void EmitError(const std::string& error_message); + + /*! \brief IRModule of the LHS stmt. */ + IRModule lhs_mod_; + /*! \brief Whether assertion mode is enabled. */ + bool assert_mode_; + /*! \brief Whether it is visiting the scope block (the outermost block). */ + bool is_scope_block = true; + /*! \brief The arithmetic analyzer. */ + arith::Analyzer analyzer_; + /*! \brief Additional error messages. Only used when assert_mode is true. */ + std::vector error_messages_; + // variable remap if any + std::unordered_map equal_map_; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index f0b38af01b5f..2368411e6f09 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -378,6 +378,24 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer const String& storage_scope); /******** Schedule: Blockize & Tensorize ********/ + +/*! + * \brief Convert the subtree rooted at a specific loop into a block. + * \param self The state of the schedule + * \param loop_sref The root of the subtree + * \return The new block + */ +TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); + +/*! + * \brief Tensorize the computation enclosed by loop with the tensor intrinsic. + * \param self The state of the schedule + * \param block_or_loop_sref The block or loop to be tensorized. + * \param intrin The tensor intrinsic. + */ +TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, + const TensorIntrin& intrin); + /******** Schedule: Annotation ********/ /*! * \brief Annotate a block/loop with a key value pair diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc new file mode 100644 index 000000000000..bbeb9caaab9b --- /dev/null +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -0,0 +1,698 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace + * represented by the outer loops. + */ +class SubspaceNotDivisibleError : public ScheduleError { + public: + explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block inner_block) + : mod_(std::move(mod)), + scope_loop_(std::move(scope_loop)), + inner_block_(std::move(inner_block)) {} + + String FastErrorString() const final { + return "ScheduleError: The bindings of the inner block can not be blockized."; + } + + String DetailRenderTemplate() const final { + return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops " + "starting at {1}."; + } + + IRModule mod() const final { return mod_; } + + Array LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } + + private: + IRModule mod_; + For scope_loop_; + Block inner_block_; +}; + +/*! + * \brief Detect if bindings are a trivial case of the subspace division where we can divide the + * block iter bindings into two categories: + * 1. The binding covers no inner loop vars. + * 2. The binding covers only inner loop vars. + * + * The bindings are not required to be quasi-affine. + * + * \param iter_vars The input iterators + * \param bindings The values of iter_vars + * \param outer_loops Iterators outside the subspace. + * \param inner_loops Iterators of the subspace + * \param predicate The predicate constraint on the input iterators. + * \return The result of the subspace division. + */ +Array> TrivialSubspaceDivision(const Array& iter_vars, + const Array& bindings, + const Array& outer_iters, + const Array& inner_iters, + const PrimExpr& predicate) { + if (!is_one(predicate)) return {}; + Array> res; + std::unordered_set outer_loop_vars; + std::unordered_set inner_loop_vars; + + auto make_uses_var = [](const Array& vars) -> std::function { + std::unordered_set var_set; + var_set.reserve(vars.size()); + for (const Var& var : vars) { + var_set.insert(var.get()); + } + return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool { + return UsesVar(expr, [&var_set](const VarNode* var) { + return var_set.count(var); // + }); + }; + }; + auto use_outer_loop_vars = make_uses_var(outer_iters); + auto use_inner_loop_vars = make_uses_var(inner_iters); + arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); + + for (size_t i = 0; i < bindings.size(); ++i) { + bool outer = use_outer_loop_vars(bindings[i]); + bool inner = use_inner_loop_vars(bindings[i]); + arith::IterMark iter_mark; + if (bindings[i]->IsInstance()) { + iter_mark = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + if (outer && !inner) { + res.push_back({/*outer_iter=*/iter_mark, /*inner_iter=*/unit_iter_mark}); + } else if (inner && !outer) { + res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/iter_mark}); + } else if (!outer && !inner) { + res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/unit_iter_mark}); + } else { + return {}; + } + } + res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), + arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); + return res; +} + +/*! + * \brief Generate the blockized init block. + * \param block The original block with init. + * \param inner_block_realize The block realize of the inner block after blockize. + * \param inner_loops The inner loops after blockize. + * \return The subtree of the init block and its outer loops. + */ +Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize, + const std::vector& inner_loops) { + Array init_block_iters; + Array init_bindings; + const Block& inner_block = inner_block_realize->block; + + // Step 1: Collect data-parallel block iters + for (size_t i = 0; i < inner_block->iter_vars.size(); i++) { + const IterVar& iter_var = inner_block->iter_vars[i]; + const PrimExpr& binding = inner_block_realize->iter_values[i]; + if (iter_var->iter_type == IterVarType::kDataPar && + UsesVar(block->init.value(), + [tgt_var = iter_var->var.get()](const VarNode* var) { return var == tgt_var; })) { + init_block_iters.push_back(iter_var); + init_bindings.push_back(binding); + } + } + + // Step 2: Collect loops related to iters of the init block + std::vector init_loops; + for (const ForNode* inner_loop : inner_loops) { + for (const PrimExpr& init_binding : init_bindings) { + if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const VarNode* var) { + return var == tgt_var; + })) { + init_loops.push_back(inner_loop); + break; + } + } + } + + // Step 3: Create new block iters for the init block + Map subst_map; + for (size_t i = 0; i < init_block_iters.size(); i++) { + IterVar new_iter_var = init_block_iters[i]; + Var old_var = new_iter_var->var; + Var new_var = old_var.copy_with_suffix("_init"); + new_iter_var.CopyOnWrite()->var = new_var; + subst_map.Set(old_var, new_var); + init_block_iters.Set(i, std::move(new_iter_var)); + } + + // Step 4: Generate loop nests and the init block + Stmt new_init = BlockRealize( + /*iter_values=*/init_bindings, + /*predicate=*/inner_block_realize->predicate, + /*block=*/ + Block{/*iter_vars=*/init_block_iters, + /*reads=*/{}, + /*writes=*/block->writes, + /*name_hint=*/block->name_hint + "_init", + /*body=*/block->init.value(), + /*init=*/NullOpt}); + + // Step 5: Generate the parent loops for the init block + for (const ForNode* init_loop : init_loops) { + ObjectPtr new_loop = make_object(*init_loop); + new_loop->loop_var = init_loop->loop_var.copy_with_suffix(""); + subst_map.Set(init_loop->loop_var, new_loop->loop_var); + new_loop->body = std::move(new_init); + new_init = For(new_loop); + } + + // Step 6: Substitute with new loop variables and block iters to prevent duplication of + // variables in the outer block. + new_init = Substitute(new_init, subst_map); + + return new_init; +} + +/*! + * \brief A helper to collect the parent loops of the block. The loops are divided into two groups, + * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the + * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its + * successor loops. It is possible that 'outer_loops' is empty. + */ +class LoopSubspaceCollector { + public: + /*! + * \brief Collect the parent loops of the block and store the result in the corresponding fields. + * \param block_sref The sref to the target block. + * \param loop_sref The sref to the separator loop. The loop itself is counted as an inner loop. + */ + void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) { + bool inner = true; + for (StmtSRefNode* current_sref = block_sref->parent; + current_sref && current_sref->stmt->IsInstance(); + current_sref = current_sref->parent) { + const auto* current_loop = current_sref->StmtAs(); + ICHECK(current_loop); + if (inner) { + inner_loops.push_back(current_loop); + inner_loop_vars.push_back(current_loop->loop_var); + } else { + outer_loops.push_back(current_loop); + outer_loop_vars.push_back(current_loop->loop_var); + } + loop_var_domain.Set(current_loop->loop_var, + Range::FromMinExtent(current_loop->min, current_loop->extent)); + if (current_sref == loop_sref.get()) inner = false; + } + } + /*! \brief Outer loops which are ancestors of the separator. */ + std::vector outer_loops; + /*! \brief Inner loops which are the separator itself or its successors. */ + std::vector inner_loops; + /*! \brief Loop variables of the outer loops. */ + Array outer_loop_vars; + /*! \brief Loop variables of the inner loops. */ + Array inner_loop_vars; + /*! \brief Domain of the loop variables. */ + Map loop_var_domain; +}; + +/*! + * \brief Check the bindings of the block iters can be divided by a subspace collected by the + * collector. + * \param mod The current IR module. + * \param block_realize The block realize to be checked. + * \param collector The collector which has collected the loops of the block. + * \param analyzer The arithmetic analyzer. + * \return The result of the subspace division. + * \throws ScheduleError If the bindings are not divisible by the subspace. + */ +Array> CheckSubspaceDivisible(const IRModule& mod, + const BlockRealize& block_realize, + const LoopSubspaceCollector& collector, + arith::Analyzer* analyzer) { + const Block& block = block_realize->block; + DiagnosticContext diag_ctx(DiagnosticContext::Default(mod)); + + Array> division = + arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain, + collector.inner_loop_vars, block_realize->predicate, + /*require_bijective=*/false, analyzer, diag_ctx); + + if (division.empty()) { + // If we can't do perfect subspace division, check if it is a trivial case of subspace division. + // In this case, we can still blockize. + division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, + collector.outer_loop_vars, collector.inner_loop_vars, + block_realize->predicate); + } + if (division.empty()) { + throw SubspaceNotDivisibleError(mod, GetRef(collector.inner_loops.back()), block); + } + return division; +} + +/*! + * \brief The binding extractor to compute the bindings of the outer and the inner blocks after + * blockize. + */ +class BlockizedBindingExtractor { + public: + /*! + * \brief Extract bindings for blockize. + * \param iter_vars The iter vars of the original inner block. + * \param division The result of the subspace division. + */ + void ExtractBindings(const Array& iter_vars, + const Array>& division, arith::Analyzer* analyzer) { + ICHECK_EQ(iter_vars.size() + 1, division.size()); + for (size_t i = 0; i < iter_vars.size(); ++i) { + const IterVar& iter_var = iter_vars[i]; + arith::IterMark outer_mark = division[i][0]; + arith::IterMark inner_mark = division[i][1]; + const auto* outer_binding = + TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode); + const auto* inner_binding = + TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode); + + // After computing the subspace division, bindings[i] can be written as + // outer_binding * inner_binding->extent + inner_binding + // The outer block will have binding: iter_outer -> outer_binding + // The inner block will have binding: iter_inner -> inner_binding + // The iter in the original block will be substituted with base + iter_inner where + // base == iter_outer * iter_inner_extent + + if (is_one(division[i][1]->extent)) { // IsOuter + // extract this iter var to outer block directly + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_iter_vars.push_back(iter_var); + } else { + // create iter var for the outer block + const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, division[i][0]->extent), + /*var=*/iter_var->var.copy_with_suffix("_o"), + /*iter_type=*/iter_var->iter_type); + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_iter_vars.push_back(outer_var); + PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; + // create iter var for the inner block + IterVar new_iter = iter_var; + auto* new_iter_node = new_iter.CopyOnWrite(); + new_iter_node->dom = Range::FromMinExtent(0, division[i][1]->extent); + inner_iter_dom_map.Set(new_iter->var, arith::IntSet::FromRange(new_iter->dom)); + analyzer->Bind(new_iter->var, new_iter->dom); + inner_iter_vars.push_back(new_iter); + inner_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); + inner_iter_subst_map.Set(iter_var->var, base + new_iter->var); + } + } + } + Map inner_iter_subst_map; + /*! \brief Iters of the outer block. */ + Array outer_iter_vars; + /*! \brief Iters of the outer block. */ + Array inner_iter_vars; + /*! \brief Binding values of the outer block. */ + Array outer_bindings; + /*! \brief Binding values of the inner block. */ + Array inner_bindings; + /*! \brief The domain of the inner block iters. */ + Map inner_iter_dom_map; +}; + +/*! + * \brief Replacer for the inner block after blockize. Inner block iters will be replaced with + * base + inner_iter and the expressions after substituion will be simplified if possible. + */ +class InnerIterReplacer : public StmtExprMutator { + public: + /*! + * \brief The constructor + * \param subst_map The substitution map of the inner block iters. + * \param analyzer The arithmetic analyzer. + * \param block_sref_reuse The map to save the block reuse information. + */ + InnerIterReplacer(Map subst_map, arith::Analyzer* analyzer, + Map* block_sref_reuse) + : subst_map_(std::move(subst_map)), + analyzer_(analyzer), + block_sref_reuse_(block_sref_reuse) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = subst_map_.find(GetRef(op)); + if (it != subst_map_.end()) { + return (*it).second; + } + return StmtExprMutator::VisitExpr_(op); + } + + PrimExpr VisitExpr(const PrimExpr& op) final { + PrimExpr result = StmtExprMutator::VisitExpr(op); + if (!result.same_as(op)) { + return analyzer_->Simplify(result); + } + return result; + } + + Stmt VisitStmt_(const BlockNode* op) final { + Stmt result = StmtExprMutator::VisitStmt_(op); + if (!result.same_as(GetRef(op))) { + block_sref_reuse_->Set(GetRef(op), Downcast(result)); + } + return result; + } + + private: + Map subst_map_; + arith::Analyzer* analyzer_; + Map* block_sref_reuse_; +}; + +/*! + * \brief Compute the access region of the outer block by relaxing the inner loops. + * \param buffer_region The original buffer region. + * \param The range of the inner loops. + * \return The new buffer region. + */ +BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, + const Map& inner_iter_relaxed_range) { + Array new_region; + new_region.reserve(buffer_region->region.size()); + Array relaxed_int_set = + arith::EvalSet(buffer_region->region, inner_iter_relaxed_range); + ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size()); + for (size_t i = 0; i < buffer_region->region.size(); i++) { + Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]); + new_region.push_back(relaxed_int_set[i].CoverRange(max_range)); + } + return BufferRegion(buffer_region->buffer, std::move(new_region)); +} + +/*! + * \brief Generate the outer block after blockize. + * \param extractor The binding extractor which has extracted the blockized bindings. + * \param block The original inner block. + * \param inner_block_realize The block realize of the inner block after blockize. + * \param inner_loops The inner loops after blockize. + * \param predicate The outer predicate of the subspace division. + * \return The block realize of the outer block after blockize. + */ +BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor, + const Block& block, BlockRealize inner_block_realize, + const std::vector& inner_loops, + PrimExpr predicate) { + // Step 1: Generate the init block if needed + Optional new_init = NullOpt; + if (block->init.defined()) { + new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops); + } + + // Step 2: Compute the access regions of the outer block by relaxing the inner loops + Array new_reads = block->reads; + Array new_writes = block->writes; + + auto f_mutate = [&](const BufferRegion& buffer_region) { + return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_dom_map); + }; + new_reads.MutateByApply(f_mutate); + new_writes.MutateByApply(f_mutate); + + // Step 3: Generate the body of the outer block. The body of the outer block is the inner block + // realize and its surrounding loops. + Stmt outer_block_body = inner_block_realize; + for (const ForNode* loop : inner_loops) { + ObjectPtr new_loop = make_object(*loop); + new_loop->body = std::move(outer_block_body); + outer_block_body = For(new_loop); + } + + // Step 4: Generate the outer block and block realize. + return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings), + /*predicate=*/std::move(predicate), + /*block=*/ + Block(/*iter_vars=*/std::move(extractor.outer_iter_vars), // + /*reads=*/std::move(new_reads), // + /*writes=*/std::move(new_writes), // + /*name_hint=*/block->name_hint + "_o", // + /*body=*/std::move(outer_block_body), // + /*init=*/std::move(new_init))); +} + +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + arith::Analyzer analyzer; + + // Step 1: Check the loop has a single child BlockRealize on the sref tree. + BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); + Block block = block_realize->block; + StmtSRef block_sref = self->stmt2ref.at(block.get()); + + // Step 2: Collect loops inside and outside loop_sref. + LoopSubspaceCollector collector; + collector.Collect(block_sref, loop_sref); + + // Step 3: Calculate subspace division for the inner loops. + Array> division = + CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer); + + // Step 4: Generate bindings for the outer block and the inner block based on the result of + // the subspace division. + BlockizedBindingExtractor extractor; + extractor.ExtractBindings(block->iter_vars, division, &analyzer); + const PrimExpr& outer_pred = division.back()[0]->extent; + const PrimExpr& inner_pred = division.back()[1]->extent; + + // Step 5: Substitute the iter vars in the original block with the inner iters after the subspace + // division + Map block_sref_reuse; + InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer, + &block_sref_reuse); + Block new_block = Downcast(replacer(block)); + + // Step 6: Generate the inner block. + BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite(); + inner_block_realize->iter_values = extractor.inner_bindings; + inner_block_realize->predicate = inner_pred; + inner_block_realize->block = new_block; + BlockNode* inner_block = inner_block_realize->block.CopyOnWrite(); + inner_block->iter_vars = extractor.inner_iter_vars; + inner_block->init = NullOpt; + block_sref_reuse.Set(block, inner_block_realize->block); + + // Step 6: Generate the outer block. + BlockRealize outer_realize = + GenerateBlockizedOuterBlock(extractor, new_block, GetRef(inner_block_realize), + collector.inner_loops, outer_pred); + // Step 7: Do the actual replacement + self->Replace(loop_sref, outer_realize, block_sref_reuse); + + // Step 8: Update the cached flags + StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); + StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false, + /*require_subtree_compact_dataflow=*/false); + bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); + self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); + self->block_info[scope_root].affine_binding = scope_block_affine_binding; + return outer_block_sref; +} + +/*! + * \brief Update the map from the buffers in the desc to the impl of the tensor + * intrinsic. + * \param intrinsic The tensor intrinsic. + * \param buffer_map The map to be updated. + */ +void RemapTensorIntrinBuffers( + const TensorIntrin& intrinsic, + std::unordered_map* buffer_map) { + ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size()); + for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) { + const Var& lhs_var = intrinsic->desc->params[i]; + const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var]; + const Var& rhs_var = intrinsic->impl->params[i]; + const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var]; + (*buffer_map)[rhs_buffer] = lhs_buffer; + } +} + +void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, + const TensorIntrin& intrinsic) { + /*! + * Check: + * - Check buffer binding, including type, alignment, shape and etc. + * - Check the sub AST is equal to the desc function. + * + * Mutate: + * - Blockize the sub AST (please refer blockize for details) + * - Bind buffers + * - Mutate the impl of the tensor intrinsic by replacing its buffers with new + * buffers created via match buffer region. + * - Replace the sub tree with the mutated function. + */ + const BlockRealize& desc_block_realize = Downcast(intrinsic->desc->body); + const BlockRealize& impl_block_realize = Downcast(intrinsic->impl->body); + Block impl_block = impl_block_realize->block; + + // Step 1: Blockize the subtree rooted at the given loop if needed + StmtSRef block_sref{nullptr}; + if (block_or_loop_sref->StmtAs()) { + block_sref = Blockize(self, block_or_loop_sref); + } else { + ICHECK(block_or_loop_sref->StmtAs()); + block_sref = block_or_loop_sref; + } + const BlockRealize& block_realize = GetBlockRealize(self, block_sref); + + // Step 2: Compare the block with the desc of the tensor intrinsic, find the correspondence + // between buffers in the block and the desc. + TensorizeComparator comparator(self->mod, /*assert_mode=*/true); + comparator.VisitStmt(block_realize, desc_block_realize); + + // Step 3: Find the correspondence between buffers in the current AST and the impl of + // the tensor intrinsic + // Step 3.1: Map from intrinsic func buffer to desc func buffer + std::unordered_map intrin_buffer_map; + RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map); + // Step 3.2: Map form intrinsic func buffer to current AST buffer + std::unordered_map buffer_map; + for (const auto& pair : intrin_buffer_map) { + auto it = comparator.rhs_buffer_map_.find(pair.second); + ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second; + buffer_map[pair.first] = it->second; + } + + // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor + // intrin to make them subregions of the buffer in the original IR. + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; + for (const BufferRegion& read : impl_block->reads) { + buffer_region_map.emplace(read->buffer, read->region); + } + for (const BufferRegion& write : impl_block->writes) { + buffer_region_map.emplace(write->buffer, write->region); + } + Array match_buffer_regions; + match_buffer_regions.reserve(intrinsic->impl->params.size()); + for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) { + const auto& param = intrinsic->impl->params[i]; + const auto& buffer = intrinsic->impl->buffer_map.at(param); + const auto& source = buffer_map.at(buffer); + // add the detected base indices to each buffer access region of the tensor intrinsic + Region old_region = buffer_region_map.at(buffer); + const auto& indices_base = comparator.buffer_indices_.at(source); + int offset = static_cast(indices_base.size()) - static_cast(old_region.size()); + ICHECK(offset >= 0); + Region new_region; + new_region.reserve(source->shape.size()); + for (int i = 0; i < offset; i++) { + new_region.push_back(Range::FromMinExtent(indices_base[i], 1)); + } + for (int i = 0; i < static_cast(old_region.size()); i++) { + new_region.push_back(Range::FromMinExtent(indices_base[i + offset], old_region[i]->extent)); + } + match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, new_region))); + } + + // Step 5: Replace the subtree in the original IR with the tensor intrin impl. + ObjectPtr new_block_ptr = make_object(*block_realize->block.get()); + new_block_ptr->body = impl_block->body; + ICHECK(new_block_ptr->match_buffers.empty()); + new_block_ptr->match_buffers = std::move(match_buffer_regions); + Block new_block(new_block_ptr); + + self->Replace(block_sref, new_block, {{block_realize->block, new_block}}); + + // Step 6: Update the cached flags. + StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_subtree_compact_dataflow=*/false); + self->UpdateScopeBlockInfo(static_cast(scope_root->stmt)->body); +} + +/******** InstructionKind Registration ********/ + +struct BlockizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Blockize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Blockize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("blockize"); + py.Input("loop", loop_rv); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct TensorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Tensorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin) { + if (const auto* block = block_or_loop_rv.as()) { + sch->Tensorize(GetRef(block), intrin); + } else if (const auto* loop = block_or_loop_rv.as()) { + sch->Tensorize(GetRef(loop), intrin); + } else { + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); + } + } + + static String UnpackedAsPython(Array outputs, String block_or_loop_rv, String intrin) { + PythonAPICall py("tensorize"); + py.Input("block_or_loop", block_or_loop_rv); + py.Input("intrin", intrin); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 6e33862c07ca..b466843f9459 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -185,6 +185,20 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); /******** (FFI) Blockize & Tensorize ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") + .set_body_method(&ScheduleNode::Blockize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") + .set_body_typed([](Schedule self, ObjectRef rv, String intrin) { + if (const auto* block_rv = rv.as()) { + self->Tensorize(GetRef(block_rv), intrin); + } else if (const auto* loop_rv = rv.as()) { + self->Tensorize(GetRef(loop_rv), intrin); + } else { + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + } + }); + /******** (FFI) Annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 04b7dd5ea2af..3a37f81b5dbc 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -201,9 +201,7 @@ class BlockInfoCollector : private StmtVisitor { bool is_root_block = srefs_.empty(); // Calculate `BlockInfo::scope` Array child_block_srefs = std::move(block_frames_.back()); - BlockInfo& info = - self_->block_info.emplace(scope_root, BlockInfo(BlockScope(child_block_srefs))) - .first->second; + BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); // Set `affine_binding` if (is_root_block) { // If the block doesn't have outer loops and BlockRealize, diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index da7a2641b162..1e2e57eb6eca 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -356,6 +356,37 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /******** Schedule: Blockize & Tensorize ********/ +BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { + BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv); + static const InstructionKind& kind = InstructionKind::Get("Blockize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{new_block})); + return new_block; +} + +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) { + ConcreteScheduleNode::Tensorize(loop_rv, intrin); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{intrin}, + /*outputs=*/{})); +} + +void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) { + ConcreteScheduleNode::Tensorize(block_rv, intrin); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{intrin}, + /*outputs=*/{})); +} + /******** Schedule: Annotation ********/ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index b35f1b6e17bb..3a88e869d309 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -87,6 +87,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) final; + void Tensorize(const BlockRV& block_rv, const String& intrin) final; + void Tensorize(const LoopRV& loop_rv, const String& intrin) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py new file mode 100644 index 000000000000..b4a16a8231b8 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest +import tvm +from tvm.script import tir as T +from tvm import tir +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +@T.prim_func +def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + +@T.prim_func +def single_elementwise_blockized1( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + with T.block("blockized_B"): + vio = T.axis.spatial(1, 0) + vjo = T.axis.spatial(1, 0) + T.reads(A[0:128, 0:128]) + T.writes(B[0:128, 0:128]) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + +@T.prim_func +def single_elementwise_blockized2( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] +) -> None: + for i in T.serial(128): + with T.block("blockized_B"): + vi = T.axis.spatial(128, i) + vjo = T.axis.spatial(1, 0) + T.reads(A[vi, 0:128]) + T.writes(B[vi, 0:128]) + for j in T.serial(128): + with T.block("B"): + vj = T.axis.remap("S", [j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + + +@T.prim_func +def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def two_elementwise_blockized( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"] +) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i_0, j_0 in T.grid(8, 8): + with T.block("blockized_B"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vio * 16 + vi, vjo * 16 + vj]) + T.writes(B[vio * 16 + vi, vjo * 16 + vj]) + B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * T.float32(2) + with T.block("blockized_C"): + vio, vjo = T.axis.remap("SS", [i_0, j_0]) + T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + for ax0, ax1 in T.grid(16, 16): + with T.block("C"): + vi, vj = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[vio * 16 + vi, vjo * 16 + vj]) + T.writes(C[vio * 16 + vi, vjo * 16 + vj]) + C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo * 16 + vj] + T.float32(1) + + +@T.prim_func +def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: + for k, i in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [k, i]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: + with T.block("blockized_B"): + vko = T.axis.R(1, 0) + vio = T.axis.S(1, 0) + with T.init(): + for i1 in T.serial(0, 128): + with T.block("B_init"): + vi_init = T.axis.S(128, i1) + B[vi_init] = T.float32(0) + for i0, i1_1 in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [i0, i1_1]) + B[vi] = B[vi] + A[vi, vk] + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +def test_blockize_outer(): + func = single_elementwise + # schedule + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + x, y = s.get_loops(B) + s.blockize(x) + print(s.mod['main'].script()) + tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized1) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_inner(): + func = single_elementwise + # schedule + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + x, y = s.get_loops(B) + s.blockize(y) + tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized2) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_two_elementwise_blockize_reverse_compute_at(): + func = two_elementwise + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.reverse_compute_at(C, yo) + s.blockize(s.get_loops(C)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_two_elementwise_blockize_compute_at(): + func = two_elementwise + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(C) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.compute_at(B, yo) + s.blockize(s.get_loops(B)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_init_loops(): + s = tir.Schedule(rowsum, debug_mask="all") + k, _ = s.get_loops(s.get_block("B")) + s.blockize(k) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) + verify_trace_roundtrip(sch=s, mod=rowsum) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py new file mode 100644 index 000000000000..401a39f379b7 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -0,0 +1,431 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +@T.prim_func +def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16]) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, ()) + + with T.block("root"): + T.reads(C[()], A[0 : 4], B[0 : 4]) + T.writes(C[()]) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.remap("R", [i]) + C[()] = C[()] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), offset_factor=1) + B = T.match_buffer(b, (4,), offset_factor=1) + C = T.match_buffer(c, (), offset_factor=1) + + with T.block("root"): + T.reads(C[()], A[0 : 4], B[0 : 4]) + T.writes(C[()]) + T.evaluate( + T.call_extern( + "vec4add", + C.data, + C.elem_offset, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 1), offset_factor=1) + B = T.match_buffer(b, (16, 1), offset_factor=1) + C = T.match_buffer(c, (16, 16), offset_factor=1) + + with T.block("root"): + T.reads( + C[0 : 16, 0 : 16], + A[0 : 16, 0 : 1], + B[0 : 16, 0 : 1], + ) + T.writes(C[0 : 16, 0 : 16]) + for i, j in T.grid(16, 16): + with T.block("update"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0] + + +@T.prim_func +def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 1), offset_factor=1) + B = T.match_buffer(b, (16, 1), offset_factor=1) + C = T.match_buffer(c, (16, 16), offset_factor=1) + + with T.block("root"): + T.reads( + C[0 : 16, 0 : 16], + A[0 : 16, 0 : 1], + B[0 : 16, 0 : 1], + ) + T.writes(C[0 : 16, 0 : 16]) + T.evaluate( + T.call_extern( + "outer_product", + C.data, + C.elem_offset, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def matmul( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], +) -> None: + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + + for i_outer, j_outer in T.grid(8, 8): + for i_inner_init, j_inner_init in T.grid(16, 16): + with T.block("init"): + vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) + vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) + C[vi_init, vj_init] = T.float32(0) + for k_outer in T.grid(8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) + T.reads( + [ + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + ] + ) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A_elem_offset = T.var("int32") + B_elem_offset = T.var("int32") + C_elem_offset = T.var("int32") + A_sub = T.match_buffer( + A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + [16, 16], + elem_offset=A_elem_offset, + ) + B_sub = T.match_buffer( + B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + [16, 16], + elem_offset=B_elem_offset, + ) + C_sub = T.match_buffer( + C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + [16, 16], + elem_offset=C_elem_offset, + ) + T.evaluate( + T.tvm_mma_sync( + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + A_sub.data, + T.floordiv(A_sub.elem_offset, 256), + B_sub.data, + T.floordiv(B_sub.elem_offset, 256), + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + dtype="handle", + ) + ) + + +@T.prim_func +def batch_matmul( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@T.prim_func +def tensorized_batch_matmul_mma( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n in range(0, 16): + for i, j, k in T.grid(8, 8, 8): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + T.reads( + C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + ) + T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A_elem_offset = T.var("int32") + B_elem_offset = T.var("int32") + C_elem_offset = T.var("int32") + A_sub = T.match_buffer( + A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + elem_offset=A_elem_offset, + ) + B_sub = T.match_buffer( + B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], + (16, 16), + elem_offset=B_elem_offset, + ) + C_sub = T.match_buffer( + C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], + (16, 16), + elem_offset=C_elem_offset, + ) + T.evaluate( + T.tvm_mma_sync( + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + A_sub.data, + T.floordiv(A_sub.elem_offset, 256), + B_sub.data, + T.floordiv(B_sub.elem_offset, 256), + C_sub.data, + T.floordiv(C_sub.elem_offset, 256), + dtype="handle", + ) + ) + + +@T.prim_func +def tensorized_batch_matmul_dot_product( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n, i, j, k_0 in T.grid(16, 128, 128, 32): + with T.block("blockized_update"): + vn, vi, vj, vko = T.axis.remap("SSSR", [n, i, j, k_0]) + T.reads( + C[vn, vi, vj], A[vn, vi, vko * 4 : vko * 4 + 4], B[vn, vj, vko * 4 : vko * 4 + 4] + ) + T.writes(C[vn, vi, vj]) + A_1 = T.match_buffer( + A[vn, vi, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1 + ) + B_1 = T.match_buffer( + B[vn, vj, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1 + ) + C_1 = T.match_buffer(C[vn, vi, vj], [], dtype="float32", offset_factor=1) + T.evaluate( + T.call_extern( + "vec4add", + C_1.data, + C_1.elem_offset, + A_1.data, + A_1.elem_offset, + B_1.data, + B_1.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def tensorized_batch_matmul_outer_product( + A: T.Buffer[(16, 128, 128), "float32"], + B: T.Buffer[(16, 128, 128), "float32"], + C: T.Buffer[(16, 128, 128), "float32"], +) -> None: + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + T.reads() + T.writes(C[vn, vi, vj]) + C[vn, vi, vj] = T.float32(0) + for n, i_0, j_0, k in T.grid(16, 8, 8, 128): + with T.block("blockized_update"): + vn, vio, vjo, vk = T.axis.remap("SSSR", [n, i_0, j_0, k]) + T.reads( + C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], + A[vn, vio * 16 : vio * 16 + 16, vk], + B[vn, vjo * 16 : vjo * 16 + 16, vk], + ) + T.writes(C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) + A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1) + B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1 + ) + C_1 = T.match_buffer( + C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], [16, 16], dtype="float32", offset_factor=1 + ) + T.evaluate( + T.call_extern("outer_product", C_1.data, C_1.elem_offset, A_1.data, A_1.elem_offset, + B_1.data, B_1.elem_offset, dtype="int32" + ) + ) + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) +tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, dot_product_intrin) +tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, outer_product_intrin) + + +def test_tensorize_matmul(): + func = matmul + # schedule + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.decompose_reduction(update, ko) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_matmul, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_tensorize_batch_matmul(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + _, i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_mma, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=batch_matmul) + + +def test_tensorize_dot_product(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, _, _, k = s.get_loops(C) + _, ki = s.split(k, factors=[None, 4]) + s.tensorize(ki, "test_dot_product_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_dot_product, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_tensorize_outer_product(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, i, j, k = s.get_loops(C) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + s.reorder(io, jo, k, ii, ji) + s.tensorize(ii, "test_outer_product_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul_outer_product, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From fa00dc6e0fdcef02ef40c3ff635550e90f9b2523 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Wed, 26 Jan 2022 02:30:17 -0800 Subject: [PATCH 08/49] [microTVM][tutorial] Add ENV variable to enable testing on physical hardware (#9993) * Add env variable to micro tflite tutorial * Address @gromero comments * address @areusch comment * fix scope * trigger * trigger --- .../how_to/work_with_microtvm/micro_tflite.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/gallery/how_to/work_with_microtvm/micro_tflite.py b/gallery/how_to/work_with_microtvm/micro_tflite.py index bd70fc581c5c..3d871ba783ad 100644 --- a/gallery/how_to/work_with_microtvm/micro_tflite.py +++ b/gallery/how_to/work_with_microtvm/micro_tflite.py @@ -123,12 +123,18 @@ # directory into a buffer import os +import json +import tarfile +import pathlib +import tempfile import numpy as np import tvm -from tvm.contrib.download import download_testdata from tvm import relay +import tvm.contrib.utils +from tvm.contrib.download import download_testdata +use_physical_hw = bool(os.getenv("TVM_MICRO_USE_HW")) model_url = "https://people.linaro.org/~tom.gall/sine_model.tflite" model_file = "sine_model.tflite" model_path = download_testdata(model_url, model_file, module="data") @@ -181,7 +187,7 @@ # RUNTIME = tvm.relay.backend.Runtime("crt", {"system-lib": True}) TARGET = tvm.target.target.micro("host") -BOARD = "qemu_x86" + # # Compiling for physical hardware # When running on physical hardware, choose a TARGET and a BOARD that describe the hardware. The @@ -190,8 +196,15 @@ # board but a couple of wirings and configs differ, it's necessary to select the "stm32f746g_disco" # board to generated the right firmware image. # -# TARGET = tvm.target.target.micro("stm32f746xx") -# BOARD = "nucleo_f746zg" # or "stm32f746g_disco#" + +if use_physical_hw: + boards_file = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) / "boards.json" + with open(boards_file) as f: + boards = json.load(f) + + BOARD = os.getenv("TVM_MICRO_BOARD", default="nucleo_f746zg") + TARGET = tvm.target.target.micro(boards[BOARD]["model"]) + # # For some boards, Zephyr runs them emulated by default, using QEMU. For example, below is the # TARGET and BOARD used to build a microTVM firmware for the mps2-an521 board. Since that board @@ -237,15 +250,12 @@ # (:doc:`Model Library Format` `). This is a tarball with a standard layout: # Get a temporary path where we can store the tarball (since this is running as a tutorial). -import tempfile fd, model_library_format_tar_path = tempfile.mkstemp() os.close(fd) os.unlink(model_library_format_tar_path) tvm.micro.export_model_library_format(module, model_library_format_tar_path) -import tarfile - with tarfile.open(model_library_format_tar_path, "r:*") as tar_f: print("\n".join(f" - {m.name}" for m in tar_f.getmembers())) @@ -264,9 +274,6 @@ # this lives in a file ``microtvm_api_server.py`` in the root directory). Let's use the example ``host`` # project in this tutorial, which simulates the device using a POSIX subprocess and pipes: -import subprocess -import pathlib - template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("crt")) project_options = {} # You can use options to provide platform-specific options through TVM. @@ -275,11 +282,12 @@ # For physical hardware, you can try out the Zephyr platform by using a different template project # and options: # -# template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) -# project_options = {"project_type": "host_driven", zephyr_board": "nucleo_f746zg"}} + +if use_physical_hw: + template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) + project_options = {"project_type": "host_driven", "zephyr_board": BOARD} # Create a temporary directory -import tvm.contrib.utils temp_dir = tvm.contrib.utils.tempdir() generated_project_dir = temp_dir / "generated-project" From c90311ae333f7037bb91f78e41cdd3543febb9fc Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 26 Jan 2022 12:19:28 +0000 Subject: [PATCH 09/49] [microNPU] Refactor base address determination to codegen (#9929) This commit introduces BaseAddress ObjectRef to determine base addresses in the codegen for microNPU. This is required when multiple memory pools become available. Thus, base addresses could not be statically determined in the source module. --- .../relay/backend/contrib/ethosu/codegen.py | 12 +-- .../contrib/ethosu/tir_to_cs_translator.py | 62 ++++++++++++- .../tvm/relay/backend/contrib/ethosu/util.py | 39 ++++++-- .../backend/contrib/ethosu/source_module.cc | 93 +++++++++---------- src/relay/backend/contrib/ethosu/utils.cc | 49 ++++++---- src/relay/backend/contrib/ethosu/utils.h | 86 ++++++++++++----- src/tir/transforms/make_unpacked_api.cc | 5 + .../test_ethosu/test_tir_to_cs_translator.py | 32 ++++--- 8 files changed, 258 insertions(+), 120 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 7666691aa19f..98ee41f428b2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -297,8 +297,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: This returns the scheduled PrimFunc """ assert len(ext_func.params) == 1 - input_size = util.calculate_size_bytes(ext_func.params[0]) - output_size = util.calculate_size_bytes(ext_func.body) mod = tvm.IRModule() mod["main"] = ext_func mod = LegalizeEthosU()(mod) @@ -317,8 +315,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) primfunc = primfunc.with_attr("ethos-u.constants", const_dict) - primfunc = primfunc.with_attr("ethos-u.input_size", input_size) - primfunc = primfunc.with_attr("ethos-u.output_size", output_size) return primfunc @@ -342,8 +338,6 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact """ symbol = str(primfunc.attrs["global_symbol"]) const_dict = primfunc.attrs["ethos-u.constants"] - input_size = primfunc.attrs["ethos-u.input_size"] - output_size = primfunc.attrs["ethos-u.output_size"] tir_mod = tvm.IRModule() tir_mod[symbol] = primfunc @@ -351,9 +345,7 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact for idx in const_dict.keys(): const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy() - cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate( + cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate( tir_mod, const_dict_with_int_keys ) - return util.CompilationArtifact( - cmms, encoded_constants, scratch_size, input_size, output_size, symbol - ) + return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 77fbc3e8628d..d7254511ebfc 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -18,7 +18,7 @@ the Relay to TIR compilation process, to Vela API calls to generate command stream. """ -from typing import Dict, NamedTuple, Tuple, Union +from typing import Dict, NamedTuple, Tuple, Union, List from enum import auto from enum import Enum import numpy as np # type: ignore @@ -102,8 +102,8 @@ def translate(tir_module, params): encoded_constants : str An hex string of the bytes that includes concat'd encoded weights, encoded biases and scales. - scratch_size : int - The size of the scratch buffer needed. + base_addresses : List[util.BaseAddress] + base addresses to be used by the driver """ buffer_info = extract_buffer_info(tir_module, params) @@ -112,10 +112,60 @@ def translate(tir_module, params): for call_extern in call_extern_list: _npu_ops.append(translate_ethosu_tir_call_extern(call_extern)) _npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops) + base_addresses = extract_param_base_addresses(tir_module, buffer_info) + if scratch_size > 0: + base_addresses.append( + util.BaseAddress( + "scratch", + None, + _REGION_MAP[BufferType.scratch], + scratch_size, + True, + ) + ) target_accel_config = vela_api.get_accelerator_config() cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) payload = vapi.npu_create_driver_payload(cmds, target_accel_config) - return payload.hex(), constant_data, scratch_size + return payload.hex(), constant_data, base_addresses + + +def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]: + """This function extracts base addresses to be used by the driver + + Parameters + ---------- + mod : tvm.IRModule + The TIR Module for NPU + buffer_info : Dict[tvm.tir.Var, BufferInfo] + Information regarding buffer vars used in the PrimFunc + + Returns + ------- + List[util.BaseAddress] + base addresses to be used by the driver + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + base_addresses = list() + idx = 0 + for param in primfunc.params: + # constants are pooled together and handled specially + # this will change after tir.allocate_const. + # For now, we are skipping generating buffer addresses here + if buffer_info[param].btype == BufferType.constant: + continue + buffer = primfunc.buffer_map[param] + dtype = buffer.dtype + element_size_bytes = np.iinfo(dtype).bits // 8 + size_bytes = element_size_bytes * np.prod(list(buffer.shape)) + base_addresses.append( + util.BaseAddress(param.name, idx, _REGION_MAP[buffer_info[param].btype], size_bytes) + ) + idx += 1 + + return base_addresses def extract_call_extern_list(mod): @@ -171,6 +221,7 @@ def extract_buffer_info( # There should only be a single function assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] + for idx, const_data in param_dict.items(): param = primfunc.params[idx] buffer_info[param] = BufferInfo( @@ -301,6 +352,9 @@ def classify_io(buffer): assert buffer_type in (BufferType.input, BufferType.output) address = 0 buffer_addresses[_buffer] = (address, buffer_type) + buffer_info[_buffer] = BufferInfo( + values=None, shape=info.dtype, dtype=info.dtype, btype=buffer_type + ) elif info.btype == BufferType.shram: accl_config = util.get_accelerator_config() arch_config = get_accelerator_arch_config(accl_config) diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 21b0ecf789d2..fcc8e9e9df30 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -23,7 +23,7 @@ from inspect import signature from enum import Enum -from typing import Union, Tuple +from typing import Union, Tuple, List import numpy as np # type: ignore import tvm # type: ignore @@ -239,6 +239,31 @@ def calculate_size_bytes(expr): return element_size * elements +@register_object("relay.ext.ethos-u.BaseAddress") +class BaseAddress(Object): + """ + This is a structure to hold base addresses for pointers + provided for the driver. + """ + + def __init__( + self, + name: str, + primfunc_param_idx: int, + region: int, + size: int, + is_runtime_allocation: bool = False, + ): + self.__init_handle_by_constructor__( + _ffi_api.BaseAddress, # type: ignore # pylint: disable=no-member + name, + primfunc_param_idx, + region, + size, + is_runtime_allocation, + ) + + @register_object("relay.ext.ethos-u.CompilationArtifact") class CompilationArtifact(Object): """ @@ -248,19 +273,15 @@ class CompilationArtifact(Object): def __init__( self, + function_name: str, command_stream: str, encoded_constants: str, - scratch_size: int, - input_size: int, - output_size: int, - function_name: str, + base_addresses: List[BaseAddress], ): self.__init_handle_by_constructor__( _ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member + function_name, command_stream, encoded_constants, - scratch_size, - input_size, - output_size, - function_name, + base_addresses, ) diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index 66955f8b201f..7d25505ab59c 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -124,7 +124,9 @@ class EthosUModuleNode : public ModuleNode { private: std::string c_source; Array compilation_artifacts_; + Map pool_var_names_; int indent_{0}; + constexpr static int kMaxBaseAddresses_ = 6; /*! * \brief Convert the raw string of hex values into a hex string @@ -150,7 +152,7 @@ class EthosUModuleNode : public ModuleNode { * \return string of code that updates the base_addrs array with the base address of the given * array */ - std::string SetBaseAddress(int index, std::string name, std::string size) { + std::string SetBaseAddress(int index, std::string name, int size) { std::stringstream ss; ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n"; ss << " base_addrs_size[" << index << "] = " << size << ";\n"; @@ -178,11 +180,24 @@ class EthosUModuleNode : public ModuleNode { } /*! - * \brief Creates a runtime function header + * \brief Creates a runtime function signature */ - void PrintRuntimeFunctionHeader(std::stringstream& ss, std::string func_name) { - ss << "TVM_DLL int32_t "; - ss << func_name << "(void* input, void* output, void* resource_handle) {\n"; + void PrintRuntimeFunctionSignature(std::stringstream& ss, + const relay::contrib::ethosu::CompilationArtifact& artifact, + std::string func_name) { + ss << "TVM_DLL int32_t " << func_name; + ss << "("; + std::unordered_map param_idx_to_base_address; + for (const relay::contrib::ethosu::BaseAddress& base_address : artifact->base_addresses) { + if (base_address->primfunc_param_idx.defined()) { + param_idx_to_base_address[base_address->primfunc_param_idx] = base_address; + } + } + for (unsigned int i = 0; i < param_idx_to_base_address.size(); i++) { + relay::contrib::ethosu::BaseAddress base_address = param_idx_to_base_address[i]; + ss << "void* " << base_address->name << ","; + } + ss << "void* resource_handle) {\n"; } /*! @@ -216,7 +231,6 @@ class EthosUModuleNode : public ModuleNode { std::stringstream ss; size_t weights_size = (compilation_artifact->encoded_constants.size() / 2); - size_t scratch_size = compilation_artifact->scratch_size; ss << "// Update linker script to place .rodata.tvm in memory that can be accessed by the " "NPU\n"; if (weights_size > 0) { @@ -234,61 +248,44 @@ class EthosUModuleNode : public ModuleNode { ss << "\n"; PrintExternCPrefix(ss); - ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, " - << "size_t in0_size, int8_t* out0, size_t out0_size, void* resource_handle) {\n"; - ss << " int num_tensors = 5;\n"; + PrintRuntimeFunctionSignature(ss, compilation_artifact, func_no_dashes); ss << " void* cms_data = (void*)(" << func_no_dashes << "_cms_data_data);\n"; ss << " int64_t device_type = kDLCPU;\n"; ss << " int64_t device_id = 0;\n"; - ss << " const size_t weights_size = " << std::to_string(weights_size) << ";\n"; - ss << " const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; ss << " const size_t cms_data_size = sizeof(" << func_no_dashes << "_cms_data_data);\n"; - if (scratch_size > 0) { - ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " - "(uint64_t)scratch_size, 0, 16);\n"; - } else { - ss << " int8_t* scratch = NULL;\n"; - } - ss << " size_t base_addrs_size[num_tensors];\n"; - ss << " uint64_t base_addrs[num_tensors];\n"; + ss << " size_t base_addrs_size[" << kMaxBaseAddresses_ << "] = {0};\n"; + ss << " uint64_t base_addrs[" << kMaxBaseAddresses_ << "] = {0};\n"; ss << "\n"; - ss << SetBaseAddress(0, func_no_dashes + "_weights", "weights_size"); - ss << SetBaseAddress(1, "scratch", "scratch_size"); - ss << SetBaseAddress(2, "scratch", "scratch_size"); - ss << SetBaseAddress(3, "in0", "in0_size"); - ss << SetBaseAddress(4, "out0", "out0_size"); + + ss << SetBaseAddress(0, func_no_dashes + "_weights", weights_size); + for (const relay::contrib::ethosu::BaseAddress& base_address : + compilation_artifact->base_addresses) { + if (base_address->is_runtime_allocation) { + ss << " int8_t* " << base_address->name + << " = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " + "(uint64_t)" + << base_address->size << ", 0, 16);\n"; + } + ss << SetBaseAddress(base_address->region->value, base_address->name.c_str(), + base_address->size->value); + } ss << "\n"; + ss << " int32_t result = TVMEthosULaunch(resource_handle, cms_data, cms_data_size, " - "base_addrs, base_addrs_size, num_tensors);\n"; - if (scratch_size > 0) { - ss << " TVMBackendFreeWorkspace(device_type, device_id, scratch);\n"; + "base_addrs, base_addrs_size, " + << kMaxBaseAddresses_ << ");\n"; + + for (const relay::contrib::ethosu::BaseAddress& base_address : + compilation_artifact->base_addresses) { + if (base_address->is_runtime_allocation) { + ss << " TVMBackendFreeWorkspace(device_type, device_id, " << base_address->name << ");\n"; + } } ss << " return result;\n"; ss << "}\n"; ss << "\n"; PrintExternCPostfix(ss); ss << "\n"; - PrintExternCPrefix(ss); - ss << "// Wrapper function is provided to allow for easier debugging\n"; - ss << "inline static int32_t " + func_no_dashes + - "_wrapper_(void* input, void* output, void* resource_handle) {\n"; - ss << " size_t input_data_size = " << compilation_artifact->input_size << ";\n"; - ss << " size_t output_data_size = " << compilation_artifact->output_size << ";\n"; - ss << " return " + func_no_dashes + - "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size, " + - "resource_handle);\n"; - ss << "}\n"; - PrintExternCPostfix(ss); - ss << "\n"; - PrintExternCPrefix(ss); - PrintRuntimeFunctionHeader(ss, func_no_dashes); - EnterScope(); - PrintIndents(ss); - ss << "return " << func_no_dashes << "_wrapper_(input, output, resource_handle);\n"; - ExitScope(); - ss << "}\n"; - PrintExternCPostfix(ss); - return ss.str(); } }; diff --git a/src/relay/backend/contrib/ethosu/utils.cc b/src/relay/backend/contrib/ethosu/utils.cc index 7e6c1c2ac840..01bd4d10324d 100644 --- a/src/relay/backend/contrib/ethosu/utils.cc +++ b/src/relay/backend/contrib/ethosu/utils.cc @@ -36,37 +36,54 @@ namespace relay { namespace contrib { namespace ethosu { -CompilationArtifact::CompilationArtifact(String command_stream, String encoded_constants, - Integer scratch_size, Integer input_size, - Integer output_size, String function_name) { +BaseAddress::BaseAddress(String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation) { + auto base_address_node = make_object(); + base_address_node->name = name; + base_address_node->primfunc_param_idx = primfunc_param_idx; + base_address_node->region = region; + base_address_node->size = size; + base_address_node->is_runtime_allocation = is_runtime_allocation; + data_ = std::move(base_address_node); +} + +TVM_REGISTER_NODE_TYPE(BaseAddressNode); +TVM_REGISTER_GLOBAL("relay.ext.ethos-u.BaseAddress") + .set_body_typed([](String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation) { + if (is_runtime_allocation.defined()) { + return BaseAddress(name, primfunc_param_idx, region, size, is_runtime_allocation); + } else { + return BaseAddress(name, primfunc_param_idx, region, size); + } + }); + +CompilationArtifact::CompilationArtifact(String function_name, String command_stream, + String encoded_constants, + Array base_addresses) { auto compilation_artifact_node = make_object(); + compilation_artifact_node->function_name = function_name; compilation_artifact_node->command_stream = command_stream; compilation_artifact_node->encoded_constants = encoded_constants; - compilation_artifact_node->scratch_size = scratch_size; - compilation_artifact_node->input_size = input_size; - compilation_artifact_node->output_size = output_size; - compilation_artifact_node->function_name = function_name; + compilation_artifact_node->base_addresses = base_addresses; data_ = std::move(compilation_artifact_node); } TVM_REGISTER_NODE_TYPE(CompilationArtifactNode); TVM_REGISTER_GLOBAL("relay.ext.ethos-u.CompilationArtifact") - .set_body_typed([](String command_stream, String encoded_constants, Integer scratch_size, - Integer input_size, Integer output_size, String function_name) { - return CompilationArtifact(command_stream, encoded_constants, scratch_size, input_size, - output_size, function_name); + .set_body_typed([](String function_name, String command_stream, String encoded_constants, + Array base_addresses) { + return CompilationArtifact(function_name, command_stream, encoded_constants, base_addresses); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "CompilationArtifactNode(\n" - << "command_stream=" << node->command_stream + << "function_name=" << node->function_name + << ",\n command_stream=" << node->command_stream << ",\n encoded_constants=" << node->encoded_constants - << ",\n scratch_size=" << node->scratch_size - << ",\n input_size=" << node->input_size - << ",\n output_size=" << node->output_size - << ",\n function_name=" << node->function_name << ")"; + << ",\n base_addresses=" << node->base_addresses << ")"; }); } // namespace ethosu diff --git a/src/relay/backend/contrib/ethosu/utils.h b/src/relay/backend/contrib/ethosu/utils.h index 5e9e337c3f17..5c61271d3425 100644 --- a/src/relay/backend/contrib/ethosu/utils.h +++ b/src/relay/backend/contrib/ethosu/utils.h @@ -34,47 +34,91 @@ namespace relay { namespace contrib { namespace ethosu { +/*! + * \brief Base addresses are input pointers to + * the driver that get accessed by the command stream + * using offsets to read/write data. + */ +struct BaseAddressNode : public Object { + /*! \brief The identifier, usually it the param name of the PrimFunc that gets lowered */ + String name; + /*! \brief The index in the params array of the PrimFunc. This is needed to keep aligned + * between the PrimFunc arguments ordering and argument ordering of generated code */ + Integer primfunc_param_idx; + /*! \brief The region used by the command stream. This needs to match with base address + * index passed into the driver */ + Integer region; + /*! \brief The size of the buffer accessible by this base address */ + Integer size; + /*! \brief This is a runtime allocation that needs to be done in the function */ + Bool is_runtime_allocation{Bool(false)}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("primfunc_param_idx", &primfunc_param_idx); + v->Visit("region", ®ion); + v->Visit("size", &size); + v->Visit("is_runtime_allocation", &is_runtime_allocation); + } + + bool SEqualReduce(const BaseAddressNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(primfunc_param_idx, other->primfunc_param_idx) && + equal(region, other->region) && equal(size, other->size) && + equal(is_runtime_allocation, other->is_runtime_allocation); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(primfunc_param_idx); + hash_reduce(region); + hash_reduce(size); + hash_reduce(is_runtime_allocation); + } + + static constexpr const char* _type_key = "relay.ext.ethos-u.BaseAddress"; + TVM_DECLARE_FINAL_OBJECT_INFO(BaseAddressNode, Object); +}; + +class BaseAddress : public ObjectRef { + public: + TVM_DLL BaseAddress(String name, Integer primfunc_param_idx, Integer region, Integer size, + Bool is_runtime_allocation = Bool(false)); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BaseAddress, ObjectRef, BaseAddressNode); +}; + /*! * \brief Captures all the binary artifactes required to create * the C-source runtime module */ struct CompilationArtifactNode : public Object { + /*! \brief The function name for this artifact belongs to */ + String function_name; /*! \brief The binary command stream (CS) in hex format */ String command_stream; /*! \brief The encoded biases and weights in hex format */ String encoded_constants; - /*! \brief The intermediary scratch area required for the execution of the CS */ - Integer scratch_size; - /*! \brief The size of the input tensor in bytes */ - Integer input_size; - /*! \brief The size of the output tensor in bytes */ - Integer output_size; - /*! \brief The name of the function */ - String function_name; + /*! \brief The information regarding the base addresses */ + Array base_addresses; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("function_name", &function_name); v->Visit("command_stream", &command_stream); v->Visit("encoded_constants", &encoded_constants); - v->Visit("scratch_size", &scratch_size); - v->Visit("input_size", &input_size); - v->Visit("output_size", &output_size); - v->Visit("function_name", &function_name); + v->Visit("base_addresses", &base_addresses); } bool SEqualReduce(const CompilationArtifactNode* other, SEqualReducer equal) const { - return equal(command_stream, other->command_stream) && + return equal(function_name, other->function_name) && + equal(command_stream, other->command_stream) && equal(encoded_constants, other->encoded_constants) && - equal(scratch_size, other->scratch_size) && equal(input_size, other->input_size) && - equal(output_size, other->output_size) && equal(function_name, other->function_name); + equal(base_addresses, other->base_addresses); } void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(function_name); hash_reduce(command_stream); hash_reduce(encoded_constants); - hash_reduce(scratch_size); - hash_reduce(input_size); - hash_reduce(output_size); - hash_reduce(function_name); + hash_reduce(base_addresses); } static constexpr const char* _type_key = "relay.ext.ethos-u.CompilationArtifact"; @@ -83,8 +127,8 @@ struct CompilationArtifactNode : public Object { class CompilationArtifact : public ObjectRef { public: - TVM_DLL CompilationArtifact(String command_stream, String encoded_constants, Integer scratch_size, - Integer input_size, Integer output_size, String function_name); + TVM_DLL CompilationArtifact(String function_name, String command_stream, String encoded_constants, + Array base_addresses); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CompilationArtifact, ObjectRef, CompilationArtifactNode); }; diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 6365e09246fc..fc43e1449d6a 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -60,12 +60,16 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { // Collect variables and buffers to map between Array args; + Map new_buffer_map; for (const Var& param : func->params) { // Ideally all func params should have Buffers defined in the buffer_map // We should look to insert buffer_maps for all PrimFuncs that are returned // to the core compiler. if (func->buffer_map.find(param) != func->buffer_map.end()) { args.push_back(func->buffer_map[param]->data); + // Rewiring the buffer_var to map to Buffers for low-level passes + // retain information about the buffer. + new_buffer_map.Set(func->buffer_map[param]->data, func->buffer_map[param]); } else { args.push_back(param); } @@ -79,6 +83,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { func_ptr->body = MergeNest(device_init, func_ptr->body); func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); + func_ptr->buffer_map = new_buffer_map; // return the function. return std::move(func); diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index c14deb636c25..0cadf96e7a18 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -227,8 +227,16 @@ def test_buffer_info_extraction(): "uint8", tir_to_cs_translator.BufferType.input_or_output, ), - "ethosu_conv2d_2": ([1024], "uint8", tir_to_cs_translator.BufferType.scratch), - "ethosu_conv2d_3": ([2048], "uint8", tir_to_cs_translator.BufferType.scratch), + "ethosu_conv2d_2": ( + [1024], + "uint8", + tir_to_cs_translator.BufferType.scratch, + ), + "ethosu_conv2d_3": ( + [2048], + "uint8", + tir_to_cs_translator.BufferType.scratch, + ), }, }, ] @@ -805,12 +813,10 @@ def _check_buffer(address, region, length, buffer_var): # Every buffer is adjusted to align to 16 bytes size_in_bytes = util.round_up(size_in_bytes, 16) assert address + size_in_bytes <= scratch_size - # The scratch area should not be used by anyother buffer - assert not scratch_allocation_mask[address : address + size_in_bytes].any() + # The scratch area should not be used by any other buffer + assert not scratch_mask[address : address + size_in_bytes].any() # The scratch area is marked as used - scratch_allocation_mask[address : address + size_in_bytes] = np.ones( - size_in_bytes, dtype="uint8" - ) + scratch_mask[address : address + size_in_bytes] = np.ones(size_in_bytes, dtype="uint8") elif buffer_type == tir_to_cs_translator.BufferType.input: assert address == 0 else: @@ -887,14 +893,16 @@ def check_buffer(address, region, length, buffer_var): for extern_call in extern_calls: _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) - _npu_ops, constant_hex_string, scratch_size = tir_to_cs_translator.assign_addresses( - buffer_info, _npu_ops - ) - scratch_allocation_mask = np.zeros(scratch_size, dtype="uint8") + ( + _npu_ops, + constant_hex_string, + scratch_size, + ) = tir_to_cs_translator.assign_addresses(buffer_info, _npu_ops) + scratch_mask = np.zeros(scratch_size, dtype="uint8") constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8") verify(_npu_ops) # This will be only 1 if all allocated scratch is used. - assert np.prod(scratch_allocation_mask) == 1 + assert np.prod(scratch_mask) == 1 # This will be only 1 if all constant tensors is read at least once. assert np.prod(constant_tensor_read_mask) == 1 From b3fcb4b211b896bf0bf8492acd8f0d9070d97b3f Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Wed, 26 Jan 2022 15:22:33 +0300 Subject: [PATCH 10/49] Add FP requantize flow. Set float32 flow by default for llvm x86 targets with (#9637) sse4.1 support --- include/tvm/relay/qnn/attrs.h | 8 +- python/tvm/relay/qnn/op/_requantize.py | 21 + python/tvm/relay/qnn/op/qnn.py | 102 ++- python/tvm/topi/x86/utils.py | 22 + src/relay/qnn/op/requantize.cc | 223 +++++- src/relay/qnn/op/requantize_config.cc | 93 +++ src/relay/qnn/op/requantize_config.h | 126 ++++ src/relay/qnn/utils.cc | 16 + src/relay/qnn/utils.h | 14 +- src/relay/transforms/pattern_utils.h | 25 + tests/python/relay/test_op_qnn_requantize.py | 750 +++++++++++-------- 11 files changed, 1060 insertions(+), 340 deletions(-) create mode 100644 python/tvm/relay/qnn/op/_requantize.py create mode 100644 src/relay/qnn/op/requantize_config.cc create mode 100644 src/relay/qnn/op/requantize_config.h diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index f0280a90c604..deb900d52d09 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -36,6 +36,7 @@ namespace qnn { struct RequantizeAttrs : public tvm::AttrsNode { int axis; std::string rounding; + std::string compute_dtype; DataType out_dtype; TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { @@ -44,7 +45,7 @@ struct RequantizeAttrs : public tvm::AttrsNode { "The output channel axis for channel wise quantization. Default value is -1," "which corresponds to the last axis.") .set_default(-1); - TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe( + TVM_ATTR_FIELD(rounding).set_default("None").describe( "Defines the rounding direction when the value is midway between" "two representable values. There are two supported modes - UPWARD" "or TONEAREST. Both modes behave exactly same except at the" @@ -54,6 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode { "value is rounded away from zero at midpoints (for example, -1.5" "rounds to -2). More context can be found at following gblic manual" "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); + TVM_ATTR_FIELD(compute_dtype) + .set_default("None") + .describe( + "Specifies the data type used during requantize. Supported " + "options: \"int64\", \"float32\", \"float64\""); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/python/tvm/relay/qnn/op/_requantize.py b/python/tvm/relay/qnn/op/_requantize.py new file mode 100644 index 000000000000..2e2fd9fd2980 --- /dev/null +++ b/python/tvm/relay/qnn/op/_requantize.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Internal module for qnn requantization.""" +import tvm._ffi + +tvm._ffi._init_api("relay._requantize", __name__) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 7f707c093ff3..aef514d81cc1 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -14,19 +14,109 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name,unused-argument, not-context-manager """QNN dialect operators.""" from __future__ import absolute_import as _abs +import tvm +import tvm.ir from tvm import relay +from tvm.runtime import Object from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.utils import get_pad_tuple2d from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE - +from tvm.target import Target +from tvm.topi.x86.utils import target_has_sse41 from ... import op as reg from ...op import OpPattern from . import _make +from . import _requantize + + +@tvm._ffi.register_object("relay.qnn.op.RequantizeConfig") +class RequantizeConfig(Object): + """Configure the requantization behavior by setting config variables. + + Note + ---- + This object is backed by node system in C++, with arguments that can be + exchanged between python and C++. + + Do not construct directly, use requantize_config instead. + + The fields that are backed by the C++ node are immutable once an instance + is constructed. Use _node_defaults getters to get results for the fields. + """ + + @staticmethod + def _get_node_default_rounding(): + return "UPWARD" + + @staticmethod + def _get_node_default_compute_dtype(): + target = Target.current(True) + if target and str(target.kind) == "llvm" and target_has_sse41(target.mcpu): + return "float32" + + return "int64" + + _node_defaults = { + "rounding": _get_node_default_rounding.__func__, + "compute_dtype": _get_node_default_compute_dtype.__func__, + } + + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + super(RequantizeConfig, self).__init__(handle) + self.handle = handle + + def __enter__(self): + # pylint: disable=protected-access + _requantize._EnterRequantizeConfigScope(self) + return self + + def __exit__(self, ptype, value, trace): + _requantize._ExitRequantizeConfigScope() + + def __setattr__(self, name, value): + if name in RequantizeConfig._node_defaults: + raise AttributeError("'%s' object cannot set attribute '%s'" % (str(type(self)), name)) + return super(RequantizeConfig, self).__setattr__(name, value) + + +def current_requantize_config(): + """Get the current requantization configuration.""" + return _requantize._GetCurrentRequantizeConfig() + + +def requantize_config(**kwargs): + """Configure the requantization behavior by setting config variables. + + Parameters + --------- + rounding: "UPWARD" or "TONEAREST" + Rounding direction for fixed point multiplications. + compute_dtype: + Specifies the data type used during requantize. + Supported options: \"int64\", \"float32\", \"float64\" + + Returns + ------- + config: RequantizeConfig + The requantization configuration + """ + node_args = { + k: v() if k not in kwargs else kwargs[k] for k, v in RequantizeConfig._node_defaults.items() + } + return tvm.ir.make_node("relay.qnn.op.RequantizeConfig", **node_args) def requantize( @@ -36,7 +126,8 @@ def requantize( output_scale, output_zero_point, axis=-1, - rounding="UPWARD", + rounding="None", + compute_dtype="None", out_dtype="int8", ): r"""Requantized operator. @@ -70,7 +161,9 @@ def requantize( rounding : string, optional Defines the rounding direction when the value is midway between two representable values. - + compute_dtype: + Specifies the data type used during requantize. + Supported options: \"int64\", \"float32\", \"float64\" out_dtype : str, optional Specifies the output data type. @@ -88,6 +181,7 @@ def requantize( output_zero_point, axis, rounding, + compute_dtype, out_dtype, ) diff --git a/python/tvm/topi/x86/utils.py b/python/tvm/topi/x86/utils.py index 50c5c848ee0a..c364027022da 100644 --- a/python/tvm/topi/x86/utils.py +++ b/python/tvm/topi/x86/utils.py @@ -18,6 +18,23 @@ import tvm +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse41") +def target_has_sse41(target): + return ( + target_has_sse42(target) + or target_has_avx(target) + or target_has_avx2(target) + or target_has_avx512(target) + or target_has_vnni(target) + or target + in { + "btver2", + "penryn", + } + ) + + +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse42") def target_has_sse42(target): return ( target_has_avx(target) @@ -42,6 +59,7 @@ def target_has_sse42(target): ) +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx") def target_has_avx(target): return ( target_has_avx2(target) @@ -51,6 +69,7 @@ def target_has_avx(target): ) +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx2") def target_has_avx2(target): return ( target_has_avx512(target) @@ -70,6 +89,7 @@ def target_has_avx2(target): ) +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx512") def target_has_avx512(target): return target in { "skylake-avx512", @@ -89,6 +109,7 @@ def target_has_avx512(target): } +@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_vnni") def target_has_vnni(target): return target in { "cascadelake", @@ -102,6 +123,7 @@ def target_has_vnni(target): } +@tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") def get_simd_32bit_lanes(): mcpu = tvm.target.Target.current().mcpu fp32_vec_len = 4 diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index a7d214761b9b..ea143fe41713 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -26,9 +26,11 @@ #include #include +#include "../../op/op_common.h" #include "../../transforms/infer_layout_utils.h" #include "../../transforms/pattern_utils.h" #include "../utils.h" +#include "./requantize_config.h" namespace tvm { namespace relay { @@ -111,6 +113,65 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param)); } +bool has_current_target_sse41_support() { + auto target = Target::Current(true); + Optional mcpu = + target.defined() ? target->GetAttr("mcpu") : Optional(nullptr); + auto target_has_sse41_fn_ptr = tvm::runtime::Registry::Get("tvm.topi.x86.utils.target_has_sse41"); + ICHECK(target_has_sse41_fn_ptr) << "Function tvm.topi.x86.utils.target_has_sse41 not found"; + return mcpu && (*target_has_sse41_fn_ptr)(mcpu.value()); +} + +/* + * \brief TONEAREST is the standard rounding where the value is rounded away + * from zero at midpoints (for example, -1.5 rounds to -2). + * \param input_tensor The input tensor to rounding op. + * \return The sequence of existing Relay ops. + */ +template +Expr Tonearest(const Expr& input_tensor) { + if (has_current_target_sse41_support()) return Round(input_tensor); + + auto half = MakeConstantScalar(DataType::Float(Bits), 0.5f); + auto zero = MakeConstantScalar(DataType::Float(Bits), 0.f); + auto pos_one = MakeConstantScalar(DataType::Float(Bits), +1.f); + auto neg_one = MakeConstantScalar(DataType::Float(Bits), -1.f); + auto multiplier = Where(Less(input_tensor, zero), neg_one, pos_one); + auto half_multiplied = Multiply(half, multiplier); + auto input_tensor_biased = Add(input_tensor, half_multiplied); + auto input_tensor_biased_multiplied = Multiply(input_tensor_biased, multiplier); + auto input_tensor_biased_multiplied_int = + Cast(input_tensor_biased_multiplied, DataType::Int(Bits)); + auto input_tensor_biased_multiplied_float = + Cast(input_tensor_biased_multiplied_int, DataType::Float(Bits)); + auto input_tensor_rounded = Multiply(input_tensor_biased_multiplied_float, multiplier); + return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor); +} + +/* + * \brief UPWARD is the standard rounding except at midpoints where the value + * is rounded to positive infinity (for example, -1.5 rounds to -1). + * \param input_tensor The input tensor to rounding op. + * \return The sequence of existing Relay ops. + */ +template +Expr Upward(const Expr& input_tensor) { + auto half = MakeConstantScalar(DataType::Float(Bits), 0.5f); + auto input_tensor_biased = Add(input_tensor, half); + if (has_current_target_sse41_support()) return Floor(input_tensor_biased); + + auto zero = MakeConstantScalar(DataType::Float(Bits), 0.f); + auto one = MakeConstantScalar(DataType::Float(Bits), +1.f); + auto input_tensor_biased_int = Cast(input_tensor_biased, DataType::Int(Bits)); + auto input_tensor_biased_float = Cast(input_tensor_biased_int, DataType::Float(Bits)); + auto is_subtraction_not_necessary = + LogicalOr(Equal(input_tensor_biased, input_tensor_biased_float), + GreaterEqual(input_tensor_biased, zero)); + auto input_tensor_rounded = Where(is_subtraction_not_necessary, input_tensor_biased_float, + Subtract(input_tensor_biased_float, one)); + return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor); +} + // Lowering of qnn.requantize op /* @@ -119,7 +180,7 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, * \param param The requantize op attrs. * \param input_shape The input tensor shape of the requantize op. * \return The sequence of existing Relay ops. - * \note Requantization using only integer computation. Here, the computation is + * \note RequantizationInt using only integer computation. Here, the computation is * converted to a fixed point computation by computing output multiplier * and shift. This is useful, if the target device does not support/have * very expensive floating point computations. @@ -131,10 +192,10 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, * 4) Add the output zero point. * 5) Cast to the out_dtype. */ -Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, - const Expr& input_zero_point, const Expr& output_scale, - const Expr& output_zero_point, const RequantizeAttrs* param, - const Array& input_shape, const DataType& out_dtype) { +Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, + const Expr& input_zero_point, const Expr& output_scale, + const Expr& output_zero_point, const RequantizeAttrs* param, + const Array& input_shape, const DataType& out_dtype) { auto tensor = Cast(input_tensor, DataType::Int(32)); auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { @@ -208,6 +269,142 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, return Cast(clipped_t, out_dtype); } +// Lowering of qnn.requantize op + +/* + * \brief Lower requantize to a sequence of ops. + * \param input_tensor The input tensor to requantize op. + * \param param The requantize op attrs. + * \param input_shape The input tensor shape of the requantize op. + * \return The sequence of existing Relay ops. + * \note RequantizationFP using floating computation. All multiplication/sub/sum + * occurs in floating point data type and only at the end is converted to + * int32 data type and clamped for output data type. + * + * The whole computation this can be broken down into following steps + * 1) Subtract the input zero point. + * 2) Perform multiplication. + * 3) Add the output zero point. + * 4) Cast to the out_dtype. + */ +template +Expr RequantizeLowerFP(const Expr& input_tensor, const Expr& input_scale, + const Expr& input_zero_point, const Expr& output_scale, + const Expr& output_zero_point, const RequantizeAttrs* param, + const Array& input_shape, const DataType& out_dtype) { + auto tensor = Cast(input_tensor, DataType::Float(Bits)); + auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); + if (!IsEqualScalar(input_zero_point, zero_scalar)) { + // Broadcast input zero point if needed. + int rank = static_cast(input_shape.size()); + int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis; + Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point, + { + -1, + }), + rank, {axis}); + tensor = Subtract(Cast(tensor, DataType::Float(Bits)), + Cast(input_zero_broadcast, DataType::Float(Bits))); + } else { + tensor = Cast(tensor, DataType::Float(Bits)); + } + + // 2) If the input and output scales are same, we can skip the multiplication. Check + // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for + // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input + // tensor. Depending on the quantization type, the fixed point multiplication routing is called. + auto scaled_fp_t = tensor; + double output_scale_float = GetScalarFromConstant(output_scale); + if (IsConstScalar(input_scale)) { + // This is per-tensor quantization. Single scale. + double input_scale_float = GetScalarFromConstant(input_scale); + double double_multiplier = static_cast(input_scale_float) / output_scale_float; + // Skip if input and output scales are same. + if (!IsEqualScalar(input_scale, output_scale)) { + double multiplier = double_multiplier; + auto m_scalar = MakeConstantScalar(DataType::Float(Bits), multiplier); + scaled_fp_t = Multiply(m_scalar, scaled_fp_t); + } + + } else { + // This is per-channel (per=axis) quantization. + std::vector double_multipliers; + auto input_axis_scales = GetFloatVectorFromConstant(input_scale); + double output_scale_float = GetScalarFromConstant(output_scale); + for (auto input_axis_scale : input_axis_scales) { + double multiplier = static_cast(input_axis_scale) / output_scale_float; + double_multipliers.push_back(multiplier); + } + int axis = param->axis; + axis = (axis == -1) ? input_shape.size() - 1 : axis; + + auto fixed_pt_multiplier_expr = MakeConstantTensor( + DataType::Float(Bits), {(int64_t)double_multipliers.size()}, double_multipliers); + size_t n_dim = input_shape.size(); + auto exp_fixed_pt_multiplier_expr = + ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {axis}); + + scaled_fp_t = Multiply(scaled_fp_t, exp_fixed_pt_multiplier_expr); + } + + // 3) Add the output zero point. + auto shifted_fp_t = scaled_fp_t; + if (!IsEqualScalar(output_zero_point, zero_scalar)) { + shifted_fp_t = Add(shifted_fp_t, Cast(output_zero_point, DataType::Float(Bits))); + } + + if (param->rounding == "UPWARD") { + shifted_fp_t = Upward(shifted_fp_t); + } else /*if (param->rounding == "TONEAREST")*/ { + shifted_fp_t = Tonearest(shifted_fp_t); + } + + shifted_fp_t = Cast(shifted_fp_t, DataType::Int(32)); + // 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point + // multiplication keeps the value in int32 range. + if (out_dtype == DataType::Int(32)) { + return shifted_fp_t; + } + + auto q_min = GetQmin(out_dtype); + auto q_max = GetQmax(out_dtype); + auto clipped_t = Clip(shifted_fp_t, q_min, q_max); + return Cast(clipped_t, out_dtype); +} + +// Lowering of qnn.requantize op +/* + * \brief Lower requantize to a sequence of ops. + * \param input_tensor The input tensor to requantize op. + * \param param The requantize op attrs. + * \param input_shape The input tensor shape of the requantize op. + * \return The sequence of existing Relay ops. + */ +Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, + const Expr& input_zero_point, const Expr& output_scale, + const Expr& output_zero_point, const RequantizeAttrs* param, + const Array& input_shape, const DataType& out_dtype) { + // Check rounding validity. + ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") + << "QNN requantize supports two rounding modes - UPWARD and " + << "TONEAREST"; + // Check compute_dtype validity. + ICHECK(param->compute_dtype == "int64" || param->compute_dtype == "float32" || + param->compute_dtype == "float64") + << "QNN requantize supports three compute_dtype variants - \"int64\", \"float32\" and " + "\"float64\""; + if (param->compute_dtype == "float32") { + return RequantizeLowerFP<32>(input_tensor, input_scale, input_zero_point, output_scale, + output_zero_point, param, input_shape, out_dtype); + } else if (param->compute_dtype == "float64") { + return RequantizeLowerFP<64>(input_tensor, input_scale, input_zero_point, output_scale, + output_zero_point, param, input_shape, out_dtype); + } else /*if (param->compute_dtype == "int64") */ { + return RequantizeLowerInt(input_tensor, input_scale, input_zero_point, output_scale, + output_zero_point, param, input_shape, out_dtype); + } +} + /* * \brief Forward rewrite the requantize op. * \param ref_call The original call that will be lowered. @@ -230,8 +427,15 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, auto& output_scale = new_args[3]; auto& output_zero_point = new_args[4]; const auto* param = attrs.as(); + const RequantizeConfig& cfg = RequantizeConfig::Current(); + ICHECK(param != nullptr); + const_cast(param)->rounding = + SelectRequntizeParameter(param->rounding, cfg->get_rounding(), cfg->is_default, "rounding"); + const_cast(param)->compute_dtype = SelectRequntizeParameter( + param->compute_dtype, cfg->get_compute_dtype(), cfg->is_default, "compute_dtype"); + // Find input shape. ICHECK_EQ(types.size(), 6); auto in_type = types[0]; @@ -246,11 +450,6 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, ICHECK(out_tensor_type != nullptr) << "Type information missing." << " Please run infer_type pass."; auto out_dtype = out_tensor_type->dtype; - - // Check rounding validity. - ICHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") - << "QNN requantize supports two rounding modes - UPWARD and " - << "TONEAREST"; return RequantizeLower(quantized_data, input_scale, input_zero_point, output_scale, output_zero_point, param, input_shape, out_dtype); } @@ -317,11 +516,13 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create qnn requantize operator // used by frontend FFI. Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale, - Expr output_zero_point, int axis, String rounding, DataType out_dtype) { + Expr output_zero_point, int axis, String rounding, String compute_dtype, + DataType out_dtype) { auto attrs = make_object(); attrs->axis = axis; attrs->rounding = std::move(rounding); attrs->out_dtype = std::move(out_dtype); + attrs->compute_dtype = std::move(compute_dtype); static const Op& op = Op::Get("qnn.requantize"); return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point}, Attrs(attrs), {}); diff --git a/src/relay/qnn/op/requantize_config.cc b/src/relay/qnn/op/requantize_config.cc new file mode 100644 index 000000000000..4a52f56400c9 --- /dev/null +++ b/src/relay/qnn/op/requantize_config.cc @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/requantize_config.cc + * \brief QNN requantize config. + */ + +#include "./requantize_config.h" + +#include +#include +#include + +#include + +namespace tvm { +namespace relay { +namespace qnn { + +/*! \brief Entry to hold the BuildConfig context stack. */ +struct TVMRequantizeConfigThreadLocalEntry { + /*! \brief The default build config if the stack is empty */ + RequantizeConfig default_config; + + /*! \brief The current build config context */ + std::stack context_stack; + + TVMRequantizeConfigThreadLocalEntry() : default_config(make_object(true)) {} +}; + +/*! \brief Thread local store to hold the BuildConfig context stack. */ +typedef dmlc::ThreadLocalStore + TVMRequantizeConfigThreadLocalStore; + +void RequantizeConfig::EnterRequantizeConfigScope(const RequantizeConfig& build_config) { + TVMRequantizeConfigThreadLocalEntry* entry = TVMRequantizeConfigThreadLocalStore::Get(); + entry->context_stack.push(build_config); +} + +void RequantizeConfig::ExitRequantizeConfigScope() { + TVMRequantizeConfigThreadLocalEntry* entry = TVMRequantizeConfigThreadLocalStore::Get(); + entry->context_stack.pop(); +} + +RequantizeConfig& RequantizeConfig::Current() { + TVMRequantizeConfigThreadLocalEntry* entry = TVMRequantizeConfigThreadLocalStore::Get(); + if (entry->context_stack.size() > 0) { + return entry->context_stack.top(); + } + + return entry->default_config; +} + +TVM_REGISTER_NODE_TYPE(RequantizeConfigNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* op = static_cast(ref.get()); + p->stream << "requantize_config("; + p->stream << "rounding==" << op->get_rounding() << ", "; + p->stream << "compute_dtype==" << op->get_compute_dtype(); + p->stream << ")"; + }); + +TVM_REGISTER_GLOBAL("relay._requantize._GetCurrentRequantizeConfig") + .set_body_typed([]() -> RequantizeConfig { return RequantizeConfig::Current(); }); + +TVM_REGISTER_GLOBAL("relay._requantize._EnterRequantizeConfigScope") + .set_body_typed(RequantizeConfig::EnterRequantizeConfigScope); + +TVM_REGISTER_GLOBAL("relay._requantize._ExitRequantizeConfigScope") + .set_body_typed(RequantizeConfig::ExitRequantizeConfigScope); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/op/requantize_config.h b/src/relay/qnn/op/requantize_config.h new file mode 100644 index 000000000000..f1cd9219c32b --- /dev/null +++ b/src/relay/qnn/op/requantize_config.h @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/requantize_config.h + * \brief QNN requantize config. + */ + +#ifndef TVM_RELAY_QNN_OP_REQUANTIZE_CONFIG_H_ +#define TVM_RELAY_QNN_OP_REQUANTIZE_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include + +#include "../../op/op_common.h" + +namespace tvm { +namespace relay { +namespace qnn { + +class RequantizeConfig; +/*! + * \brief Container for build configuration options + */ +class RequantizeConfigNode : public Object { + std::string rounding; + std::string compute_dtype; + + public: + explicit RequantizeConfigNode(bool is_default = false) : is_default(is_default) {} + + std::string get_rounding() const { + if (!rounding.empty()) return rounding; + return "UPWARD"; + } + + std::string get_compute_dtype() const { + if (!compute_dtype.empty()) return compute_dtype; + + // For the x86 architecture, the float32 computation is expected to give significant speedup, + // with little loss in the accuracy of the requantize operation. + auto target = Target::Current(true); + auto target_has_sse41 = tvm::runtime::Registry::Get("tvm.topi.x86.utils.target_has_sse41"); + ICHECK(target_has_sse41) << "Function tvm.topi.x86.utils.target_has_sse41 not found"; + if (target.defined() && target->kind->name == "llvm" && + (target->GetAttr("mcpu") && + (*target_has_sse41)(target->GetAttr("mcpu").value()))) { + return "float32"; + } + + return "int64"; + } + + const bool is_default = false; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("rounding", &rounding); + v->Visit("compute_dtype", &compute_dtype); + } + + static constexpr const char* _type_key = "relay.qnn.op.RequantizeConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(RequantizeConfigNode, Object); +}; + +/*! + * \brief Container for build configuration options + */ +class RequantizeConfig : public ObjectRef { + public: + RequantizeConfig() {} + explicit RequantizeConfig(ObjectPtr n) : ObjectRef(n) {} + + const RequantizeConfigNode* operator->() const { + return static_cast(get()); + } + + RequantizeConfigNode* operator->() { return static_cast(get_mutable()); } + + /*! + * \brief Push a new BuildConfig context onto the thread local stack. + * \param build_config The configuration to set as the current context. + */ + static void EnterRequantizeConfigScope(const RequantizeConfig& requantize_config); + + /*! + * \brief Pop a build config off the thread local context stack, restoring the previous + * configuration as the current context. + */ + static void ExitRequantizeConfigScope(); + + /*! + * \brief Get the current BuildConfig context from thread local storage, or a default + * configuration if a BuildConfig scope has not been entered. + * \return The configuration that is the current context. + */ + static RequantizeConfig& Current(); + + using ContainerType = RequantizeConfigNode; +}; + +} // namespace qnn +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_QNN_OP_REQUANTIZE_CONFIG_H_ diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index 982efa0a61c1..7dfd788d96c6 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -199,6 +199,22 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, return Cast(tensor, DataType::Int(32)); } +std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value, + const bool is_cfg_default, const std::string& name) { + if (arg_value == "None") { + return cfg_value; + } else { + if (!is_cfg_default && arg_value != cfg_value) { + DLOG(INFO) << "The value of parameter \"" << name + << "\" from the non-default requantize config will not be used. The value " + "provided from " + "requantize function argument will be used instead. The value used is \"" + << arg_value << "\"."; + } + return arg_value; + } +} + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index c8f3524d51ea..0f3645a9882a 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -35,6 +35,8 @@ #include #include +#include "./op/requantize_config.h" + namespace tvm { namespace relay { namespace qnn { @@ -98,13 +100,21 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype); +std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value, + const bool is_cfg_default, const std::string& name); + static inline Expr Requantize(const Expr& data, const Array& input_shape, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, - const DataType& out_dtype, const std::string& rounding = "UPWARD") { + const DataType& out_dtype, const std::string& rounding = "None", + const std::string& compute_dtype = "None") { auto attrs = make_object(); - attrs->rounding = std::move(rounding); attrs->out_dtype = std::move(out_dtype); + const RequantizeConfig& cfg = RequantizeConfig::Current(); + attrs->rounding = + SelectRequntizeParameter(rounding, cfg->get_rounding(), cfg->is_default, "rounding"); + attrs->compute_dtype = SelectRequntizeParameter(compute_dtype, cfg->get_compute_dtype(), + cfg->is_default, "compute_dtype"); return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point, attrs.operator->(), input_shape, attrs->out_dtype); } diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 16a23a4ba699..7d2657eb04f2 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -565,6 +565,11 @@ inline Expr Round(Expr x) { return Call(op, {x}, Attrs(), {}); } +inline Expr Floor(Expr x) { + static const Op& op = Op::Get("floor"); + return Call(op, {x}, Attrs(), {}); +} + inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); } inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { @@ -662,11 +667,31 @@ static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { return Call(op, {condition, x, y}); } +static inline Expr LogicalOr(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("logical_or"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { static const Op& op = Op::Get("greater_equal"); return Call(op, {lhs, rhs}, Attrs(), {}); } +static inline Expr Equal(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("equal"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr Less(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("less"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr IsFinite(const Expr x) { + static const Op& op = Op::Get("isfinite"); + return Call(op, {x}, Attrs(), {}); +} + static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype); } diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 0f512df25cdf..64306476dfe9 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -22,14 +22,15 @@ from tvm.contrib import graph_executor roundings = ["UPWARD", "TONEAREST"] +compute_dtypes = ["float32", "float64", "int64"] -def verify(mod, goldens): +def verify(mod, goldens, target="llvm"): with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build(mod, "llvm", params=None) + graph, lib, params = relay.build(mod, target, params=None) golden_data, golden_output = goldens rt_mod = graph_executor.create(graph, lib, device=tvm.cpu(0)) - rt_mod.set_input("quantized_data", golden_data) + rt_mod.set_input("input_data", golden_data) rt_mod.set_input(**params) rt_mod.run() res = rt_mod.get_output(0).numpy() @@ -44,10 +45,11 @@ def get_mod( output_scale, input_zero_point=0, output_zero_point=0, - rounding="TONEAREST", + rounding="None", + compute_dtype="None", axis=0, ): - quantized_data = relay.var("quantized_data", shape=data_shape, dtype=data_dtype) + input_data = relay.var("input_data", shape=data_shape, dtype=data_dtype) if isinstance(input_scale, float): input_scale_expr = relay.const(input_scale, "float32") else: @@ -59,13 +61,14 @@ def get_mod( input_zero_point_expr = relay.const(np.array(input_zero_point).astype("int32")) mod = relay.qnn.op.requantize( - quantized_data, + input_data, input_scale=input_scale_expr, input_zero_point=input_zero_point_expr, output_scale=relay.const(output_scale, "float32"), output_zero_point=relay.const(output_zero_point, "int32"), axis=axis, rounding=rounding, + compute_dtype=compute_dtype, out_dtype=out_dtype, ) @@ -78,327 +81,344 @@ def test_same_scale(): # Have same scales, everything within range golden_data = np.arange(-100, 100, 1).astype("int32") golden_output = golden_data - - for rounding in roundings: - mod = get_mod( - data_shape=(200,), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(200,), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_scalar_same_scale(): # Have same scales, everything within range golden_data = np.array(-10).astype("int32") golden_output = golden_data - - for rounding in roundings: - mod = get_mod( - data_shape=(), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - ) - assert "right_shift" not in mod.astext() - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) def test_downscale(): - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([0, -1, -2], [9, 16, 7]) - else: - golden_output = np.repeat([0, -1, -2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=4, - rounding=rounding, - ) - - # Try positive values - # 2I corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + rounding=rounding, + compute_dtype=compute_dtype, ) - else: - golden_output = np.repeat( - [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=4, + rounding=rounding, ) - verify(mod, (golden_data, golden_output)) - - # Try uint8 out_dtype - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="uint8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) - - # Try uint8 in_dtyope and uint8 out_dtype - mod = get_mod( - data_shape=(32,), - data_dtype="uint8", - out_dtype="uint8", - input_scale=1, - output_scale=16, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - verify(mod, (golden_data, golden_output)) + + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 4, 1] + ) + else: + golden_output = np.repeat( + [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 4, 2] + ) + verify(mod, (golden_data, golden_output)) + + # Try uint8 out_dtype + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="uint8", + input_scale=1, + output_scale=16, + rounding=rounding, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + # Try uint8 in_dtyope and uint8 out_dtype + mod = get_mod( + data_shape=(32,), + data_dtype="uint8", + out_dtype="uint8", + input_scale=1, + output_scale=16, + rounding=rounding, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) def test_upscale(): - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=2, - output_scale=1, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(2, golden_data) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=2, + output_scale=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(2, golden_data) + verify(mod, (golden_data, golden_output)) def test_non_power_of_two(): - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=3, - rounding=rounding, - ) - - # Try positive values - golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) - golden_output = np.arange(0, 32, 1) - verify(mod, (golden_data, golden_output)) - - # Try negative values - golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) - golden_output = np.arange(0, -32, -1) - verify(mod, (golden_data, golden_output)) - - # Try a different scale - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=3, - output_scale=1, - rounding=rounding, - ) - - # Try positive values - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) - - # Try negative values - golden_data = np.arange(0, -32, -1).astype("int32") - golden_output = np.multiply(golden_data, 3) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=3, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3) + golden_output = np.arange(0, 32, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3) + golden_output = np.arange(0, -32, -1) + verify(mod, (golden_data, golden_output)) + + # Try a different scale + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=3, + output_scale=1, + rounding=rounding, + ) + + # Try positive values + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.multiply(golden_data, 3) + verify(mod, (golden_data, golden_output)) def test_saturation(): - for rounding in roundings: - mod = get_mod( - data_shape=(16,), - data_dtype="int32", - out_dtype="int8", - input_scale=0.5, - output_scale=0.5, - rounding=rounding, - ) - golden_data = np.arange(0, 16, 1).astype("int32") - golden_data = np.add(120, golden_data) - output = np.array( - [120, 121, 122, 123, 124, 125, 126, 127, 127, 127, 127, 127, 127, 127, 127, 127] - ) - golden_output = output - verify(mod, (golden_data, golden_output)) - - # Try negative numbers - golden_data = np.arange(0, -16, -1).astype("int32") - golden_data = np.add(-120, golden_data) - output = np.array( - [ - -120, - -121, - -122, - -123, - -124, - -125, - -126, - -127, - -128, - -128, - -128, - -128, - -128, - -128, - -128, - -128, - ] - ) - golden_output = output - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(16,), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + compute_dtype=compute_dtype, + ) + golden_data = np.arange(0, 16, 1).astype("int32") + golden_data = np.add(120, golden_data) + output = np.array( + [120, 121, 122, 123, 124, 125, 126, 127, 127, 127, 127, 127, 127, 127, 127, 127] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype("int32") + golden_data = np.add(-120, golden_data) + output = np.array( + [ + -120, + -121, + -122, + -123, + -124, + -125, + -126, + -127, + -128, + -128, + -128, + -128, + -128, + -128, + -128, + -128, + ] + ) + golden_output = output + verify(mod, (golden_data, golden_output)) def test_zero_point(): # Output zero point - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - output_zero_point=1, - rounding=rounding, - ) - - # Try positive values - # 8 corresponds to 0.5, resulting in 1 - golden_data = np.arange(0, 32, 1).astype("int32") - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) - golden_output = np.add(1, golden_output) - verify(mod, (golden_data, golden_output)) - - # Try negative values - # -8 corresponds to -0.5. For UPWARD, this is 0 - golden_data = np.arange(-32, -64, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) - else: - golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) - golden_output = np.add(1, golden_output) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + output_zero_point=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype("int32") + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For UPWARD, this is 0 + golden_data = np.arange(-32, -64, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(mod, (golden_data, golden_output)) # Input zero point - for rounding in roundings: - mod = get_mod( - data_shape=(32,), - data_dtype="int32", - out_dtype="int8", - input_scale=1, - output_scale=16, - input_zero_point=16, - rounding=rounding, - ) - - # Try positive values - golden_data = np.arange(32, 64, 1).astype("int32") - golden_output = np.repeat([2, 3, 4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) - - # Try negative values - golden_data = np.arange(-32, -64, -1).astype("int32") - if rounding == "UPWARD": - golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) - else: - golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) - golden_output = np.subtract(golden_output, 1) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + input_zero_point=16, + rounding=rounding, + compute_dtype=compute_dtype, + ) + + # Try positive values + golden_data = np.arange(32, 64, 1).astype("int32") + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(-32, -64, -1).astype("int32") + if rounding == "UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(mod, (golden_data, golden_output)) def test_per_channel_same_scale(): # Have same scales, everything within range golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2)) golden_output = golden_data - - for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.5], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Change axis golden_data = np.arange(-10, 10, 1).astype("int32").reshape((2, 2, 5)) golden_output = golden_data - for rounding in roundings: - mod = get_mod( - data_shape=(2, 2, 5), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.5], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(2, 2, 5), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.5], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) def test_per_channel_different_scale(): @@ -406,17 +426,19 @@ def test_per_channel_different_scale(): golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2)) golden_output = np.array([-5, -2, -3, -1, -1, 0, 1, 1, 3, 2]).reshape((5, 2)) - for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.25], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Change axis golden_data = np.arange(-20, 20, 2).astype("int32").reshape((2, 2, 5)) @@ -424,33 +446,113 @@ def test_per_channel_different_scale(): [-20, -18, -16, -14, -12, -5, -4, -3, -2, -1, 0, 2, 4, 6, 8, 5, 6, 7, 8, 9] ).reshape((2, 2, 5)) - for rounding in roundings: - mod = get_mod( - data_shape=(2, 2, 5), - data_dtype="int32", - out_dtype="int8", - input_scale=[0.5, 0.25], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(2, 2, 5), + data_dtype="int32", + out_dtype="int8", + input_scale=[0.5, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) # Have input scale > output scale golden_data = np.arange(-5, 5, 1).astype("int32").reshape((5, 2)) golden_output = np.array([-10, -2, -6, -1, -2, 0, 2, 1, 6, 2]).reshape((5, 2)) + for compute_dtype in compute_dtypes: + for rounding in roundings: + mod = get_mod( + data_shape=(5, 2), + data_dtype="int32", + out_dtype="int8", + input_scale=[1.0, 0.25], + output_scale=0.5, + axis=1, + rounding=rounding, + compute_dtype=compute_dtype, + ) + verify(mod, (golden_data, golden_output)) + + +def test_default_cfg_and_no_args(): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + ) + golden_data = np.arange(0, -32, -1).astype("int32") + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + verify(mod, (golden_data, golden_output)) + + +def test_non_default_cfg_and_no_args(): + for rounding_cfg in roundings: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding_cfg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + +def test_default_cfg_and_args(): for rounding in roundings: - mod = get_mod( - data_shape=(5, 2), - data_dtype="int32", - out_dtype="int8", - input_scale=[1.0, 0.25], - output_scale=0.5, - axis=1, - rounding=rounding, - ) - verify(mod, (golden_data, golden_output)) + with relay.qnn.op.requantize_config(rounding="UPWARD"): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + rounding=rounding, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) + + +def test_non_default_cfg_and_args(): + for rounding_arg in roundings: + for rounding_cfg in roundings: + with relay.qnn.op.requantize_config(rounding=rounding_cfg): + mod = get_mod( + data_shape=(32,), + data_dtype="int32", + out_dtype="int8", + input_scale=1, + output_scale=16, + rounding=rounding_arg, + ) + + golden_data = np.arange(0, -32, -1).astype("int32") + + if rounding_arg == "UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(mod, (golden_data, golden_output)) if __name__ == "__main__": @@ -463,3 +565,7 @@ def test_per_channel_different_scale(): test_zero_point() test_per_channel_same_scale() test_per_channel_different_scale() + test_default_cfg_and_no_args() + test_non_default_cfg_and_no_args() + test_default_cfg_and_args() + test_non_default_cfg_and_args() From a0c95f8c52690376bebcdad3baae4728af30e7ad Mon Sep 17 00:00:00 2001 From: Pranav Jonnalagadda <52756956+pranavjon@users.noreply.github.com> Date: Wed, 26 Jan 2022 04:25:15 -0800 Subject: [PATCH 11/49] [Relay][DefuseOps pass] bug fix: To support function body types other than call node (#10069) Co-authored-by: pranav jonnalagadda-SJ1 Eng_ML --- src/relay/transforms/defuse_ops.cc | 18 ++- tests/python/relay/test_pass_defuse_ops.py | 151 ++++++++++++++++++++- 2 files changed, 157 insertions(+), 12 deletions(-) diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc index d7a9bfde57c3..0d97d5a7b75c 100644 --- a/src/relay/transforms/defuse_ops.cc +++ b/src/relay/transforms/defuse_ops.cc @@ -55,17 +55,15 @@ class DefuseOpsMutator : public ExprMutator { if (const auto* call = new_n.as()) { if (const auto* func = call->op.as()) { - if (func->body->IsInstance()) { - std::unordered_map name_to_args; - for (size_t i = 0; i < func->params.size(); ++i) { - const std::string& pname = func->params[i]->name_hint(); - ICHECK(name_to_args.cend() == name_to_args.find(pname)) - << "Found multiple parameters share the same variable name `" << pname - << "` which introduces uncertainty in DefuseOps pass"; - name_to_args[pname] = call->args[i]; - } - return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body); + std::unordered_map name_to_args; + for (size_t i = 0; i < func->params.size(); ++i) { + const std::string& pname = func->params[i]->name_hint(); + ICHECK(name_to_args.cend() == name_to_args.find(pname)) + << "Found multiple parameters share the same variable name `" << pname + << "` which introduces uncertainty in DefuseOps pass"; + name_to_args[pname] = call->args[i]; } + return FuncBodyMutator(std::move(name_to_args)).Mutate(func->body); } } return new_n; diff --git a/tests/python/relay/test_pass_defuse_ops.py b/tests/python/relay/test_pass_defuse_ops.py index 2312b2d9ec47..f123bd582b87 100644 --- a/tests/python/relay/test_pass_defuse_ops.py +++ b/tests/python/relay/test_pass_defuse_ops.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy +import pytest import tvm from tvm import relay from tvm.relay import transform @@ -63,6 +65,151 @@ def before(dshape): assert tvm.ir.structural_equal(x, defused) +def test_defuse_complex(): + """Complex defuse testcase""" + + def fused_conv2d_batch_norm(w): + data = relay.var("data", shape=(1, 224, 224, 3)) + bn_gamma0 = relay.var("bn_gamma0", relay.TensorType((64,), "float32")) + bn_beta0 = relay.var("bn_beta0", relay.TensorType((64,), "float32")) + bn_mmean0 = relay.var("bn_mean0", relay.TensorType((64,), "float32")) + bn_mvar0 = relay.var("bn_var0", relay.TensorType((64,), "float32")) + c0 = relay.nn.conv2d( + data, + w, + strides=(2, 2), + padding=(3, 3, 3, 3), + channels=64, + kernel_size=(7, 7), + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + c1 = relay.nn.batch_norm(c0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3) + c2 = c1[0] + return relay.Function(relay.analysis.free_vars(c2), c2) + + def fused_conv2d_batch_norm_relu(z): + data2 = relay.var("data2", shape=(1, 56, 56, 64)) + bn_gamma0 = relay.var("bn_gamma0", relay.TensorType((64,), "float32")) + bn_beta0 = relay.var("bn_beta0", relay.TensorType((64,), "float32")) + bn_mmean0 = relay.var("bn_mean0", relay.TensorType((64,), "float32")) + bn_mvar0 = relay.var("bn_var0", relay.TensorType((64,), "float32")) + c0 = relay.nn.conv2d( + data2, + z, + padding=(1, 1, 1, 1), + channels=64, + kernel_size=(3, 3), + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + c1 = relay.nn.batch_norm(c0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3) + c2 = c1[0] + c3 = relay.nn.relu(data=c2) + return relay.Function(relay.analysis.free_vars(c3), c3) + + def fused_max_pool2d(): + data1 = relay.var("data1", shape=(1, 112, 112, 64)) + a1 = relay.nn.max_pool2d( + data1, + pool_size=(3, 3), + strides=(2, 2), + padding=(1, 1, 1, 1), + layout="NHWC", + out_layout="NHWC", + ) + return relay.Function(relay.analysis.free_vars(a1), a1) + + def fused_add_relu(): + data1 = relay.var("data1", shape=(1, 56, 56, 64)) + data2 = relay.var("data2", shape=(1, 56, 56, 64)) + a0 = relay.add(data1, data2) + a1 = relay.nn.relu(a0) + return relay.Function(relay.analysis.free_vars(a1), a1) + + def before_fused(conv_layer1_weight, conv_layer2_weight): + data = relay.var("data", shape=(1, 3, 224, 224)) + data1 = relay.layout_transform(data, src_layout="NCHW", dst_layout="NHWC") + bn_gamma0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + bn_beta0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + bn_mmean0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + bn_mvar0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + a0 = fused_conv2d_batch_norm(conv_layer1_weight) + a1 = fused_max_pool2d() + a2 = fused_conv2d_batch_norm_relu(conv_layer2_weight) + a3 = fused_add_relu() + y0 = relay.Call(a0, [data1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0]) + y1 = relay.Call(a1, [y0]) + y2 = relay.Call(a2, [y1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0]) + y3 = relay.Call(a3, [y1, y2]) + return relay.Function(relay.analysis.free_vars(y3), y3) + + def golden_defused(conv_layer1_weight, conv_layer2_weight): + data = relay.var("data", shape=(1, 3, 224, 224)) + data1 = relay.layout_transform(data, src_layout="NCHW", dst_layout="NHWC") + bn_gamma0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + bn_beta0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + bn_mmean0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + bn_mvar0 = relay.const(tvm.nd.array(numpy.ndarray(shape=(64,), dtype="float32"))) + c0 = relay.nn.conv2d( + data1, + conv_layer1_weight, + strides=(2, 2), + padding=(3, 3, 3, 3), + channels=64, + kernel_size=(7, 7), + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + c1 = relay.nn.batch_norm(c0, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3) + c2 = c1[0] + c3 = relay.nn.max_pool2d( + c2, + pool_size=(3, 3), + strides=(2, 2), + padding=(1, 1, 1, 1), + layout="NHWC", + out_layout="NHWC", + ) + c4 = relay.nn.conv2d( + c3, + conv_layer2_weight, + padding=(1, 1, 1, 1), + channels=64, + kernel_size=(3, 3), + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + c5 = relay.nn.batch_norm(c4, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0, axis=3) + c6 = c5[0] + c7 = relay.nn.relu(c6) + c8 = relay.add(c3, c7) + c9 = relay.nn.relu(c8) + return relay.Function(relay.analysis.free_vars(c9), c9) + + # creating weight constants for the two convolution layers + # in the input fused model and the golden defused model. + conv_layer1_weight = relay.nn.Constant( + tvm.nd.array(numpy.ndarray(shape=(64, 7, 7, 3), dtype="float32")) + ) + conv_layer2_weight = relay.nn.Constant( + tvm.nd.array(numpy.ndarray(shape=(64, 3, 3, 64), dtype="float32")) + ) + x = before_fused(conv_layer1_weight, conv_layer2_weight) + x = run_opt_pass(x, transform.InferType()) + defused = run_opt_pass(x, transform.DefuseOps()) + + golden1 = golden_defused(conv_layer1_weight, conv_layer2_weight) + golden1 = run_opt_pass(golden1, transform.InferType()) + + assert tvm.ir.structural_equal(defused, golden1), ( + "Actual = \n" + str(defused) + "\nGolden = \n" + str(golden1) + ) + + if __name__ == "__main__": - test_defuse_simple() - test_inception_like() + pytest.main([__file__]) From e7705d709935b6cb3b095bf29cd8807b487151e6 Mon Sep 17 00:00:00 2001 From: ninesheep Date: Wed, 26 Jan 2022 20:27:30 +0800 Subject: [PATCH 12/49] [Fix Bug]fix the bug of tensorflow frontend when parsing Range layer (#9999) Co-authored-by: wangjiuyang --- python/tvm/relay/frontend/tensorflow_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index df8b7438af88..9b36d712e9ec 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -2454,6 +2454,7 @@ def _impl(inputs, attr, params, mod): delta = inputs[2] # if all attributes are constant, evalute the range function and return relay.const + dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype) if all( [ isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)), @@ -2461,9 +2462,8 @@ def _impl(inputs, attr, params, mod): isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)), ] ): - return tvm.relay.const(list(range(int(start), int(limit), int(delta)))) + return tvm.relay.const(list(range(int(start), int(limit), int(delta))), dtype=dtype) - dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype) if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)): start = _expr.const(start, dtype=dtype) if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)): From 887779bc4fdaf3659f206bdeef527537aa9d7830 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Wed, 26 Jan 2022 21:10:51 +0800 Subject: [PATCH 13/49] [MetaSchedule][M4a] Schedule Rule: Multi-Level-Tiling (#10043) * multi level tiling * remove tensor core related code * pylint * fix Co-authored-by: Junru Shao --- include/tvm/meta_schedule/schedule_rule.h | 6 +- include/tvm/tir/stmt.h | 14 + .../meta_schedule/schedule_rule/__init__.py | 1 + .../schedule_rule/multi_level_tiling.py | 84 ++++ .../meta_schedule/testing/schedule_rule.py | 37 ++ .../schedule_rule/multi_level_tiling.cc | 416 ++++++++++++++++++ src/support/array.h | 23 + src/tir/schedule/analysis.h | 14 + src/tir/schedule/analysis/analysis.cc | 27 ++ ...hedule_schedule_rule_multi_level_tiling.py | 280 ++++++++++++ 10 files changed, 898 insertions(+), 4 deletions(-) create mode 100644 python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling.cc create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 3911a5254290..1675bcce05ed 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -137,9 +137,8 @@ class ScheduleRule : public runtime::ObjectRef { * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: * - NullOpt on CPU * - [blockIdx.x, vthread.x, threadIdx.x] on GPU - * \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching. + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. * NullOpt means disable vectorization * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. @@ -147,9 +146,8 @@ class ScheduleRule : public runtime::ObjectRef { */ TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // Optional> tile_binds, // - bool use_tensor_core, // Optional max_innermost_factor, // - Optional vector_load_max_len, // + Optional> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write); /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0a05439b2341..edb789b0bd7f 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1364,6 +1364,20 @@ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; +/*! + * \brief Mark that the loop should be further skip and bound to environment threads to enable + * cooperative fetching. + */ +constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_low_inclusive = + "meta_schedule.thread_extent_low_inclusive"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_high_inclusive = + "meta_schedule.thread_extent_high_inclusive"; + /*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ constexpr const char* meta_schedule_random_compute_producer = "meta_schedule.random_compute_producer"; diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index ce66323fd15b..b0fe8c8bdd75 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -19,6 +19,7 @@ from .add_rfactor import AddRFactor from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction +from .multi_level_tiling import MultiLevelTiling, ReuseType from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py new file mode 100644 index 000000000000..2ff49168d0c6 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Multi-level tiling with reuse.""" +from typing import Any, Dict, List, NamedTuple, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +class ReuseType(NamedTuple): + """Reuse type.""" + + req: str + levels: List[int] + scope: str + + def as_dict(self) -> Dict[str, Any]: + """Return the dict representation of the reuse type.""" + return { + "req": self.req, + "levels": self.levels, + "scope": self.scope, + } + + +@register_object("meta_schedule.MultiLevelTiling") +class MultiLevelTiling(ScheduleRule): + """Multi-level tiling with reuse. + + Parameters + ---------- + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 464d2496a603..b149f20c52e3 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -19,8 +19,10 @@ AddRFactor, AutoInline, CrossThreadReduction, + MultiLevelTiling, ParallelizeVectorizeUnroll, RandomComputeLocation, + ReuseType, ScheduleRule, ) from tvm.target import Target @@ -65,6 +67,41 @@ def cross_thread_reduction(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") +def multi_level_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling and reuse""" + if target.kind.name == "llvm": + return MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ) + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def random_compute_location(target: Target) -> ScheduleRule: """Default schedule rules for with random-compute-location""" if target.kind.name == "llvm": diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc new file mode 100644 index 000000000000..d0bfff40fcbe --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { +/*! + * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction + * buffers' dimensions as -1 + * \param block_sref The block to be processed + * \return The buffer dimensions for all the read buffers of a block, except for reduction buffers + * \note The method is not designed for generic analysis and relies on assumptions in the scenario + * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header + */ +std::vector GetReadBufferNDims(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + int n = block->reads.size(); + std::vector results(n, -1); + for (int i = 0; i < n; ++i) { + const BufferNode* read_buffer = block->reads[i]->buffer.get(); + if (read_buffer != write_buffer) { + results[i] = read_buffer->shape.size(); + } + } + return results; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::ExprRV; +using tir::IterVarType; +using tir::LoopRV; +using tir::Schedule; + +/*! + * \brief Configuration of data reuse type: + * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. + * 1) kMayReuse: reuse is allowed, but no reuse is explored. + * 2) kMustReuse: reuse is allowed and no reuse is not explored. + */ +enum class ReuseType : int32_t { + kNoReuse = 0, + kMayReuse = 1, + kMustReuse = 2, +}; + +/*! + * \brief Converts a string to ReuseType. + * \param str The string to be converted. + * \return The converted ReuseType. + */ +ReuseType Str2ReuseType(const String& str) { + if (str == "no") { + return ReuseType::kNoReuse; + } else if (str == "may") { + return ReuseType::kMayReuse; + } else if (str == "must") { + return ReuseType::kMustReuse; + } else { + LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; + throw; + } +} + +/*! \brief Configuration of data reuse patterns */ +struct ReuseConfig { + /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ + ReuseType req; + /*! \brief Which levels are caching stage inserted at */ + std::vector levels; + /*! \brief The storage scope */ + String scope; + + /*! \brief Default constructor: no data reuse */ + ReuseConfig() : req(ReuseType::kNoReuse) {} + + /*! \brief Construct from a configuration dictionary */ + explicit ReuseConfig(const Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { + ICHECK_EQ(config.size(), 3); + } +}; + +/*! \brief The state of auto scheduling for the multi-level tiling rule */ +struct State { + /*! \brief The schedule to date */ + Schedule sch; + /*! \brief The block to be tiled */ + BlockRV block_rv; + /*! \brief The loop tiles */ + Array> tiles; + + /*! \brief Default constructor */ + explicit State(Schedule sch, BlockRV block_rv, Optional write_cache = NullOpt, + bool write_cache_is_added = false, Array> tiles = {}) + : sch(sch), block_rv(block_rv), tiles(tiles) {} +}; + +/*! + * \brief Helper to apply a sub-rule to a list of auto scheduling states + * \tparam FLambda The type of the sub-rule functor + * \param states The list of states to be applied + * \return The list of states after applying the sub-rule + */ +template +std::vector SubRule(std::vector states, FLambda sub_rule) { + std::vector results; + for (auto&& state : states) { + std::vector next = sub_rule(std::move(state)); + results.insert(results.end(), // + std::make_move_iterator(next.begin()), // + std::make_move_iterator(next.end())); + } + return results; +} + +/*! + * \brief The mega rule: multi-level tiling with data reuse + */ +class MultiLevelTilingNode : public ScheduleRuleNode { + public: + // SubRule 1. add write cache + inline std::vector AddWriteReuse(State state) const; + // SubRule 2. tile the loop nest + inline std::vector TileLoopNest(State state) const; + // SubRule 3. add read cache + inline std::vector AddReadReuse(State state) const; + + // Do nothing; Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { + this->max_threads_per_block_ = v.value()->value; + if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + this->thread_warp_size_ = v.value()->value; + } else { + LOG(INFO) << "'thread_warp_size' is not defined in the target"; + } + } + } + + // Entry of the mega rule; Inherited from ScheduleRuleNode + Array Apply(const Schedule& sch, const BlockRV& block_rv) final { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); + + std::vector states{State(sch, block_rv)}; + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + Array results; + for (auto&& state : states) { + results.push_back(std::move(state.sch)); + } + return results; + } + + public: + /*! + * \brief The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + */ + String structure; + /*! \brief For each level of tiles, which thread axis it is bound to */ + Array tile_binds; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The length of vector lane in vectorized cooperative fetching */ + std::vector vector_load_lens; + /*! \brief Data reuse configuration for reading */ + ReuseConfig reuse_read_; + /*! \brief Data reuse configuration for writing */ + ReuseConfig reuse_write_; + /*! \brief The indices of spatial tiles in `structure` */ + std::vector s_indices_; + /*! \brief The indices of reduction tiles in `structure` */ + std::vector r_indices_; + /*! \brief The size of the thread warp */ + int thread_warp_size_; + /*! \brief The maximum number of threads to be used size of a thread warp */ + int max_threads_per_block_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("structure", &structure); + v->Visit("tile_binds", &tile_binds); + v->Visit("max_innermost_factor", &max_innermost_factor); + // `vector_load_lens` is not visited + // `reuse_read_` is not visited + // `reuse_write_` is not visited + // `s_indices_` is not visited + // `r_indices_` is not visited + // `thread_warp_size_` is not visited + // `max_threads_per_block` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); +}; + +inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + std::vector results; + if (config.req == ReuseType::kMayReuse) { + // Case 1. If the write cache is already there, we don't need to add another. + Array consumer_rvs = state.sch->GetConsumers(state.block_rv); + if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { + for (int level : config.levels) { + State new_state = state; + new_state.sch = state.sch->Copy(); + new_state.sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = new_state.tiles[level - 1].back(); + new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true); + results.push_back(std::move(new_state)); + } + results.push_back(state); + return results; + } else { + // Case 2. No write cache is added + State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv); + new_state.sch->Seed(state.sch->ForkSeed()); + results.emplace_back(std::move(new_state)); + } + } + + // Case 3. Add one write cache + BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, + /*storage_scope=*/config.scope); + for (int level : config.levels) { + State new_state = state; + new_state.sch = state.sch->Copy(); + new_state.sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = new_state.tiles[level - 1].back(); + new_state.sch->ReverseComputeAt(write_cache, loop_rv, true); + results.push_back(std::move(new_state)); + } + return results; +} + +inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const { + Schedule& sch = state.sch; + const BlockRV& block_rv = state.block_rv; + // Step 1. Assuming trivial binding, pair the loops and their iter-var-types + Array loops = sch->GetLoops(block_rv); + std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv)); + ICHECK_EQ(loops.size(), iter_types.size()); + // Step 2. For each loop axis, tile it + int64_t spatial_loop_product = 1; + std::vector> tiles(s_indices_.size() + r_indices_.size()); + for (int i = 0, n = loops.size(); i < n; ++i) { + LoopRV loop = loops[i]; + const std::vector* idx = nullptr; + if (iter_types[i] == IterVarType::kDataPar) { + idx = &s_indices_; + if (spatial_loop_product != -1) { + if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { + spatial_loop_product *= *extent; + } else { + spatial_loop_product = -1; + } + } + } else if (iter_types[i] == IterVarType::kCommReduce) { + idx = &r_indices_; + } else { + continue; + } + // Do the split + int n_tiles = idx->size(); + Array factors = sch->SamplePerfectTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*max_innermost_factor=*/max_innermost_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + } + } + // Step 3. Reorder to organize the tiles + sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); + // Step 4. Bind the tiles to threads + int n_binds = std::min(tile_binds.size(), tiles.size()); + for (int i = 0; i < n_binds; ++i) { + LoopRV fused = sch->Fuse(tiles[i]); + sch->Bind(fused, tile_binds[i]); + tiles[i] = {fused}; + } + state.tiles = Array>{tiles.begin(), tiles.end()}; + if (this->thread_warp_size_ != -1) { + int64_t low_inclusive = 1; + int64_t high_inclusive = this->max_threads_per_block_; + if (spatial_loop_product > 2 * this->thread_warp_size_) { + low_inclusive = this->thread_warp_size_; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive, + Integer(low_inclusive)); + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive, + Integer(high_inclusive)); + } + return {state}; +} + +inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const { + const ReuseConfig& config = this->reuse_read_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + ICHECK(config.req != ReuseType::kMayReuse); + const BlockRV& block_rv = state.block_rv; + std::vector results; + results.reserve(config.levels.size()); + for (int level : config.levels) { + Schedule sch = state.sch->Copy(); + sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = state.tiles[level - 1].back(); + // Enumerate all buffers that are read but not written + std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { + int buffer_ndim = read_buffer_ndims[i]; + if (buffer_ndim == -1) { + continue; + } + // Do cache_read + BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope); + // Insert cache_read block to the proper place + sch->ComputeAt(cache_read_block, loop_rv, true); + // Fuse the iterators of the cache_read + Array buffer_loops = sch->GetLoops(cache_read_block); + LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); + // Annotate cooperative fetching + if (!vector_load_lens.empty()) { + int n = vector_load_lens.size(); + double prob = 1.0 / n; + ExprRV vector_load_len = + sch->SampleCategorical(support::AsArray(vector_load_lens), + Array(n, FloatImm(DataType::Float(64), prob))); + sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, + vector_load_len); + } + } + State new_state = state; + new_state.sch = sch; + results.push_back(std::move(new_state)); + } + return results; +} + +// Constructor + +ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, + Optional max_innermost_factor, + Optional> vector_load_lens, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_lens = vector_load_lens.defined() + ? support::AsVector(vector_load_lens.value()) + : std::vector(); + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + n->thread_warp_size_ = -1; + n->max_threads_per_block_ = -1; + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTiling") + .set_body_typed(ScheduleRule::MultiLevelTiling); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/support/array.h b/src/support/array.h index 95b4f58a2e22..218150f9dba0 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -100,6 +100,29 @@ inline Array AsArray(const ShapeTuple& shape) { return result; } +/*! + * \brief Concatenate a list of arrays into a single array + * \tparam T The type of elements in the arrays + * \tparam Iterator The type of the iterator into the list of arrays + * \param begin The begin iterator to the array list + * \param end The end iterator to the array list + * \return The concatenated array + */ +template +inline Array ConcatArrayList(Iterator begin, Iterator end) { + int size = 0; + for (Iterator it = begin; it != end; ++it) { + size += (*it).size(); + } + Array result; + result.reserve(size); + for (Iterator it = begin; it != end; ++it) { + const auto& item = *it; + result.insert(result.end(), item.begin(), item.end()); + } + return result; +} + /********** Implementation details of AsVector **********/ namespace details { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 636cc7d0a5db..591201312cd2 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -175,6 +175,20 @@ bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Extracts the types of the block vars + * \param block_sref The block to be checked + * \return A vector of types of the block vars + */ +std::vector GetBlockVarTypes(const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be considered as a "write cache" + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a write cache + */ +bool IsWriteCache(const StmtSRef& block_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index be5e55d4ec70..1579f9154fe6 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -408,6 +408,33 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + std::vector results; + results.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + results.push_back(iter_var->iter_type); + } + return results; +} + +bool IsWriteCache(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1) { + return false; + } + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool exists, surjective, injective, ordered, no_const_read, no_shift_read; + std::tie(exists, surjective, injective, ordered, no_const_read, no_shift_read) = + AnalyzeReadWritePattern(read_region, write_region); + if (!(injective && ordered)) { + return false; + } + } + return true; +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py new file mode 100644 index 000000000000..c6a63aae7427 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -0,0 +1,280 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import ( + multi_level_tiling, +) +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.te import create_prim_func +from tvm.meta_schedule.testing import te_workload +from tvm.target import Target + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cpu_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + "b24, = sch.get_consumers(block=b0)", + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + "b24, = sch.get_consumers(block=b0)", + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + # pylint: enable=line-too-long + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cuda_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", + "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", + "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", + "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", + "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", + "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", + "l30 = sch.fuse(l9, l19)", + 'sch.bind(loop=l30, thread_axis="blockIdx.x")', + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="vthread.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="threadIdx.x")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', + 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + ] + ] + # pylint: enable=line-too-long + target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", + "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", + "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", + "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", + "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", + "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", + "l30 = sch.fuse(l9, l19)", + 'sch.bind(loop=l30, thread_axis="blockIdx.x")', + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="vthread.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="threadIdx.x")', + 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + ] + ] + # pylint: enable=line-too-long + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_matmul_relu() + test_cuda_matmul() + test_cuda_matmul_relu() From aa44d7b84086df3734bde87c2c6ea0ae1afe0f3d Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Wed, 26 Jan 2022 22:35:38 +0800 Subject: [PATCH 14/49] Revert "[Frontend] Add Span filling for frontends to Relay (#9723)" (#10072) Because of the failure of LSTM conversion from Pytorch --- python/tvm/relay/expr.py | 7 +-- python/tvm/relay/frontend/common.py | 53 ------------------ python/tvm/relay/frontend/pytorch.py | 19 ------- python/tvm/relay/frontend/tensorflow.py | 17 +++++- python/tvm/relay/frontend/tensorflow2.py | 17 +++++- python/tvm/relay/frontend/tflite.py | 16 ++---- src/printer/relay_text_printer.cc | 23 +++----- src/printer/text_printer.h | 2 +- src/relay/ir/expr.cc | 4 +- tests/python/frontend/pytorch/test_forward.py | 47 ---------------- .../frontend/tensorflow/test_forward.py | 54 ------------------- .../tensorflow2/test_sequential_models.py | 24 +-------- tests/python/frontend/tflite/test_forward.py | 54 ------------------- 13 files changed, 48 insertions(+), 289 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 598354e1b514..811e205fb2b3 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -316,13 +316,10 @@ class TupleGetItem(ExprWithOp): index: int The index. - - span: Optional[tvm.relay.Span] - Span that points to original source code """ - def __init__(self, tuple_value, index, span=None): - self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span) + def __init__(self, tuple_value, index): + self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index) @tvm._ffi.register_object("relay.RefCreate") diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index f8c12ff334db..eeede181f6f9 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -25,7 +25,6 @@ from tvm.topi.utils import get_const_tuple from .. import expr as _expr -from ..expr_functor import ExprMutator from .. import function as _function from .. import transform as _transform from .. import op as _op @@ -955,55 +954,3 @@ def try_resolve_var_to_const(x, graph_params): return _op.const(value, dtype) return x - - -def set_span(sym, node_name): - """Set up the span of relay expression(s) while converting OP""" - - class SpanFiller(ExprMutator): - """SpanFiller""" - - def __init__(self, node_name, suffix_str="_PART_"): - ExprMutator.__init__(self) - self.node_name = node_name - self.suffix_str = suffix_str - self.counter = 0 - self.distance_from_leaf = -1 - - def _create_span(self): - if self.distance_from_leaf == 0: - return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0, 0, 0, 0) - self.distance_from_leaf -= 1 - span_str = "{}{}{}".format(self.node_name, self.suffix_str, str(self.counter)) - self.counter += 1 - return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0) - - def visit_call(self, call): - if call.span is None: - self.distance_from_leaf += 1 - new_args = [self.visit(arg) for arg in call.args] - return _expr.Call( - call.op, new_args, call.attrs, call.type_args, self._create_span() - ) - return call - - def visit_tuple(self, tup): - if tup.span is None: - self.distance_from_leaf += 1 - return _expr.Tuple([self.visit(field) for field in tup.fields], self._create_span()) - return tup - - def visit_tuple_getitem(self, op): - if op.span is None: - self.distance_from_leaf += 1 - return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._create_span()) - return op - - def fill(self, sym): - if isinstance(sym, _expr.TupleWrapper): - return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size) - if isinstance(sym, _expr.RelayExpr): - return self.visit(sym) - return sym - - return SpanFiller(node_name).fill(sym) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b7188370d86e..f7538f0837c6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -45,7 +45,6 @@ from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import lstm_cell, try_infer_value, unbind -from .common import set_span from .pytorch_utils import is_version_greater_than __all__ = ["from_pytorch"] @@ -3276,9 +3275,6 @@ def body(*current_vals): def convert_operators(self, operators, outputs, ret_names): """Convert each Torch IR operators to Relay equivalent""" - # an op node might not belong to any of scope in trace info natively - # use a cunter to prevent from messing up its scope in span - empty_counter = 0 for node_name, op_node in operators: operator = op_node.kind() inputs = _get_op_inputs(op_node, outputs) @@ -3339,9 +3335,6 @@ def _handel_nested_input(inputs): relay_out = relay_op( inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) ) - span_str, empty_counter = self._get_torch_span(op_node, empty_counter) - relay_out = set_span(relay_out, span_str) - self.record_output_type(relay_out) if isinstance(relay_out, tuple): @@ -3355,18 +3348,6 @@ def _handel_nested_input(inputs): return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] - def _get_torch_span(self, node, empty_counter): - # torch span looks like - # %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu # ${torch}/nn file - # the scope part might not exist - if node.scopeName(): - scope_name_str = "jit._trace.TopLevelTracedModule: " + node.scopeName() - else: - scope_name_str = "warning: no trace info " + str(empty_counter) - empty_counter += 1 - span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str) - return span_str, empty_counter - def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c2aa5a165b3c..d35e0e1c203d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -37,7 +37,6 @@ from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value -from .common import set_span from .tensorflow_ops import _convert_map from .tensorflow_ops import _need_prelude_for_shape_inference @@ -1029,10 +1028,24 @@ def _convert_operator( else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) - sym = set_span(sym, node_name) + sym = self._set_span(sym, node_name) return sym + @staticmethod + def _set_span(sym, node_name): + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call) and sym.span is None: + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call) and tuple_value.span is None: + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) + return sym + def _licm_construct(self, loop_name, node_name): """Construct a node by considering whether it is loop invariant with the given while loop. If yes, we diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 2c8b7d4e777b..465f530624b9 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -36,7 +36,6 @@ from .. import function as _function from ..loops import while_loop as _while_loop from .common import infer_type as _infer_type -from .common import set_span from .tensorflow_ops import _convert_map as _convert_map_common from .tensorflow_ops import _get_more_static_shape_rank @@ -59,6 +58,22 @@ def _infer_type_with_prelude(val, prelude): return body.checked_type +def set_span(sym, node_name): + """set span of symbol""" + + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call): + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call): + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) + return sym + + def is_tensor_list_constuctor(tf_node): """Check whether is tensor list constructor node.""" return tf_node.op == "TensorListReserve" diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 12296bd50542..b675dd56a7bb 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -32,7 +32,6 @@ from .. import qnn as _qnn from .common import ExprTable from .common import infer_shape as _infer_shape -from .common import set_span from .common import to_int_list from .tflite_flexbuffer import FlexBufferDecoder @@ -240,17 +239,12 @@ def convert_op_to_relay(self): if len(output_tensors) == 1: tensor_idx = output_tensors[0].tensor_idx - curr_output = get_tensor_name(self.subgraph, tensor_idx) - ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output)) - self.exp_tab.set_expr(curr_output, ret) + self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret) else: - out_names = [] - for output_tensor in output_tensors: - out_names.append(get_tensor_name(self.subgraph, output_tensor.tensor_idx)) - curr_output = ", ".join(out_names) - ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output)) - for idx, out_name in enumerate(out_names): - self.exp_tab.set_expr(out_name, ret[idx]) + for idx, output_tensor in enumerate(output_tensors): + self.exp_tab.set_expr( + get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx] + ) def get_op_code_str(self, op): """Get TFLite ops string representation""" diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 7654ef17b753..fdc6c37e527a 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -389,21 +389,12 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { if (op->fields.size() == 1) { doc << ","; } - doc << ")"; - if (op->span.defined()) { - doc << " /* " << PrintSpan(op->span) << " */"; - } - return doc; + return doc << ")"; } Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) { Doc doc; - doc << Print(op->tuple) << "." << op->index; - - if (op->span.defined()) { - doc << " /* " << PrintSpan(op->span) << " */"; - } - return doc; + return doc << Print(op->tuple) << "." << op->index; } Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { @@ -977,13 +968,11 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map& return doc; } -Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) { +Doc RelayTextPrinter::PrintSpan(const Span& span) { Doc doc; - if (include_spans) { - const auto* span_node = span.as(); - ICHECK(span_node); - doc << span_node->source_name->name; - } + const auto* span_node = span.as(); + ICHECK(span_node); + doc << span_node->source_name->name; return doc; } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index ca46700d9cf5..a4d0ff30fa62 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor, */ Doc PrintMapAsAttributeValue(const Map& map); - Doc PrintSpan(const Span& span, bool include_spans = true); + Doc PrintSpan(const Span& span); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 64d921efe6a6..73ae3faf7078 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -375,8 +375,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { - return TupleGetItem(tuple, index, span); +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { + return TupleGetItem(tuple, index); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2c07094c1e9f..3fbef494f16d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -247,53 +247,6 @@ def visit(op): torch.cuda.empty_cache() -def verify_span(model_name, input_data=[], custom_convert_map={}): - if isinstance(model_name, str): - baseline_model, baseline_input = load_model(model_name) - elif isinstance(input_data, list): - baseline_model = model_name - baseline_input = input_data - elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0: - baseline_model = model_name - baseline_input = [input_data] - else: - assert False, "Unexpected input format" - - trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input]) - if isinstance(baseline_model, torch.nn.Module): - trace = trace.float().eval() - - if torch.cuda.is_available(): - trace = trace.cuda() - else: - trace = trace.cpu() - - input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] - input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) - mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - - # collect fail cases for the convenience of further improvement - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 - - -def test_span(): - verify_span("resnet18") - - # Single operator tests @tvm.testing.uses_gpu def test_forward_pixel_shuffle(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c76803b8fb3c..a5a67e149986 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -298,60 +298,6 @@ def is_gpu_available(): return False -def verify_span(mod): - # collect fail cases for the convenience of further improvement - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 - - -def simple_model(): - input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32, name="input") - - shape = tf.shape(input_node) - stack = tf.stack([shape[0], 3, 3], axis=0) - output_node = tf.reshape(input_node, stack, name="output") - return output_node - - -####################################################################### -# Span fill up -# ------- -def test_span_complement_simple_model(): - with tf.Graph().as_default() as graph: - model_graph = simple_model() - graph_def = graph.as_graph_def() - - graph_def = tf_testing.ProcessGraphDefParam(graph_def) - - mod, params = relay.frontend.from_tensorflow(graph_def, shape={"input:0", (1, 3, 3, 1)}) - verify_span(mod) - - -def test_span_complement_big_model(): - with tf.Graph().as_default() as graph: - graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") - # Call the utility to import the graph definition into default graph. - graph_def = tf_testing.ProcessGraphDefParam(graph_def) - - mod, params = relay.frontend.from_tensorflow( - graph_def, shape={"input_tensor:0", (128, 224, 224, 3)} - ) - verify_span(mod) - - ####################################################################### # Pooling # ------- diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py b/tests/python/frontend/tensorflow2/test_sequential_models.py index b76b4a714938..1b5a6342f07d 100644 --- a/tests/python/frontend/tensorflow2/test_sequential_models.py +++ b/tests/python/frontend/tensorflow2/test_sequential_models.py @@ -26,25 +26,6 @@ from common import compare_tf_tvm from common import run_tf_code -from tvm.relay.frontend.tensorflow2 import from_tensorflow - - -def verify_span(mod): - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 def run_sequential_model(model_fn, input_shape): @@ -67,10 +48,7 @@ def model_graph(model, input_shape): gdef = f.graph.as_graph_def(add_shapes=True) return gdef, _input, _output - gdef, _input, _output = model_graph(model_fn, input_shape) - mod, _ = from_tensorflow(gdef) - compare_tf_tvm(gdef, _input, _output, runtime="vm") - verify_span(mod) + compare_tf_tvm(*model_graph(model_fn, input_shape), runtime="vm") def test_dense_model(): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 77acce459fc9..60af94b53a51 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -259,59 +259,6 @@ def run_tflite_graph(tflite_model_buf, input_data): return tflite_output -def run_span_verification( - tflite_model_buf, - input_data, - input_node, - num_output=1, - target="llvm", - out_names=None, - mode="graph_executor", -): - """Generic function to compile on relay and execute on tvm""" - # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 - try: - import tflite.Model - - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - except AttributeError: - import tflite - - tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) - except ImportError: - raise ImportError("The tflite package must be installed") - - input_data = convert_to_list(input_data) - input_node = convert_to_list(input_node) - - shape_dict = {} - dtype_dict = {} - for i, e in enumerate(input_node): - shape_dict[e] = input_data[i].shape - dtype_dict[e] = input_data[i].dtype.name - - mod, _ = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) - verify_span(mod) - - -def verify_span(mod): - fail_cases = [] - mod_main_start = False - for line in str(mod.__str__).split("\n"): - if "@main" in line: - mod_main_start = True - continue - - if mod_main_start == True: - if "}" == line: - break - elif not ("/*" in line and "*/" in line): - fail_cases.append(line) - - print(fail_cases) - assert len(fail_cases) == 0 - - def compare_tflite_with_tvm( in_data, in_name, @@ -4620,7 +4567,6 @@ def test_forward_tflite2_qnn_resnet50(): tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] - run_span_verification(tflite_model_buf, np.array(data), "input_1") tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1") tvm_predictions = np.squeeze(tvm_output) tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] From 4a15db2020e9568eab912a121f636b74bec6a0eb Mon Sep 17 00:00:00 2001 From: Ophir Frish Date: Wed, 26 Jan 2022 19:11:22 +0200 Subject: [PATCH 15/49] Improve the tensorflow frontend _test_spop_resource_variables to support tensoflow 2.6 (#9978) On tensorflow 2.4 the test is expected to fail as the generated graph is not forzen. On tensorflow 2.6 the generated graph is identified as frozen, therefore the test is not needed --- tests/python/frontend/tensorflow/test_forward.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index a5a67e149986..19b1fbd0eb91 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -5391,7 +5391,12 @@ def resourceVariablesTest(x, y): def test_forward_spop(): _test_spop_stateful() _test_spop_device_assignment() - _test_spop_resource_variables() + # tensorflow version upgrade support + # This test is expected to fail in TF version >= 2.6 + # as the generated graph will be considered frozen, hence + # not passing the criteria for the test below. + if tf.__version__ < LooseVersion("2.6.1"): + _test_spop_resource_variables() # Placeholder test cases _test_spop_placeholder_without_shape_info() From b1812bbf167559ac0bca0b75dab838433a51d0fa Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 27 Jan 2022 03:55:29 +0800 Subject: [PATCH 16/49] [MetaSchedule] postproc: rewrite_parallel_vectorize_unroll (#10071) Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin --- python/tvm/meta_schedule/postproc/__init__.py | 1 + .../rewrite_parallel_vectorize_unroll.py | 33 ++ .../rewrite_parallel_vectorize_unroll.cc | 399 ++++++++++++++++++ ...tproc_rewrite_parallel_vectorize_unroll.py | 87 ++++ 4 files changed, 520 insertions(+) create mode 100644 python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py create mode 100644 src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc create mode 100644 tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index eaab8c7bd484..0c914ac809f9 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -17,6 +17,7 @@ """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc from .disallow_dynamic_loop import DisallowDynamicLoop +from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..abe7288acba9 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that applies parallelization, vectorization and auto unrolling +according to the annotation of each block""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteParallelVectorizeUnroll") +class RewriteParallelVectorizeUnroll(Postproc): + """A postprocessor that applies parallelization, vectorization and auto unrolling + according to the annotation of each block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc new file mode 100644 index 000000000000..69e8dfb858bc --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check whether the loop has any annotation + * \param sref The sref of loop + * \return Whether the loop has any annotation + */ +inline bool HasAnnOrBinding(const ForNode* loop) { + return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty(); +} + +/*! \brief The visitor for extracting the stride of a var in a PrimExpr. */ +class StrideExtractor : public ExprVisitor { + public: + /*! + * \brief Extracting the stride of a var in a PrimExpr. + * e.g the stride of `x` in `(x * 2 + 1) * 3 + 1` is 6 + * \param expr The given PrimExpr. + * \param var The target var. + * \return The stride of the var. + */ + static int64_t Extract(const PrimExpr& expr, const Var& var) { + StrideExtractor extractor(var); + extractor.VisitExpr(expr); + return extractor.strides_[expr.get()]; + } + + private: + explicit StrideExtractor(const Var& var) : var_(var) {} + + void VisitExpr_(const MulNode* node) final { + ExprVisitor::VisitExpr_(node); + + if (const auto* a = node->a.as()) { + if (strides_.count(node->b.get())) { + strides_[node] = strides_[node->b.get()] * a->value; + } + } else if (const auto* b = node->b.as()) { + if (strides_.count(node->a.get())) { + strides_[node] = strides_[node->a.get()] * b->value; + } + } + } + + void VisitExpr_(const AddNode* node) final { + ExprVisitor::VisitExpr_(node); + int64_t stride_a, stride_b; + if (strides_.count(node->a.get())) { + stride_a = strides_[node->a.get()]; + } else { + stride_a = INT64_MAX; + } + if (strides_.count(node->b.get())) { + stride_b = strides_[node->b.get()]; + } else { + stride_b = INT64_MAX; + } + if (stride_a != INT64_MAX || stride_b != INT64_MAX) { + strides_[node] = std::min(stride_a, stride_b); + } + } + + void VisitExpr_(const VarNode* node) final { + if (node == var_.get()) { + strides_[node] = 1; + } + } + + const Var& var_; + std::unordered_map strides_; +}; + +struct ParsedAnnotation { + int max_parallel_extent; + int max_vectorize_extent; + int unroll_explicit; + int unroll_implicit; + int num_parallel_loops; + int num_vectorize_loops; +}; + +bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { + bool found = false; + *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; + for (const auto& ann : block->annotations) { + if (ann.first == attr::meta_schedule_parallel) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_parallel_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_vectorize) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_vectorize_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_explicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_explicit = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_implicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_implicit = imm->value; + } + } + } + return found; +} + +void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { + if (parsed.max_parallel_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_parallel); + } + if (parsed.max_vectorize_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_vectorize); + } + if (parsed.unroll_explicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); + } + if (parsed.unroll_implicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); + } +} + +void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, + const Array& loop_rvs, ParsedAnnotation* parsed) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { + return; + } + int n_loops = loop_rvs.size(); + if (n_loops == 0) { + parsed->max_parallel_extent = -1; + parsed->max_vectorize_extent = -1; + return; + } + // Extract loop_srefs, and calculate the iterator types + Array loop_srefs; + std::vector loop_types; + { + loop_srefs.reserve(n_loops); + loop_types.reserve(n_loops); + for (const LoopRV& loop_rv : loop_rvs) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + loop_types.push_back(GetLoopIterType(loop_srefs.back())); + } + } + // check the maximal number of axes that are vectorizable (contiguous memory access) + BlockRealize realize = GetBlockRealize(sch->state(), block_sref); + Array buffer_access(realize->block->reads); + buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), + realize->block->writes.end()); + std::unordered_map binding_map; + for (size_t i = 0; i < realize->iter_values.size(); i++) { + binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i]; + } + int max_fusible = INT32_MAX; + // for each block read/write, get the strides of the loop vars and find the fusible + // (vectorizable) axes + for (const BufferRegion& access : buffer_access) { + int fusible = 0; + std::vector strides; + // get strides for each loop var + for (const StmtSRef& loop_sref : loop_srefs) { + int64_t stride = 0, buffer_stride = 1; + const auto* var = loop_sref->StmtAs(); + arith::Analyzer analyzer; + for (int i = access->region.size() - 1; i >= 0; i--) { + PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); + int64_t coef = StrideExtractor::Extract(idx, var->loop_var); + if (coef != 0) { + stride = coef * buffer_stride; + break; + } + buffer_stride *= access->buffer->shape[i].as()->value; + } + strides.push_back(stride); + } + int prev_used_iter = -1; + // check the number of fusible loops + for (int i = strides.size() - 1; i >= 0; i--) { + if (strides[i] == 0) { + // not used in the buffer access, safe to fuse + fusible++; + continue; + } else if (prev_used_iter == -1) { + // the stride of last axis is not 1 means the memory access is not contiguous + if (strides[i] != 1) { + break; + } + fusible++; + prev_used_iter = i; + } else { + // contiguous memory access + const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs(); + int64_t prev_used_iter_extent = prev_loop->extent.as()->value; + if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { + fusible++; + prev_used_iter = i; + } else { + break; + } + } + } + max_fusible = std::min(max_fusible, fusible); + } + // Calculate the parallelize extent + if (parsed->max_parallel_extent != -1) { + int max_extent = parsed->max_parallel_extent; + int& num_fusible = parsed->num_parallel_loops = 0; + int64_t prod_extent = 1; + for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Then we can fuse it in + ++num_fusible; + // Check if we need to break + prod_extent *= *extent; + if (prod_extent > max_extent || !IsSingleStmt(loop->body)) { + break; + } + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Calculate the vectorize extent + if (parsed->max_vectorize_extent != -1) { + int max_extent = parsed->max_vectorize_extent; + int& num_fusible = parsed->num_vectorize_loops = 0; + int64_t prod_extent = 1; + for (int i = n_loops - 1; + i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Cannot vectorize reduce axis + if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) { + break; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Check if the extent is still in a good range + prod_extent *= *extent; + if (prod_extent > max_extent) { + break; + } + ++num_fusible; + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { + parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // + n_loops - parsed->num_vectorize_loops); + } +} + +bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + Block block = Downcast(prim_func->body)->block; + if (ParseAnnotation(block, parsed)) { + *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); + RemoveParsedAnn(sch, *root_rv, *parsed); + return true; + } + } + } + return false; +} + +void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { + ICHECK_LE(n, loop_rvs->size()); + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); + sch->Parallel(fused); + for (size_t i = 0; i < n; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteVectorize(const Schedule& sch, size_t n, Array* loop_rvs) { + size_t n_loops = loop_rvs->size(); + ICHECK_LE(n, n_loops); + LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); + sch->Vectorize(fused); + for (size_t i = n_loops - n; i < n_loops; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const LoopRV& loop) { + if (max_step > 0) { + sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); + sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); + } +} + +} // namespace tir + +namespace meta_schedule { + +using tir::Schedule; + +class RewriteParallelVectorizeUnrollNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const Schedule& sch) final { + tir::ParsedAnnotation parsed_root; + tir::BlockRV root_rv{nullptr}; + while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { + for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { + Array loop_rvs = sch->GetLoops(block_rv); + if (loop_rvs.empty()) { + continue; + } + tir::ParsedAnnotation parsed = parsed_root; + tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } + // AutoUnroll + if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { + ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + int unroll_explicit = parsed.unroll_explicit != -1; + int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; + tir::RewriteUnroll(sch, unroll_explicit, max_step, loop_rvs[0]); + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); +}; + +Postproc Postproc::RewriteParallelVectorizeUnroll() { + ObjectPtr n = + make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") + .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..9988e874b81d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Move_PUV: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1) + T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + + +@T.prim_func +def Move_PUV0(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + for i0_j0_fused in T.parallel(0, 8192): + for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): + for k1_fused in T.vectorized(0, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1_fused) + T.where( + i0_j0_fused // 64 * 16 + i1 * 4 + i2 < 1024 + and i0_j0_fused % 64 * 32 + j1 * 8 + j2 < 1024 + and k0 * 32 + k1_fused < 1024 + ) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): + postproc = RewriteParallelVectorizeUnroll() + sch = Schedule(Move_PUV) + assert postproc.apply(sch) + print(sch.mod["main"].script()) + mod = tvm.tir.transform.Simplify()(sch.mod) + tvm.ir.assert_structural_equal(mod["main"], Move_PUV0) + + +if __name__ == "__main__": + test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize() From 04aa0717c4603506b3b2c9288457b5163d81e3d4 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Wed, 26 Jan 2022 12:57:30 -0800 Subject: [PATCH 17/49] add pytest condition to pass CI. rename test name to be consistent. --- ...hedule_byoc_trt.py => test_meta_schedule_byoc_tensorrt.py} | 4 ++++ 1 file changed, 4 insertions(+) rename tests/python/unittest/{test_meta_schedule_byoc_trt.py => test_meta_schedule_byoc_tensorrt.py} (99%) diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py similarity index 99% rename from tests/python/unittest/test_meta_schedule_byoc_trt.py rename to tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index ca38a7d118a7..9204e151344a 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_trt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -52,6 +52,7 @@ not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" ) + # conv2d+relu network def get_conv2d_relu( data_shape, @@ -224,6 +225,9 @@ def test_conv2d_relu(): verify_meta_schedule_with_tensorrt(mod, params, data_shape) +@tvm.testing.requires_cuda +@has_tensorrt_codegen +@has_tensorrt_runtime @pytest.mark.parametrize( "model_name", ["resnet-50", "mobilenet"], From 394ec0177745e358951f22ac2a913680d93860db Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Thu, 27 Jan 2022 01:00:32 +0300 Subject: [PATCH 18/49] Clear warnings when building with MSVC. (#10059) * Fix warning "unsafe mix of type 'const int64_t' and type 'bool' in operation" occurring in tvm::tir::HasAnn * Suppress warning "destructor never returns, potential memory leak" occurring in tvm::runtime::detail::LogFatal::~LogFatal --- include/tvm/ir/expr.h | 6 ++++++ include/tvm/runtime/logging.h | 7 +++++++ src/tir/schedule/utils.h | 2 +- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c7d0f58e3d9f..0e43abb54b93 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -374,6 +374,12 @@ inline Bool operator&&(const Bool& a, const Bool& b) { return Bool(a.operator bool() && b.operator bool()); } +inline bool operator==(const Bool& a, bool b) { return a.operator bool() == b; } +inline bool operator==(bool a, const Bool& b) { return a == b.operator bool(); } +inline bool operator==(const Bool& a, const Bool& b) { + return a.operator bool() == b.operator bool(); +} + /*! * \brief Container of constant int that adds more constructors. * diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index a951264b9706..25e70289118c 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -302,7 +302,14 @@ TVM_DLL void LogMessageImpl(const std::string& file, int lineno, const std::stri class LogFatal { public: LogFatal(const std::string& file, int lineno) : file_(file), lineno_(lineno) {} +#ifdef _MSC_VER +#pragma disagnostic push +#pragma warning(disable : 4722) +#endif ~LogFatal() TVM_THROW_EXCEPTION { LogFatalImpl(file_, lineno_, stream_.str()); } +#ifdef _MSC_VER +#pragma disagnostic pop +#endif std::ostringstream& stream() { return stream_; } private: diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 673813b0f140..2de8ef6e0c93 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -374,7 +374,7 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& an */ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { Optional result = GetAnn(sref, ann_key); - return result.defined() && result.value()->value == ann_val; + return result.defined() && result.value() == ann_val; } /********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/ From 478896ba223e2cdf6bb94ee3663d4308d5690e56 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 26 Jan 2022 16:00:48 -0600 Subject: [PATCH 19/49] [Makefile] Fixed error in "make clean" (#10048) The top-level makefile should delegate `make clean` to the cmake folder of each enabled build, similar to the existing delegation of `make all` and `make runtime`. --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index eba5d5710fdd..05bd7245f830 100644 --- a/Makefile +++ b/Makefile @@ -78,7 +78,7 @@ FORCE: # Since the pattern stem is already being used for the directory name, # cannot also have it refer to the command passed to cmake. # Therefore, explicitly listing out the delegated. -CMAKE_TARGETS = all runtime vta cpptest crttest +CMAKE_TARGETS = all runtime vta cpptest crttest clean define GEN_CMAKE_RULE %/$(CMAKE_TARGET): %/CMakeCache.txt FORCE @@ -174,7 +174,7 @@ jvminstall: -Dcurrent_libdir="$(TVM_BUILD_PATH)" $(JVM_TEST_ARGS)) # Final cleanup rules, delegate to more specific rules. -clean: cmake_clean cyclean webclean +clean: $(addsuffix /clean,$(TVM_BUILD_PATH)) cyclean webclean docs: python3 tests/scripts/ci.py docs From 5bee3c0349b14d78e71b966c3df259e28e2d7349 Mon Sep 17 00:00:00 2001 From: yuanfz <42092999+FZYUAN-1@users.noreply.github.com> Date: Wed, 26 Jan 2022 23:02:58 +0100 Subject: [PATCH 20/49] [Relay] QLinearMatMul allows 1D weight_scale, weight_zero_point inputs (#10047) * fix after cr * fix after cr 2 * emptycommit * emptycommit 2nd try --- python/tvm/relay/frontend/onnx.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 234beec244ba..aee4bb3e060e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3804,18 +3804,16 @@ def _impl_v10(cls, inputs, attr, params): # # This function attempts to present 'x' in a form that meets both of those # requirements. - def try_resolve_to_const_scalar(x, dtype_override=None): + def try_resolve_to_const(x, dtype_override=None): x2 = try_resolve_var_to_const(x, params) - x3 = ensure_scalar_shape(x2) - + num_elem = np.prod(infer_shape(x)) + if num_elem == 1: + x2 = ensure_scalar_shape(x2) x_dtype = infer_type(x).checked_type.dtype if (dtype_override is not None) and (dtype_override != x_dtype): - x4 = _op.cast(x3, dtype_override) - else: - x4 = x3 - - x5 = fold_constant(x4) - return x5 + x2 = _op.cast(x2, dtype_override) + x3 = fold_constant(x2) + return x3 # Unpack the inputs and obtain some type info... a, a_scale, a_zp, b, b_scale, b_zp, y_scale, y_zp = inputs @@ -3855,14 +3853,14 @@ def try_resolve_to_const_scalar(x, dtype_override=None): ) # _qnn.op.dense requires the zero-point values to have dtype int32. - a_scale_scalar = try_resolve_to_const_scalar(a_scale) - a_zp_scalar = try_resolve_to_const_scalar(a_zp, "int32") + a_scale_scalar = try_resolve_to_const(a_scale) + a_zp_scalar = try_resolve_to_const(a_zp, "int32") - b_scale_scalar = try_resolve_to_const_scalar(b_scale) - b_zp_scalar = try_resolve_to_const_scalar(b_zp, "int32") + b_scale_scalar = try_resolve_to_const(b_scale) + b_zp_scalar = try_resolve_to_const(b_zp, "int32") - y_scale_scalar = try_resolve_to_const_scalar(y_scale) - y_zp_scalar = try_resolve_to_const_scalar(y_zp, "int32") + y_scale_scalar = try_resolve_to_const(y_scale) + y_zp_scalar = try_resolve_to_const(y_zp, "int32") # TODO: Confirm that we're using 'num_hidden_units' correctly / as intended with # the '_qnn.op.dense' instance below. From 87dd34962da0b13d4c9abf7fb4f5232efd902246 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 26 Jan 2022 17:03:21 -0600 Subject: [PATCH 21/49] Don't explicitly link libgcc.a into libtvm_runtime.so on Android (#10052) Setting Android toolchain via CMAKE_TOOLCHAIN_FILE also causes necessary flags to be added. Also, newer versions of the Android NDK no longer ship libgcc.a, so this takes care of that as well. --- CMakeLists.txt | 6 ------ 1 file changed, 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b60258206bb8..01d667f7545e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -241,12 +241,6 @@ if(NOT BUILD_FOR_HEXAGON) list(APPEND TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) endif() -if(BUILD_FOR_ANDROID) - # EmuTLS on Android is in libgcc. Without it linked in, libtvm_runtime.so - # won't load on Android due to missing __emutls_XXX symbols. - list(APPEND TVM_RUNTIME_LINKER_LIBS "gcc") -endif() - # add source group tvm_file_glob(GLOB_RECURSE GROUP_SOURCE "src/*.cc") tvm_file_glob(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h") From c09d1c70a829686332fa3b17a4a9563afa1f3727 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 26 Jan 2022 15:47:45 -0800 Subject: [PATCH 22/49] Change function constructors to WithFields (#9690) * Change function constructors to WithFields Get rid of std::moves, they were causing problems * Fix bad rebase * flaky * try to trigger ci * try again --- python/tvm/ir/json_compact.py | 1 + .../contrib/cmsisnn/extract_constants.cc | 7 +++--- .../backend/contrib/cmsisnn/relay_to_tir.cc | 9 ++----- src/relay/backend/contrib/ethosu/codegen.cc | 8 ++----- .../example_target_hooks/relay_to_tir.cc | 9 ++----- src/relay/backend/te_compiler.cc | 12 ++++++++-- src/relay/backend/vm/lambda_lift.cc | 6 ++--- src/relay/ir/expr.cc | 1 + src/relay/quantize/annotate.cc | 2 +- src/relay/quantize/calibrate.cc | 9 +++++-- src/relay/transforms/annotate_target.cc | 2 +- src/relay/transforms/convert_sparse_conv2d.cc | 4 ++-- src/relay/transforms/convert_sparse_dense.cc | 4 ++-- src/relay/transforms/de_duplicate.cc | 9 +++---- src/relay/transforms/defunctionalization.cc | 7 +++--- src/relay/transforms/eta_expand.cc | 3 +-- src/relay/transforms/first_order_gradient.cc | 5 ++-- src/relay/transforms/higher_order_gradient.cc | 19 ++++++++------- src/relay/transforms/inline.cc | 5 ++-- src/relay/transforms/partial_eval.cc | 24 +++++++++---------- src/relay/transforms/partition_graph.cc | 19 +++++++-------- src/relay/transforms/pass_utils.h | 2 +- src/relay/transforms/simplify_fc_transpose.cc | 4 ++-- src/relay/transforms/to_a_normal_form.cc | 4 ++-- src/relay/transforms/to_cps.cc | 13 +++++----- 25 files changed, 94 insertions(+), 94 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 9666475b8039..a6bcc28dad43 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -86,6 +86,7 @@ def _initialize_virtual_device(item, _): "relay.RefRead": _initialize_virtual_device, "relay.RefWrite": _initialize_virtual_device, "relay.Match": _initialize_virtual_device, + "relay.Constant": _initialize_virtual_device, } return create_updater(node_map, "0.8", "0.9") diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 61f215a7d88c..9b724034ccf2 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -67,8 +67,8 @@ class ExtractConstantsMutator : public MixedModeMutator { auto new_body = VisitExpr(func->body); functions_.pop_back(); if (function_to_constants_[func].size()) { - func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), - func->attrs); + func = WithFields(func, FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); } return std::move(func); } @@ -159,8 +159,7 @@ IRModule ExtractConstants(const IRModule& mod) { auto new_main_body = extract_constants.VisitExpr(main_func->body); if (!new_main_body.same_as(main_func->body)) { auto main_var = mod->GetGlobalVar("main"); - auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, - main_func->type_params, main_func->attrs); + Function new_main_func = WithFields(main_func, main_func->params, new_main_body); mod->Update(main_var, new_main_func); } return mod; diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index b8744247e9a6..f366e4ab2635 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -46,13 +46,8 @@ class RelayToTIRVisitor : public MixedModeMutator { IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - BaseFunc main = ir_module_->Lookup(main_global_var); - Function main_func = GetRef(main.as()); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = Downcast(ir_module_->Lookup(main_global_var)); + Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index d618a4971189..0fdbb7063e3f 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -56,12 +56,8 @@ class RelayToTIRMutator : public MixedModeMutator { IRModule operator()() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - Function main_func = Downcast(ir_module_->Lookup(main_global_var)); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = Downcast(ir_module_->Lookup(main_global_var)); + Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_); diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 89b325f51a0c..6794594b5ba4 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -43,13 +43,8 @@ class ConvertAddToSubtract : public MixedModeMutator { IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - BaseFunc main = ir_module_->Lookup(main_global_var); - Function main_func = GetRef(main.as()); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = GetRef(ir_module_->Lookup(main_global_var).as()); + Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 3ff6076473f1..3000ef9640f3 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -100,6 +100,7 @@ class TECompilerImpl : public TECompilerNode { } IRModule GetLoweredFunctions() { + VLOG(1) << "GetLoweredFunctions"; IRModule mod; // Extract lowered functions from the cache for (const auto& it : cache_) { @@ -164,8 +165,15 @@ class TECompilerImpl : public TECompilerNode { for (const auto& kv2 : kv1.second->cached_func->funcs->functions) { if (const auto* function_node = kv2.second.as()) { // Abandon the existing function annotations. - Function function(function_node->params, function_node->body, function_node->ret_type, - function_node->type_params, /*attrs=*/{}, function_node->span); + + // Unfortuantely, Optional() is indistinguishable from + // NullValue(), and DictAttrs() is nullptr, so to erase the attributes, we + // need pass in DictAttrs()), which is a DictAttrs containing no + // attributes. + Function function = + WithFields(GetRef(function_node), function_node->params, + function_node->body, function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); // Mark function as 'extern' using the "ExternalSymbol" attribute. function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint); module->Add(kv2.first, function); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 0457459b3847..f2bd9e6b9a8a 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -102,8 +102,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { if (function_nesting() == 1) { // We don't need to lift global functions. - return Function(func_node->params, VisitExpr(func_node->body), func_node->ret_type, - func_node->type_params, func_node->attrs, func_node->span); + return WithFields(GetRef(func_node), func_node->params, VisitExpr(func_node->body)); } auto name = GenerateName(func); @@ -188,8 +187,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. size_t before_arity = body->params.size(); - auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, - func->type_params, func->attrs, func->span); + auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); size_t after_arity = rebound_body->params.size(); CHECK_EQ(before_arity, after_arity); lifted_func = diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 73ae3faf7078..fc76577bd7c0 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -45,6 +45,7 @@ using namespace tvm::runtime; Constant::Constant(runtime::NDArray data, Span span) { ObjectPtr n = make_object(); n->data = std::move(data); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relay/quantize/annotate.cc b/src/relay/quantize/annotate.cc index 3def616e9423..c704bcbc466b 100644 --- a/src/relay/quantize/annotate.cc +++ b/src/relay/quantize/annotate.cc @@ -98,7 +98,7 @@ Pass QuantizeAnnotate() { for (const auto& x : FreeVars(func)) { new_params.push_back(x); } - return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs); + return WithFields(func, new_params); }; return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index 0ac445295496..21ed61187c38 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -152,8 +152,13 @@ class StatsCollector : private ExprMutator { const FunctionNode* func = new_e.as(); ICHECK(func) << "Input shoule be Function"; Expr new_body = Tuple(std::move(profile_data_)); - return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); + Function ret_func = WithFields(GetRef(func), FreeVars(new_body), new_body); + + // We are changing the function's ret_type to an empty type. Unfortunately, Optional() is + // indistinguishable from NullValue(), so we can't express "update to nullptr" in + // WithFields. + ret_func.CopyOnWrite()->ret_type = NullValue(); + return ret_func; } private: diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 6e4ab88ea326..3f1985b7ddfa 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -295,7 +295,7 @@ class AnnotateTargetRewriter : public ExprRewriter { func = Downcast(post); new_body = InsertCompilerEndAndPropogateTarget(func->body); } - return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); + return WithFields(func, func->params, new_body); } Expr Rewrite_(const LetNode* op, const Expr& post) override { diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 3f2c25e988f9..f2af290f3e22 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -292,12 +292,12 @@ Pass Conv2dToSparse(const Array& weight_name, const Array(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(f0, sparse_params); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(f1, params); }; return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc index 26a4d487196d..faba366eca49 100644 --- a/src/relay/transforms/convert_sparse_dense.cc +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -135,12 +135,12 @@ Pass DenseToSparse(const Array& weight_name, // Remove FreeVar warnings auto f0 = Downcast(DenseToSparse(f, weight_name, weight_shape)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(f0, sparse_params); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(f1, params); }; return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index b3e88376abcb..23e147d5d4c4 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -82,16 +82,17 @@ Expr DeDup(const Expr& e) { Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } - Expr VisitExpr_(const FunctionNode* op) final { + Expr VisitExpr_(const FunctionNode* func_node) final { tvm::Array type_params; - for (const TypeVar& type_param : op->type_params) { + for (const TypeVar& type_param : func_node->type_params) { type_params.push_back(Fresh(type_param)); } tvm::Array params; - for (const Var& param : op->params) { + for (const Var& param : func_node->params) { params.push_back(Fresh(param)); } - return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs); + return WithFields(GetRef(func_node), params, VisitExpr(func_node->body), + VisitType(func_node->ret_type), type_params); } Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 5255a672a856..38e403a8d9b0 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -283,7 +283,7 @@ class DefuncMutator : public ExprMutator { auto apply_gv = GetApplyFunction(ft); auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); - AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params), + AddApplyCase(apply_gv, ft, c, WithFields(GetRef(fn), fn->params, body), pattern_vars); return Call(c, call_args); @@ -380,7 +380,7 @@ class DefuncMutator : public ExprMutator { map.Set(f->type_params[i], type_args[i]); } // copy with typevars removed - auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map); + auto copy = TypeSubst(WithFields(f, {}, {}, {}, /* erase type params */ Array()), map); return Downcast(copy); } @@ -410,7 +410,8 @@ class DefuncMutator : public ExprMutator { } auto bind = Downcast(Bind(f, var_bind_map)); - return Function(params, this->VisitExpr(bind->body), bind->ret_type, {}); + return WithFields(bind, params, this->VisitExpr(bind->body), bind->ret_type, + /* erase type params */ Array()); } }; diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index 4023c9dafef4..40b0a54ba38c 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -129,8 +129,7 @@ class EtaExpander : public ExprMutator { params.push_back(var); args.push_back(var); } - - return Function(args, Call(gvar, params), func->ret_type, func->type_params); + return WithFields(func, args, Call(gvar, params)); } else { return std::move(gvar); } diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index d695c6dc491d..f530d61e0d99 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -307,8 +307,9 @@ Pass FirstOrderGradient() { }); return Pair(res.forward, grad_tuple); }); - ad_mod->Update(pr.first, - Function(func->params, body, GradRetType(GetRef(func)), {})); + ad_mod->Update(pr.first, WithFields(GetRef(func), func->params, body, + GradRetType(GetRef(func)), + /* erase type params */ Array())); } return ad_mod; diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index 202275626d5d..1cf7cb86692c 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -341,28 +341,28 @@ struct ReverseAD : ExprMutator { GlobalVar gv(op->name_hint + "_grad"); (*ad_gvars)[orig_gv] = gv; Function orig_f = Downcast(DeDup(mod.value()->Lookup(orig_gv))); - std::vector params; + Array params; for (const auto& p : orig_f->params) { params.push_back(Downcast(VisitExpr(p))); } params.push_back(bp); - Expr body = VisitExpr(orig_f->body); - Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs); + Function f = WithFields(orig_f, params, VisitExpr(orig_f->body), VisitType(orig_f->ret_type)); std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; mod.value()->Add(gv, f); } return ad_gvars->at(orig_gv); } - Expr VisitExpr_(const FunctionNode* op) final { - std::vector params; - for (const auto& var : op->params) { + Expr VisitExpr_(const FunctionNode* func_node) final { + Array params; + for (const auto& var : func_node->params) { params.push_back(Downcast(VisitExpr(var))); } auto new_bp = Var("bp", bpt); params.push_back(new_bp); - return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), - VisitType(op->ret_type), op->type_params, op->attrs); + return WithFields(GetRef(func_node), params, + ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body), + VisitType(func_node->ret_type)); } Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } @@ -456,7 +456,8 @@ Expr Gradient(const Expr& re, const Optional& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); - auto ret = Function(f->params, body, GradRetType(GetRef(f)), {}); + Function ret = WithFields(GetRef(f), f->params, body, GradRetType(GetRef(f)), + /* erase type params */ Array()); CheckFeature(ret, FeatureSet::All() - fGraph); return std::move(ret); } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index f1492b9f1258..a6e26364bbc4 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -91,8 +91,7 @@ class Inliner : ExprMutator { } Function Inline(const Function& func) { - return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + return WithFields(func, func->params, VisitExpr(func->body)); } private: @@ -131,6 +130,8 @@ class Inliner : ExprMutator { const auto* fn = base_func.as(); ICHECK(fn) << "Expected to work on a Relay function."; + // There is an inconsistency here, the function itself gets shallow-copied but the body is not + // shallow-copied. auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 28d1aa5532bf..fc9922ca03ef 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -827,18 +827,18 @@ class PartialEvaluator : public ExprFunctor Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return Function(func->params, LetList::With([&](LetList* ll) { - std::vector pv; - for (const auto& v : func->params) { - pv.push_back(NoStatic(v)); - } - tvm::Array type_args; - for (const auto& tp : func->type_params) { - type_args.push_back(tp); - } - return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; - }), - func->ret_type, func->type_params, func->attrs); + return WithFields( + func, func->params, LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; + })); }); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index d1b9b563e932..bc1ed518d473 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -213,9 +213,8 @@ class Partitioner : public MixedModeMutator { auto glob_funcs = module_->functions; for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + Function func = GetRef(fn); + func = WithFields(func, func->params, VisitExpr(func->body)); module_->Update(pair.first, func); module_ = transform::InferType()(module_); } @@ -429,7 +428,7 @@ IRModule RemoveDefaultAnnotations(IRModule module) { auto func = GetRef(fn); DefaultRemover remover; auto removed = PostOrderRewrite(func->body, &remover); - func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + func = WithFields(func, func->params, removed); module->Update(pair.first, func); module = relay::transform::InferType()(module); } @@ -482,10 +481,10 @@ IRModule FlattenTupleOutputs(IRModule module) { module.CopyOnWrite(); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + Function func = GetRef(fn); TupleOutFlattener to_flattener; auto removed = PostOrderRewrite(func->body, &to_flattener); - func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + func = WithFields(func, func->params, removed); module->Update(pair.first, func); module = relay::transform::InferType()(module); } @@ -527,12 +526,12 @@ class NameMangleExtFuncs : public MixedModeMutator { auto new_dict = func->attrs->dict; new_dict.Set(tvm::attr::kGlobalSymbol, String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)))); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - DictAttrs(new_dict)); + func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type, + func->type_params, DictAttrs(new_dict)); + new_module->Add(mangled_gvars_[pair.first->name_hint], func); } else { - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + func = WithFields(func, func->params, VisitExpr(func->body)); new_module->Add(pair.first, func); } } diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index 317ac17f83c8..b14a93f02b55 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -106,7 +106,7 @@ bool IsDataDependent(const CallNode* call); */ inline Expr TransformF(const std::function& func, const Expr& e) { if (const FunctionNode* f = e.as()) { - return Function(f->params, func(f->body), f->ret_type, f->type_params, f->attrs); + return WithFields(GetRef(f), f->params, func(f->body)); } else { return func(e); } diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc index b5090e7e6fe4..ad38ea6cb8df 100644 --- a/src/relay/transforms/simplify_fc_transpose.cc +++ b/src/relay/transforms/simplify_fc_transpose.cc @@ -128,12 +128,12 @@ Pass SimplifyFCTranspose(const Array& target_weights) { // Remove FreeVar warning auto f0 = Downcast(SimplifyFCTranspose(f, target_weights)); Array wt_params = FreeVars(f0); - auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(f0, wt_params); Array params = FreeVars(f1); for (const auto& var : wt_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(f1, params); }; return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index f6d5ac9cf8bb..a0841ec44fae 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -298,8 +298,8 @@ class Fill : ExprFunctor, private transform::Lexi PushBoundVar(f->params[i], GetFunctionParamVirtualDevice(f, i)); } EnterFunctionBody(); - ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, - f->type_params, f->attrs); + ret = WithFields(GetRef(f), f->params, + GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body))); // We are done with this function. ExitFunctionBody(); for (size_t i = 0; i < f->params.size(); ++i) { diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index c5d17fbfbef7..6d8fe67847f6 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -272,8 +272,8 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, new_params.push_back(remap(v)); } new_params.push_back(k); - return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), - answer, f->type_params, f->attrs); + return WithFields(f, new_params, + mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), answer); } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { @@ -299,7 +299,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { Function ret = ToCPS(f, m, cm, &var, answer); auto new_type_params = ret->type_params; new_type_params.push_back(answer); - return Function(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); + return WithFields(ret, ret->params, ret->body, ret->ret_type, new_type_params); } Function ToCPS(const Function& f, const IRModule& m) { @@ -311,7 +311,7 @@ Function ToCPS(const Function& f, const IRModule& m) { Function UnCPS(const Function& f) { CheckFeature(f, FeatureSet::All() - fGraph); ICHECK_GT(f->params.size(), 0); - std::vector new_params; + Array new_params; for (const auto& p : f->params) { new_params.push_back(Var(p->name_hint(), p->checked_type())); } @@ -319,7 +319,7 @@ Function UnCPS(const Function& f) { new_params.pop_back(); ICHECK_EQ(cont_type->arg_types.size(), 1); auto new_ret_type = Type(cont_type->arg_types[0]); - std::vector new_type_params; + Array new_type_params; for (const auto& tp : f->type_params) { new_type_params.push_back(TypeVar(tp->name_hint, tp->kind)); } @@ -339,8 +339,7 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - return Function(new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params, - f->attrs); + return WithFields(f, new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params); } TVM_REGISTER_GLOBAL("relay._transform.to_cps") From e2ad90761734a5f5862ecec2b11e9ef6034f4be1 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Wed, 26 Jan 2022 19:25:28 -0800 Subject: [PATCH 23/49] add pyteset decorator to pass CI --- tests/python/unittest/test_meta_schedule_byoc_tensorrt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index 9204e151344a..24b6094af97c 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -201,6 +201,9 @@ def eval_func(rt_mod, device, evaluator_config, repeated_args): ).evaluate() +@tvm.testing.requires_cuda +@has_tensorrt_codegen +@has_tensorrt_runtime def test_conv2d_relu(): data_shape = (1, 1280, 14, 14) out_channels = 256 From c6ae12480f648905fe7d3ecba5416a5c88a23be8 Mon Sep 17 00:00:00 2001 From: Ramana Radhakrishnan Date: Thu, 27 Jan 2022 04:01:36 +0000 Subject: [PATCH 24/49] Document missing qnn operators (#10077) The following qnn operators were missing from the relay documentation. --- docs/reference/langref/relay_op.rst | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/reference/langref/relay_op.rst b/docs/reference/langref/relay_op.rst index 3e797fc93b31..8788eb52ae0d 100644 --- a/docs/reference/langref/relay_op.rst +++ b/docs/reference/langref/relay_op.rst @@ -231,5 +231,18 @@ This level supports dialect operators. .. autosummary:: :nosignatures: - tvm.relay.qnn.op.requantize + tvm.relay.qnn.op.add + tvm.relay.qnn.op.batch_matmul + tvm.relay.qnn.op.concatenate tvm.relay.qnn.op.conv2d + tvm.relay.qnn.op.conv2d_transpose + tvm.relay.qnn.op.dense + tvm.relay.qnn.op.dequantize + tvm.relay.qnn.op.mul + tvm.relay.qnn.op.quantize + tvm.relay.qnn.op.requantize + tvm.relay.qnn.op.rsqrt + tvm.relay.qnn.op.simulated_dequantize + tvm.relay.qnn.op.simulated_quantize + tvm.relay.qnn.op.subtract + tvm.relay.qnn.op.transpose_conv2d From 8735349e4a12062009898f31fe14420058394817 Mon Sep 17 00:00:00 2001 From: Jacob Bohlin Date: Thu, 27 Jan 2022 09:16:53 +0000 Subject: [PATCH 25/49] Add temp git dir to test_cc_reviewers test case (#10058) This decouples the test_cc_reviewers test case from the user's git configuration. The implementation reuses the TempGit structure from test_skip_ci to always use a fresh git environment. --- tests/python/unittest/test_ci.py | 61 +++++++++++++++++--------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/tests/python/unittest/test_ci.py b/tests/python/unittest/test_ci.py index 0c80617985ee..dfd1fd5cde17 100644 --- a/tests/python/unittest/test_ci.py +++ b/tests/python/unittest/test_ci.py @@ -26,16 +26,31 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent -def test_cc_reviewers(): +class TempGit: + def __init__(self, cwd): + self.cwd = cwd + + def run(self, *args): + proc = subprocess.run(["git"] + list(args), cwd=self.cwd) + if proc.returncode != 0: + raise RuntimeError(f"git command failed: '{args}'") + + +def test_cc_reviewers(tmpdir_factory): reviewers_script = REPO_ROOT / "tests" / "scripts" / "github_cc_reviewers.py" def run(pr_body, expected_reviewers): + git = TempGit(tmpdir_factory.mktemp("tmp_git_dir")) + git.run("init") + git.run("checkout", "-b", "main") + git.run("remote", "add", "origin", "https://github.com/apache/tvm.git") proc = subprocess.run( [str(reviewers_script), "--dry-run"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"PR": json.dumps({"number": 1, "body": pr_body})}, encoding="utf-8", + cwd=git.cwd, ) if proc.returncode != 0: raise RuntimeError(f"Process failed:\nstdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}") @@ -53,36 +68,26 @@ def run(pr_body, expected_reviewers): ) -def test_skip_ci(): +def test_skip_ci(tmpdir_factory): skip_ci_script = REPO_ROOT / "tests" / "scripts" / "git_skip_ci.py" - class TempGit: - def __init__(self, cwd): - self.cwd = cwd - - def run(self, *args): - proc = subprocess.run(["git"] + list(args), cwd=self.cwd) - if proc.returncode != 0: - raise RuntimeError(f"git command failed: '{args}'") - def test(commands, should_skip, pr_title, why): - with tempfile.TemporaryDirectory() as dir: - git = TempGit(dir) - # Jenkins git is too old and doesn't have 'git init --initial-branch' - git.run("init") - git.run("checkout", "-b", "main") - git.run("remote", "add", "origin", "https://github.com/apache/tvm.git") - git.run("config", "user.name", "ci") - git.run("config", "user.email", "email@example.com") - git.run("commit", "--allow-empty", "--message", "base commit") - for command in commands: - git.run(*command) - pr_number = "1234" - proc = subprocess.run( - [str(skip_ci_script), "--pr", pr_number, "--pr-title", pr_title], cwd=dir - ) - expected = 0 if should_skip else 1 - assert proc.returncode == expected, why + git = TempGit(tmpdir_factory.mktemp("tmp_git_dir")) + # Jenkins git is too old and doesn't have 'git init --initial-branch' + git.run("init") + git.run("checkout", "-b", "main") + git.run("remote", "add", "origin", "https://github.com/apache/tvm.git") + git.run("config", "user.name", "ci") + git.run("config", "user.email", "email@example.com") + git.run("commit", "--allow-empty", "--message", "base commit") + for command in commands: + git.run(*command) + pr_number = "1234" + proc = subprocess.run( + [str(skip_ci_script), "--pr", pr_number, "--pr-title", pr_title], cwd=git.cwd # dir + ) + expected = 0 if should_skip else 1 + assert proc.returncode == expected, why test( commands=[], From f931fe1f589490f2ffe8e41384049350447a1f8b Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Thu, 27 Jan 2022 09:58:09 +0000 Subject: [PATCH 26/49] [CI] Fix Rust permissions for wasmtime and sccache (#10015) Previously this was ran as part of `ubuntu_install_rust.sh`, as we now have multiple scripts which write as the container build user we have to fix up each time to ensure future users don't break. --- docker/install/ubuntu_install_rust.sh | 2 +- docker/install/ubuntu_install_sccache.sh | 3 +++ docker/install/ubuntu_install_wasmtime.sh | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_rust.sh b/docker/install/ubuntu_install_rust.sh index 58f8256b03b3..d16b3dc5e71e 100755 --- a/docker/install/ubuntu_install_rust.sh +++ b/docker/install/ubuntu_install_rust.sh @@ -26,5 +26,5 @@ export PATH=$CARGO_HOME/bin:$PATH rustup component add rustfmt rustup component add clippy -# make rust usable by all users +# make rust usable by all users after install during container build chmod -R a+w /opt/rust diff --git a/docker/install/ubuntu_install_sccache.sh b/docker/install/ubuntu_install_sccache.sh index 79ce1535c71e..e4a458a4288e 100644 --- a/docker/install/ubuntu_install_sccache.sh +++ b/docker/install/ubuntu_install_sccache.sh @@ -26,3 +26,6 @@ cargo install sccache mkdir /opt/sccache ln "$(which sccache)" /opt/sccache/cc ln "$(which sccache)" /opt/sccache/c++ + +# make rust usable by all users after install during container build +chmod -R a+w /opt/rust diff --git a/docker/install/ubuntu_install_wasmtime.sh b/docker/install/ubuntu_install_wasmtime.sh index d1285b36b429..d85cb49b0a6e 100644 --- a/docker/install/ubuntu_install_wasmtime.sh +++ b/docker/install/ubuntu_install_wasmtime.sh @@ -24,3 +24,6 @@ export WASMTIME_HOME=/opt/wasmtime curl https://wasmtime.dev/install.sh -sSf | bash export PATH="${WASMTIME_HOME}/bin:${PATH}" rustup target add wasm32-wasi + +# make rust usable by all users after install during container build +chmod -R a+w /opt/rust From 27901564202ddbd9b21263aed7ba201f38886b89 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Thu, 27 Jan 2022 03:12:42 -0800 Subject: [PATCH 27/49] [EZ][Typo] Correct gather, scatter type rel error message (#10023) --- src/ir/error.cc | 4 ++-- src/relay/op/tensor/transform.cc | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/ir/error.cc b/src/ir/error.cc index 0089f55a4da8..f0e78b954a41 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -64,13 +64,13 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { ICHECK(has_errs != this->node_to_error_.end()); - const auto& error_indicies = has_errs->second; + const auto& error_indices = has_errs->second; std::stringstream err_msg; err_msg << rang::fg::red; err_msg << " "; - for (auto index : error_indicies) { + for (auto index : error_indices) { err_msg << this->errors_[index].what() << "; "; } err_msg << rang::fg::reset; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9e469f373131..25d836d51160 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1105,7 +1105,7 @@ bool ScatterRel(const Array& types, int num_inputs, const Attrs& attrs, if (updates == nullptr) { return false; } - ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + ICHECK(indices->dtype.is_int()) << "indices of scatter must be tensor of integer"; const auto param = attrs.as(); ICHECK(param != nullptr); reporter->Assign(types[3], TensorType(data->shape, data->dtype)); @@ -1125,7 +1125,7 @@ RELAY_REGISTER_OP("scatter") R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE) .set_num_inputs(3) .add_argument("data", "Tensor", "The input data tensor.") - .add_argument("indicies", "Tensor", "The indicies location tensor.") + .add_argument("indices", "Tensor", "The indices location tensor.") .add_argument("updates", "Tensor", "The values to update the input with.") .add_type_rel("Scatter", ScatterRel) .set_attr("TOpIsStateful", false) @@ -1172,7 +1172,7 @@ RELAY_REGISTER_OP("scatter_add") R"doc(Update data by adding values in updates at positions defined by indices)doc" TVM_ADD_FILELINE) .set_num_inputs(3) .add_argument("data", "Tensor", "The input data tensor.") - .add_argument("indicies", "Tensor", "The indicies location tensor.") + .add_argument("indices", "Tensor", "The indices location tensor.") .add_argument("updates", "Tensor", "The values to update the input with.") .add_type_rel("ScatterAdd", ScatterAddRel) .set_attr("TOpIsStateful", false) @@ -3318,7 +3318,8 @@ bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, << "Gather: expect indices type to be TensorType but get " << types[1]; return false; } - ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) + << "indices of gather must be tensor of integer"; const auto param = attrs.as(); ICHECK(param != nullptr); ICHECK(param->axis.defined()); From 0740bc3890aaf58ff376c91e7714d65ff41cba79 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Thu, 27 Jan 2022 09:10:56 -0800 Subject: [PATCH 28/49] [microTVM][tvmc] Add TVMC Micro tutorial for Zephyr (#10024) --- docs/conf.py | 1 + .../work_with_microtvm/micro_autotune.py | 7 +- .../how_to/work_with_microtvm/micro_tvmc.py | 198 ++++++++++++++++++ tests/micro/common/test_tvmc.py | 2 +- 4 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 gallery/how_to/work_with_microtvm/micro_tvmc.py diff --git a/docs/conf.py b/docs/conf.py index 2f650a88c936..702821c470f4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -316,6 +316,7 @@ def git_describe_version(original_version): "micro_reference_vm.py", "micro_tflite.py", "micro_ethosu.py", + "micro_tvmc.py", ], } diff --git a/gallery/how_to/work_with_microtvm/micro_autotune.py b/gallery/how_to/work_with_microtvm/micro_autotune.py index 394a946cf3d5..c6516563fac3 100644 --- a/gallery/how_to/work_with_microtvm/micro_autotune.py +++ b/gallery/how_to/work_with_microtvm/micro_autotune.py @@ -18,7 +18,7 @@ """ .. _tutorial-micro-autotune: -Autotuning with micro TVM +Autotuning with microTVM ========================= **Authors**: `Andrew Reusch `_, @@ -28,11 +28,10 @@ """ import numpy as np -import subprocess import pathlib import tvm -from tvm.relay.backend import Executor, Runtime +from tvm.relay.backend import Runtime #################### # Defining the model @@ -67,7 +66,7 @@ params = {"weight": weight_sample} ####################### -# Defining the target # +# Defining the target ####################### # Now we define the TVM target that describes the execution environment. This looks very similar # to target definitions from other microTVM tutorials. Alongside this we pick the C Runtime to code diff --git a/gallery/how_to/work_with_microtvm/micro_tvmc.py b/gallery/how_to/work_with_microtvm/micro_tvmc.py new file mode 100644 index 000000000000..423e0f1dde37 --- /dev/null +++ b/gallery/how_to/work_with_microtvm/micro_tvmc.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _tutorial-micro-tvmc: + +Executing a Tiny Model with TVMC Micro +====================================== +**Author**: `Mehrdad Hessar `_ + +This tutorial explains how to compile a tiny model for a micro device, +build a program on Zephyr platform to execute this model, flash the program +and run the model all using `tvmc micro` command. +""" + +###################################################################### +# .. note:: +# This tutorial is explaining using TVMC Mirco on Zephyr platform. You need +# to install Zephyr dependencies before processing with this tutorial. Alternatively, +# you can run this tutorial in one of the following ways which has Zephyr depencencies already installed. +# +# * Use `microTVM Reference Virtual Machines `_. +# * Use QEMU docker image provided by TVM. Following these you will download and login to the docker image: +# +# .. code-block:: bash +# +# cd tvm +# ./docker/bash.sh tlcpack/ci-qemu +# + + +############################################################ +# Using TVMC Micro +############################################################ +# +# TVMC is a command-line tool which is installed as a part of TVM Python packages. Accessing this +# package varies based on your machine setup. In many cases, you can use the ``tvmc`` command directly. +# Alternatively, if you have TVM as a Python module on your ``$PYTHONPATH``, you can access this +# driver with ``python -m tvm.driver.tvmc`` command. This tutorial will use TVMC command as +# ``tvmc`` for simplicity. +# +# To check if you have TVMC command installed on your machine, you can run: +# +# .. code-block:: bash +# +# tvmc --help +# +# To compile a model for microtvm we use ``tvmc compile`` subcommand. The output of this command +# is used in next steps with ``tvmc micro`` subcommands. You can check the availability of TVMC Micro using: +# +# .. code-block:: bash +# +# tvmc micro --help +# +# The main tasks that you can perform using ``tvmc micro`` are ``create``, ``build`` and ``flash``. +# To read about specific options under a givern subcommand, use +# ``tvmc micro --help``. We will use each subcommand in this tutorial. +# + +############################################################ +# Obtain a Tiny Model +############################################################ +# +# For this tutorial, we will use Magic Wand model from tflite micro. Magic Wand is a +# Depthwise Convolution Layer model which recognizes gestures with an accelerometer. +# +# For this tutorial we will be using the model in tflite format. +# +# .. code-block:: bash +# +# wget https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/magic_wand/magic_wand.tflite +# + +############################################################ +# Compiling a TFLite model to a Model Library Format +############################################################ +# +# Model Library Format (MLF) is an output format that TVM provides for micro targets. MLF is a tarball +# containing a file for each piece of the TVM compiler output which can be used on micro targets outside +# TVM environment. Read more about `Model Library Format `_. +# +# Here, we generate a MLF file for ``qemu_x86`` Zephyr board. To generate MLF output for the ``magic_wand`` tflite model: +# +# .. code-block:: bash +# +# tvmc compile magic_wand.tflite \ +# --target='c -keys=cpu -link-params=0 -model=host' \ +# --runtime=crt \ +# --runtime-crt-system-lib 1 \ +# --executor='graph' \ +# --executor-graph-link-params 0 \ +# --output model.tar \ +# --output-format mlf \ +# --pass-config tir.disable_vectorize=1 \ +# --disabled-pass=AlterOpLayout +# +# This will generate a ``model.tar`` file which contains TVM compiler output files. To run this command for +# a different Zephyr device, you need to update ``target``. For instance, for ``nrf5340dk_nrf5340_cpuapp`` board +# the target is ``--target='c -keys=cpu -link-params=0 -model=nrf5340dk'``. +# + + +############################################################ +# Create a Zephyr Project Using Model Library Format +############################################################ +# +# To generate a Zephyr project we use TVM Micro subcommand ``create``. We pass the MLF format and the path +# for the project to ``create`` subcommand along with project options. Project options for each +# platform (Zephyr/Arduino) are defined in their Project API server file. To generate Zephyr project, run: +# +# .. code-block:: bash +# +# tvmc micro create \ +# project \ +# model.tar \ +# zephyr \ +# --project-option project_type=host_driven zephyr_board=qemu_x86 +# +# This will generate a ``Host-Driven`` Zephyr project for ``qemu_x86`` Zephyr board. In Host-Driven template project, +# the Graph Executor will run on host and perform the model execution on Zephyr device by issuing commands to the +# device using an RPC mechanism. Read more about `Host-Driven Execution `_. +# +# To get more information about TVMC Micro ``create`` subcommand: +# +# .. code-block:: bash +# +# tvmc micro create --help +# + +############################################################ +# Build and Flash Zephyr Project Using TVMC Micro +############################################################ +# +# Next step is to build the Zephyr project which includes TVM generated code for running the tiny model, Zephyr +# template code to run a model in Host-Driven mode and TVM runtime source/header files. To build the project: +# +# .. code-block:: bash +# +# tvmc micro build \ +# project \ +# zephyr \ +# --project-option zephyr_board=qemu_x86 +# +# This will build the project in ``project`` directory and generates binary files under ``project/build``. To build +# Zephyr project for a different Zephyr board, change ``zephyr_board`` project option. +# +# Next, we flash the Zephyr binary file to Zephyr device. For ``qemu_x86`` Zephyr board this step does not +# actually perform any action since QEMU will be used, however you need this step for physical hardware. +# +# .. code-block:: bash +# +# tvmc micro flash \ +# project \ +# zephyr \ +# --project-option zephyr_board=qemu_x86 +# + +############################################################ +# Run Tiny Model on Micro Target +############################################################ +# +# After flashing the device, the compiled model and TVM RPC server are programmed on the device. +# The Zephyr board is waiting for host to open a communication channel. MicroTVM devices typicall communicate +# using a serial communication (UART). To run the flashed model on the device using TVMC, we use ``tvmc run`` subcommand +# and pass ``--device micro`` to specify the device type. This command will open a communication channel, set input +# values using ``Graph Executor`` on host and run full model on the device. Then it gets output from the device. +# +# .. code-block:: bash +# +# tvmc run \ +# --device micro \ +# project \ +# --project-option zephyr_board=qemu_x86 \ +# --fill-mode ones +# --print-top 4 +# # Output: +# # +# # INFO:__main__:b'[100%] [QEMU] CPU: qemu32,+nx,+pae\n' +# # remote: microTVM Zephyr runtime - running +# # INFO:__main__:b'[100%] Built target run\n' +# # [[3. 1. 2. 0. ] +# # [0.47213247 0.41364592 0.07525456 0.03896701]] +# +# Specifically, this command sets the input of the model to all ones and shows the four values of the output with their indices. diff --git a/tests/micro/common/test_tvmc.py b/tests/micro/common/test_tvmc.py index eb0b3a628442..a3e8a9e8b5a4 100644 --- a/tests/micro/common/test_tvmc.py +++ b/tests/micro/common/test_tvmc.py @@ -80,7 +80,7 @@ def test_tvmc_model_build_only(board, output_dir): shutil.rmtree(out_dir_temp) os.mkdir(out_dir_temp) - model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data") + model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data") tar_path = str(output_dir / "model.tar") project_dir = str(output_dir / "project") From f61b172e2003df17e03344192872598bcfdfa98b Mon Sep 17 00:00:00 2001 From: lhutton1 <35535092+lhutton1@users.noreply.github.com> Date: Thu, 27 Jan 2022 19:45:05 +0000 Subject: [PATCH 29/49] [CI][Fix] Remove additional qnn.op.transpose_conv2d from docs (#10083) Fixes CI after #10077, and replaces misuse elsewhere. Change-Id: I095fc8ea2b8d268b09538832cba1f5482a73a9d9 --- docs/reference/langref/relay_op.rst | 1 - python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/relay/qnn/op/qnn.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/reference/langref/relay_op.rst b/docs/reference/langref/relay_op.rst index 8788eb52ae0d..8bc24b9ab865 100644 --- a/docs/reference/langref/relay_op.rst +++ b/docs/reference/langref/relay_op.rst @@ -245,4 +245,3 @@ This level supports dialect operators. tvm.relay.qnn.op.simulated_dequantize tvm.relay.qnn.op.simulated_quantize tvm.relay.qnn.op.subtract - tvm.relay.qnn.op.transpose_conv2d diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index abd3e28bc3eb..b8a31370658e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -482,7 +482,7 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw), name="conv2d_transpose_nchw.generic", ) - else: # group_transpose_conv2d + else: # group_conv2d_transpose strategy.add_implementation( wrap_compute_conv2d_transpose(topi.nn.group_conv2d_transpose_nchw, has_groups=True), wrap_topi_schedule(topi.generic.schedule_group_conv2d_transpose_nchw), diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index aef514d81cc1..85629a9b5a5a 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -527,8 +527,8 @@ def conv2d_transpose( kernel_scale: tvm.relay.Expr The scale for the weight tensor. The scale for the weight tensor is stored for access to this during relay. This information is not - needed in the pass pipeline after qnn.transpose_conv2d is lowered to the - sequence of steps as in nn.transpose_conv2d. See also input_scale in Requantize. + needed in the pass pipeline after qnn.conv2d_transpose is lowered to the + sequence of steps as in nn.conv2d_transpose. See also input_scale in Requantize. strides : Tuple[int], optional The strides of convolution. From b1bf8718135455d243925f3f086d2c7e04df358b Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 28 Jan 2022 00:10:24 -0800 Subject: [PATCH 30/49] [PyTorch] Fix rsub type (#10090) * [PyTorch] Fix rsub type * fix --- python/tvm/relay/frontend/pytorch.py | 5 +---- tests/python/frontend/pytorch/test_forward.py | 7 +++++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f7538f0837c6..16dd0c447124 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1972,10 +1972,7 @@ def stack(self, inputs, input_types): return self.tensor_array_stack(inputs, input_types) def rsub(self, inputs, input_types): - data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2]) - - # TODO (t-vi): should this also be part of the type promotion? - alpha = _expr.const(float(inputs[2])) + data0, data1, alpha = self.pytorch_promote_types(inputs, input_types) # note: rsub means data0 and data1 swap places return get_relay_op("subtract")(data1, alpha * data0) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3fbef494f16d..a02701b5278a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2691,6 +2691,13 @@ def forward(self, *args): verify_model(Rsub2().float().eval(), input_data=[d1, d2]) verify_model(Rsub2().float().eval(), input_data=[d1, d3]) + d1 = torch.rand([1, 3]).half() + d2 = torch.rand([1, 3]).half() + verify_model(Rsub1().half().eval(), input_data=[d1, d2]) + verify_model(Rsub1().half().eval(), input_data=[d1, d3]) + verify_model(Rsub2().half().eval(), input_data=[d1, d2]) + verify_model(Rsub2().half().eval(), input_data=[d1, d3]) + @tvm.testing.uses_gpu def test_forward_embedding(): From 9f0cc50d43df9354a816f60b323e6cae59b73611 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 28 Jan 2022 15:07:37 +0000 Subject: [PATCH 31/49] [microNPU] Removing constant args from PrimFunc (#9951) Before this commit, microNPU creates PrimFunc as if it accepts constants from the callee. This commit changes the PrimFunc to remove the constants as an argument to PrimFunc as they are not provided from the main function. --- .../relay/backend/contrib/ethosu/codegen.py | 14 +- .../backend/contrib/ethosu/tir/compiler.py | 1 + .../backend/contrib/ethosu/tir/passes.py | 37 ++ .../contrib/ethosu/tir_to_cs_translator.py | 6 +- .../contrib/test_ethosu/test_compiler.py | 55 ++- .../test_ethosu/test_encode_constants.py | 200 +++++------ .../test_ethosu/test_remove_concatenates.py | 37 +- .../test_ethosu/test_replace_conv2d.py | 152 ++++----- .../contrib/test_ethosu/test_replace_copy.py | 32 +- .../contrib/test_ethosu/test_scheduler.py | 20 +- .../test_ethosu/test_tir_to_cs_translator.py | 316 ++++++++++-------- 11 files changed, 460 insertions(+), 410 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 98ee41f428b2..54312f6c8d6f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -309,8 +309,8 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc: # scratch memory size. tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants()) - for idx in const_dict.keys(): - const_dict[idx] = tvm.nd.array(const_dict[idx]) + for param in const_dict.keys(): + const_dict[param] = tvm.nd.array(const_dict[param]) primfunc = tir_mod["main"] primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"]) @@ -341,11 +341,9 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact tir_mod = tvm.IRModule() tir_mod[symbol] = primfunc - const_dict_with_int_keys = dict() - for idx in const_dict.keys(): - const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy() + const_dict_np = dict() + for buffer_var in const_dict.keys(): + const_dict_np[buffer_var] = const_dict[buffer_var].numpy() - cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate( - tir_mod, const_dict_with_int_keys - ) + cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(tir_mod, const_dict_np) return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index bcd785ddbbd8..ee35da4cab61 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -90,6 +90,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = ethosu_passes.AnnotateAllocates()(mod) + mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod) return mod, const_dict diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index fbc9bf3ff41c..c2fff8abb9b0 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -687,3 +687,40 @@ def _ftransform(f, mod, ctx): return tvm.tir.transform.prim_func_pass( _ftransform, opt_level=0, name="tir.ethosu.remove_concatenates" ) + + +def CreatePrimFuncWithoutConstants(const_dict): + """ + This pass will remove arguments that are constants + from PrimFunc Args. These should be replaced properly + with tir.allocate_const when it becomes available. + + It also modifies the constant dictionary to + rewrite the keys as the actual tir.Vars that are params + rather than the index because this pass removes PrimFunc + arguments that represent constants. + """ + + new_const_dict = dict() + + def _ftransform(f, mod, ctx): + new_params = list() + new_buffer_map = dict() + for param_idx in const_dict.keys(): + # We are using buffer_var to key the constants as + # PrimFunc params of constants will be removed. + new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx] + for i in range(len(f.params)): + if i not in const_dict.keys(): + new_params.append(f.params[i]) + new_buffer_map[f.params[i]] = f.buffer_map[f.params[i]] + return tvm.tir.PrimFunc(new_params, f.body, f.ret_type, new_buffer_map, f.attrs, f.span) + + def _create_primfunc_without_constants(mod): + transform_func = tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.CreatePrimFuncWithoutConstants" + ) + mod = transform_func(mod) + return mod, new_const_dict + + return _create_primfunc_without_constants diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index d7254511ebfc..ecea6eb28f09 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -208,7 +208,7 @@ def extract_buffer_info( ---------- mod : tvm.IRModule The NPU TIR IRModule. - param_dict : Dict[int, np.ndarray] + param_dict : Dict[tvm.tir.Var, np.ndarray] A dictionary containing param idx --> const numpy.NDArray Returns @@ -222,8 +222,7 @@ def extract_buffer_info( assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] - for idx, const_data in param_dict.items(): - param = primfunc.params[idx] + for param, const_data in param_dict.items(): buffer_info[param] = BufferInfo( const_data, const_data.shape, const_data.dtype, BufferType.constant ) @@ -257,7 +256,6 @@ def populate_allocate_buffer_info(stmt): ) tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) - return buffer_info diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py index e1688b8aa512..0e31be86becb 100644 --- a/tests/python/contrib/test_ethosu/test_compiler.py +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -20,27 +20,46 @@ import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from . import infra -def test_lower_to_tir(): - data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8") - weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8") - p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32") - conv = relay.nn.conv2d( - data, - weight, - kernel_size=(1, 1), - data_layout="NHWC", - kernel_layout="HWIO", - out_dtype="int32", - ) - tile = relay.tile(p2, reps=(1, 1, 1, 1001)) - subtract = relay.subtract(conv, tile) - func = subtract - expr = relay.Function(relay.analysis.free_vars(func), func) - mod = tvm.IRModule.from_expr(expr) +def _create_single_conv2d(): + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + func = relay.Function(relay.analysis.free_vars(conv1), conv1) + return func + + +def _create_double_conv2d(): + ifm = relay.var("x", shape=(1, 8, 8, 4), dtype="int8") + conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + conv2 = infra.make_ethosu_conv2d(conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1)) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + return func + + +def _create_non_linear_conv2d(): + shape = (1, 8, 8, 4) + ifm1 = relay.var("x", shape=shape, dtype="int8") + ifm2 = relay.var("y", shape=shape, dtype="int8") + conv1 = infra.make_ethosu_conv2d(ifm1, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + conv2 = infra.make_ethosu_conv2d(ifm2, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1)) + add = infra.make_ethosu_binary_elementwise(conv1, conv2, shape[3], shape[3], "ADD", "int8") + func = relay.Function(relay.analysis.free_vars(add), add) + return func + + +@pytest.mark.parametrize( + "relay_function, arg_count", + [(_create_single_conv2d, 2), (_create_double_conv2d, 2), (_create_non_linear_conv2d, 3)], +) +def test_lower_to_tir_arg_count(relay_function, arg_count): + mod = tvm.IRModule() + mod["main"] = relay_function() mod = relay.transform.InferType()(mod) - lower_to_tir(mod["main"]) + tir_mod = lower_to_tir(mod["main"])[0] + primfunc = tir_mod["main"] + assert len(primfunc.params) == arg_count if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 7f5eeb1121af..315712996ac8 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -34,34 +34,32 @@ @tvm.script.ir_module class WeightStreamOnly: @T.prim_func - def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_7, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_2, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_4 = T.match_buffer(placeholder_5, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_9 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_5 = T.match_buffer(placeholder_3, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_6 = T.match_buffer(placeholder_1, [128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_7 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") + buffer_4 = T.buffer_var("uint8", "") + buffer_5 = T.buffer_var("uint8", "") + buffer_6 = T.buffer_var("uint8", "") + buffer_7 = T.buffer_var("uint8", "") # body - placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5.data, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -100,34 +98,29 @@ def _get_func(): reference_mod = WeightStreamOnly tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = {2: 128, 3: 32, 4: 112, 5: 32, 6: 112, 7: 32, 8: 112, 9: 32} - test_const_sizes = {} - for key, value in consts.items(): - test_const_sizes[key] = len(value) - - assert reference_const_sizes == test_const_sizes + reference_const_sizes = [128, 32, 112, 32, 112, 32, 112, 32] + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module class RereadWeights: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8") - buffer = T.match_buffer(placeholder_1, [304], dtype="uint8") - buffer_1 = T.match_buffer(placeholder_2, [80], dtype="uint8") - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8") + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -166,31 +159,26 @@ def _get_func(): reference_mod = RereadWeights tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = {1: 304, 2: 80} - test_const_sizes = {} - for key, value in consts.items(): - test_const_sizes[key] = len(value) - - assert reference_const_sizes == test_const_sizes + reference_const_sizes = [304, 80] + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module class DirectReadOnly: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") # body - ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 592, 12, T.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 160, 12, T.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 160, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -228,50 +216,45 @@ def _get_func(): reference_mod = DirectReadOnly tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = {1: 592, 2: 160, 3: 160, 4: 80} - test_const_sizes = {} - for key, value in consts.items(): - test_const_sizes[key] = len(value) - - assert reference_const_sizes == test_const_sizes + reference_const_sizes = [592, 160, 160, 80] + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module class MixedRead: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_4 = T.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_5 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_11 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_6 = T.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_7 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_8 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_9 = T.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") + buffer_4 = T.buffer_var("uint8", "") + buffer_5 = T.buffer_var("uint8", "") + buffer_6 = T.buffer_var("uint8", "") + buffer_7 = T.buffer_var("uint8", "") + buffer_8 = T.buffer_var("uint8", "") + buffer_9 = T.buffer_var("uint8", "") # body - ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin": True}) - placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) - placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_6.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -320,23 +303,20 @@ def _get_func(): reference_mod = MixedRead tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = { - 1: 592, - 2: 160, - 4: 80, - 5: 32, - 6: 80, - 7: 32, - 8: 80, - 9: 32, - 10: 80, - 11: 32, - } - test_const_sizes = {} - for key, value in consts.items(): - test_const_sizes[key] = len(value) - - assert reference_const_sizes == test_const_sizes + reference_const_sizes = [ + 592, + 160, + 80, + 32, + 80, + 32, + 80, + 32, + 80, + 32, + ] + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size def test_constant_as_input(): diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index 3b767da98ef4..f6e0e2d855cd 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -19,7 +19,7 @@ pytest.importorskip("ethosu.vela") import tvm import tvm.script -from tvm.script import tir +from tvm.script import tir as T from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir @@ -29,27 +29,24 @@ # fmt: off @tvm.script.ir_module class ReferenceModule: - @tir.prim_func - def main(placeholder: tir.handle, placeholder_1: tir.handle, placeholder_2: tir.handle, placeholder_3: tir.handle, placeholder_4: tir.handle, placeholder_5: tir.handle, placeholder_6: tir.handle, placeholder_7: tir.handle, placeholder_8: tir.handle, placeholder_9: tir.handle, T_concat: tir.handle) -> None: + @T.prim_func + def main(placeholder: T.Buffer[(1, 8, 12, 16), "int8"], placeholder_1: T.Buffer[(1, 8, 10, 16), "int8"], T_concat: T.Buffer[(1, 8, 32, 16), "int8"]) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_2, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_4, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_10 = tir.match_buffer(placeholder_1, [1, 8, 10, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_9, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_8, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_4 = tir.match_buffer(placeholder_5, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_5 = tir.match_buffer(placeholder_6, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - T_concat_1 = tir.match_buffer(T_concat, [1, 8, 32, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - placeholder_11 = tir.match_buffer(placeholder, [1, 8, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_6 = tir.match_buffer(placeholder_7, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_7 = tir.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") + buffer_4 = T.buffer_var("uint8", "") + buffer_5 = T.buffer_var("uint8", "") + buffer_6 = T.buffer_var("uint8", "") + buffer_7 = T.buffer_var("uint8", "") # body - T_concat_2 = tir.allocate([2816], "int8", "global", annotations={"disable_lower_builtin": True}) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", placeholder_10.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", T_concat_2, 192), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 2992, 12, tir.load("uint8", buffer_4.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", T_concat_2, 192), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", T_concat_1.data, 352), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2992, 12, tir.load("uint8", buffer_2.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, tir.load("int8", placeholder_11.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, tir.load("int8", T_concat_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 2992, 12, tir.load("uint8", buffer_7.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, tir.load("int8", T_concat_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, tir.load("int8", T_concat_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_5.data, 0), 2992, 12, tir.load("uint8", buffer_6.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", placeholder_1.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 2992, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat.data, 352), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 2992, 12, T.load("uint8", buffer_3, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_4, 0), 2992, 12, T.load("uint8", buffer_5, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_6, 0), 2992, 12, T.load("uint8", buffer_7, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 7b09fb255663..2136b9f6d1b3 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -245,86 +245,78 @@ def _visit(stmt): @tvm.script.ir_module class Conv2dDoubleCascade1: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_3, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 32), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 32), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade2: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_1, [1312], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_3, [2608], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade3: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], ethosu_write_1: T.Buffer[(1, 20, 4, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 20, 4, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = T.match_buffer(placeholder_3, [1744], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, T.load("int8", placeholder_5.data, 576), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, T.load("int8", ethosu_write_1.data, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, T.load("int8", placeholder_5.data, 576), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, T.load("int8", ethosu_write_1.data, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dDoubleCascade4: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write_1: T.Buffer[(1, 8, 2, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_1, [1456], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_2, [352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder, [1, 8, 1, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 2, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -440,30 +432,26 @@ def _get_func( @tvm.script.ir_module class Conv2dInlineCopy1: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [1, 10, 12, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, T.load("int8", placeholder_3.data, 120), 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 848, 12, T.load("uint8", buffer_1.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, T.load("int8", placeholder_3.data, 120), 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 848, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineCopy2: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], ethosu_write_1: T.Buffer[(1, 3, 5, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 3, 5, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [1, 7, 9, 5], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_1, [656], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, T.load("int8", placeholder_3.data, 146), 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 656, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, T.load("int8", placeholder_3.data, 146), 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 656, 12, T.load("uint8", buffer, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -499,64 +487,56 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): @tvm.script.ir_module class Conv2dInlineReshape1: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [4, 6, 8, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineReshape2: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_3: T.Buffer[(1, 24, 8), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [1, 24, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineReshape3: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_3: T.Buffer[(192, 1), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [192, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @tvm.script.ir_module class Conv2dInlineReshape4: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(1, 8, 6, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index cce414c4c8f7..7aee57d548fe 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -31,18 +31,16 @@ @tvm.script.ir_module class ReferenceModule: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = T.match_buffer(placeholder_2, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") # body placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -77,23 +75,21 @@ def _get_func(): @tvm.script.ir_module class WeightStream: @T.prim_func - def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle) -> None: + def main(placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buffer[(1, 16, 16, 16), "int8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8") - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 16], dtype="int8") - buffer = T.match_buffer(placeholder_1, [416], dtype="uint8") - buffer_1 = T.match_buffer(placeholder_2, [112], dtype="uint8") - buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8") - buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8") + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") # body placeholder_global = T.allocate([416], "uint8", "global", annotations={"disable_lower_builtin": True}) placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 20595465e32e..11dde6fb256d 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -180,25 +180,23 @@ def test_schedule_cache_reads(): @tvm.script.ir_module class DiamondGraphTir: @T.prim_func - def main(placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, placeholder_5: T.handle, ethosu_write_1: T.handle): + def main(input_buffer: T.Buffer[(1, 56, 56, 96), "int8"], output_buffer: T.Buffer[(1, 56, 56, 24), "int8"]) -> None: T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - input_buffer = T.match_buffer(placeholder_1, [1, 56, 56, 96], "int8") - weight_buffer = T.match_buffer(placeholder_2, [2608], "uint8") - bias_buffer = T.match_buffer(placeholder_3, [240], "uint8") - weight_buffer2 = T.match_buffer(placeholder_4, [736], "uint8") - bias_buffer2 = T.match_buffer(placeholder_5, [240], "uint8") - output_buffer = T.match_buffer(ethosu_write_1, [1, 56, 56, 24], "int8") + weight_buffer = T.buffer_var("uint8", "") + bias_buffer = T.buffer_var("uint8", "") + weight_buffer2 = T.buffer_var("uint8", "") + bias_buffer2 = T.buffer_var("uint8", "") placeholder_global = T.allocate([2608], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([240], "uint8", "global", annotations={"disable_lower_builtin":True}) featuremap_buffer = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) featuremap_buffer2 = T.allocate([75264], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer.data, 0), 2608, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer.data, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer, 0), 2608, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, T.load("int8", input_buffer.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 2608, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer2.data, 0), 736, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer2.data, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer2, 0), 736, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer2, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 736, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", output_buffer.data, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) __tvm_meta__ = None diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 0cadf96e7a18..add8021083c6 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -33,15 +33,13 @@ @tvm.script.ir_module class SingleEthosUConv2D: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_conv2d: T.handle) -> None: + def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder_1, [1, 1, 3, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_conv2d_1 = T.match_buffer(ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = T.buffer_var("uint8", "") + placeholder_5 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_4.data, 0), 0, 12, T.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_4, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -50,22 +48,20 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle @tvm.script.ir_module class MultiEthosUConv2D: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_conv2d: T.handle) -> None: + def main(placeholder_6: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_conv2d_1: T.Buffer[(1, 8, 8, 8), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_9 = T.match_buffer(placeholder_3, [1, 1, 32, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_conv2d_1 = T.match_buffer(ethosu_conv2d, [1, 8, 8, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_7 = T.match_buffer(placeholder_1, [1, 1, 3, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_6 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_8 = T.match_buffer(placeholder_2, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_4, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_9 = T.buffer_var("uint8", "") + placeholder_7 = T.buffer_var("uint8", "") + placeholder_8 = T.buffer_var("uint8", "") + placeholder_5 = T.buffer_var("uint8", "") # body ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7.data, 0), 0, 12, T.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9.data, 0), 0, 12, T.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7.data, 0), 0, 12, T.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9.data, 0), 0, 12, T.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -74,18 +70,16 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle @tvm.script.ir_module class MultiEthosUCopy: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_conv2d: T.handle) -> None: + def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_conv2d_1 = T.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = T.match_buffer(placeholder_2, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) - placeholder_4 = T.match_buffer(placeholder_1, [8, 1, 1, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.buffer_var("uint8", "") + placeholder_4 = T.buffer_var("uint8", "") # body placeholder_global = T.allocate([256], "uint8", "global") placeholder_d_global = T.allocate([8], "int32", "global") - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", placeholder_4.data, 0), 256, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("int32", placeholder_5.data, 0), 8, T.load("int32", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", placeholder_4, 0), 256, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("int32", placeholder_5, 0), 8, T.load("int32", placeholder_d_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 0, 12, T.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) # fmt: on @@ -95,34 +89,41 @@ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle @tvm.script.ir_module class WeightStreamOnly: @T.prim_func - def main(placeholder: T.handle, ethosu_conv2d: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") + buffer_4 = T.buffer_var("uint8", "") + buffer_5 = T.buffer_var("uint8", "") + buffer_6 = T.buffer_var("uint8", "") + buffer_7 = T.buffer_var("uint8", "") # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - buffer_4 = T.match_buffer(placeholder_5, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_4, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_7 = T.match_buffer(placeholder_7, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_5 = T.match_buffer(placeholder_1, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_6 = T.match_buffer(placeholder_6, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_conv2d_1 = T.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_3, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_2, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_9 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer = T.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, + "global_symbol": "main", "tir.noalias": True, + "constants": {buffer.name: buffer, + buffer_1.name: buffer_1, + buffer_2.name: buffer_2, + buffer_3.name: buffer_3, + buffer_4.name: buffer_4, + buffer_5.name: buffer_5, + buffer_6.name: buffer_6, + buffer_7.name: buffer_7}}) # body - placeholder_global = T.allocate([144], "uint8", "global") - placeholder_d_global = T.allocate([20], "uint8", "global") - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5.data, 0), 144, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 144, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 144, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 144, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 144, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 144, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7.data, 0), 144, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 144, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) + placeholder_global = T.allocate([128], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -132,38 +133,47 @@ def main(placeholder: T.handle, ethosu_conv2d: T.handle, placeholder_1: T.handle @tvm.script.ir_module class MixedRead: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, ethosu_conv2d: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle) -> None: + def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[(1, 16, 16, 8), "int8"]) -> None: + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("uint8", "") + buffer_3 = T.buffer_var("uint8", "") + buffer_4 = T.buffer_var("uint8", "") + buffer_5 = T.buffer_var("uint8", "") + buffer_6 = T.buffer_var("uint8", "") + buffer_7 = T.buffer_var("uint8", "") + buffer_8 = T.buffer_var("uint8", "") + buffer_9 = T.buffer_var("uint8", "") # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - buffer_5 = T.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_7 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = T.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_6 = T.match_buffer(placeholder_4, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_9 = T.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_conv2d_1 = T.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer = T.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = T.match_buffer(placeholder_10, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_11 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = T.match_buffer(placeholder_6, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_4 = T.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_8 = T.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, + "global_symbol": "main", "tir.noalias": True, + "constants": {buffer.name: buffer, + buffer_1.name: buffer_1, + buffer_2.name: buffer_2, + buffer_3.name: buffer_3, + buffer_4.name: buffer_4, + buffer_5.name: buffer_5, + buffer_6.name: buffer_6, + buffer_7.name: buffer_7, + buffer_8.name: buffer_8, + buffer_9.name: buffer_9}}) # body - ethosu_conv2d_2 = T.allocate([4096], "uint8", "global") - placeholder_global = T.allocate([80], "uint8", "global") - placeholder_d_global = T.allocate([20], "uint8", "global") - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 16, 16, 0, 16, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_5.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 20, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) + ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -174,16 +184,14 @@ def test_buffer_info_extraction(): # Stimulus "tir_module": SingleEthosUConv2D, "param_dict": { - 1: np.random.randint( + tvm.tir.Var("placeholder_4", "uint8"): np.random.randint( np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 16], "uint8" ), - 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32"), + tvm.tir.Var("placeholder_5", "uint8"): np.random.randint( + np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32" + ), }, # Reference Outputs - "constants": { - "placeholder_4": 1, - "placeholder_5": 2, - }, "data_buffers": { "placeholder_3": ( [1, 8, 8, 3], @@ -200,22 +208,20 @@ def test_buffer_info_extraction(): { "tir_module": MultiEthosUConv2D, "param_dict": { - 1: np.random.randint( + tvm.tir.Var("placeholder_7", "uint8"): np.random.randint( np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 32], "uint8" ), - 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [32], "int32"), - 3: np.random.randint( + tvm.tir.Var("placeholder_8", "uint8"): np.random.randint( + np.iinfo("int32").min, np.iinfo("int32").max, [32], "int32" + ), + tvm.tir.Var("placeholder_8", "uint8"): np.random.randint( np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 32, 8], "uint8" ), - 4: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32"), + tvm.tir.Var("placeholder_5", "uint8"): np.random.randint( + np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32" + ), }, # Reference Outputs - "constants": { - "placeholder_5": 4, - "placeholder_7": 1, - "placeholder_8": 2, - "placeholder_9": 3, - }, "data_buffers": { "placeholder_6": ( [1, 8, 8, 3], @@ -248,16 +254,12 @@ def test_buffer_info_extraction(): tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod) buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) for buffer_var, info in buffer_info.items(): - buffer_name = buffer_var.name - if buffer_name in test_case["constants"].keys(): - assert ( - info.values == test_case["param_dict"][test_case["constants"][buffer_name]] - ).all() - assert ( - info.dtype == test_case["param_dict"][test_case["constants"][buffer_name]].dtype - ) + if buffer_var in test_case["param_dict"].keys(): + assert (info.values == test_case["param_dict"][buffer_var]).all() + assert info.dtype == test_case["param_dict"][buffer_var].dtype info.btype == tir_to_cs_translator.BufferType.constant else: + buffer_name = buffer_var.name assert info.btype == test_case["data_buffers"][buffer_name][2] @@ -663,23 +665,25 @@ def populate_ethosu_copy_calls(stmt): @tvm.script.ir_module class MixedConstantDatatypes: @T.prim_func - def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle, placeholder_3: T.handle) -> None: + def main(placeholder_4: T.Buffer[(1, 8, 16, 16), "int8"], ethosu_write_1: T.Buffer[(1, 1, 1, 16), "int8"]) -> None: + buffer = T.buffer_var("uint8", "") + buffer_1 = T.buffer_var("uint8", "") + buffer_2 = T.buffer_var("int16", "") # function attr dict - T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - placeholder_4 = T.match_buffer(placeholder, [1, 8, 16, 16], dtype="int8") - buffer = T.match_buffer(placeholder_1, [160], dtype="uint8") - placeholder_5 = T.match_buffer(placeholder_2, [1, 1, 1, 1], dtype="int16") - ethosu_write_1 = T.match_buffer(ethosu_write, [1, 1, 1, 16], dtype="int8") - buffer_1 = T.match_buffer(placeholder_3, [272], dtype="uint8") + T.func_attr({"from_legacy_te_schedule": True, + "global_symbol": "main", "tir.noalias": True, + "constants": {buffer.name: buffer, + buffer_1.name: buffer_1, + buffer_2.name: buffer_2}}) # body placeholder_global = T.allocate([272], "uint8", "global") placeholder_d_global = T.allocate([160], "uint8", "global") ethosu_write_2 = T.allocate([16], "int16", "global") placeholder_d_global_1 = T.allocate([1], "int16", "global") - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="uint8")) - T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 160, T.load("uint8", placeholder_d_global, 0), dtype="uint8")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 272, T.load("uint8", placeholder_global, 0), dtype="uint8")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 160, T.load("uint8", placeholder_d_global, 0), dtype="uint8")) T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 0, T.load("uint8", placeholder_d_global, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="int16")) - T.evaluate(T.call_extern("ethosu_copy", T.load("int16", placeholder_5.data, 0), 1, T.load("int16", placeholder_d_global_1, 0), dtype="int16")) + T.evaluate(T.call_extern("ethosu_copy", T.load("int16", buffer_2, 0), 1, T.load("int16", placeholder_d_global_1, 0), dtype="int16")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "int16", 1, 1, 1, 1, 0, 1, T.load("int16", placeholder_d_global_1, 0), 0, 0, 0, T.float32(0.0078125018482064768), 0, "NHWC", 1, 1, 1, "int8", 1, 1, 16, 1, 0, 1, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "MUL", 0, "NONE", 0, 0, "NATURAL", dtype="int8")) # fmt: on @@ -690,39 +694,81 @@ def test_assign_addresses(): # Stimulus "tir_module": WeightStreamOnly, "param_dict": { - 2: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), - 3: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), - 4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), - 5: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), - 6: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), - 7: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), - 8: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), - 9: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + WeightStreamOnly["main"].attrs["constants"]["buffer"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [128], "uint8" + ), + WeightStreamOnly["main"].attrs["constants"]["buffer_1"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), + WeightStreamOnly["main"].attrs["constants"]["buffer_2"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [112], "uint8" + ), + WeightStreamOnly["main"].attrs["constants"]["buffer_3"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), + WeightStreamOnly["main"].attrs["constants"]["buffer_4"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [112], "uint8" + ), + WeightStreamOnly["main"].attrs["constants"]["buffer_5"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), + WeightStreamOnly["main"].attrs["constants"]["buffer_6"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [112], "uint8" + ), + WeightStreamOnly["main"].attrs["constants"]["buffer_7"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), }, }, { # Stimulus "tir_module": MixedRead, "param_dict": { - 1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [592], "uint8"), - 3: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8"), - 4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), - 5: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), - 6: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), - 7: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), - 8: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), - 9: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), - 10: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), - 11: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + MixedRead["main"].attrs["constants"]["buffer"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [592], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_1"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_2"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_3"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_4"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_5"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_6"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_7"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_8"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8" + ), + MixedRead["main"].attrs["constants"]["buffer_9"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [32], "uint8" + ), }, }, { # Stimulus "tir_module": MixedConstantDatatypes, "param_dict": { - 1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8"), - 2: np.random.randint(np.iinfo("int16").min, np.iinfo("int16").max, [1], "int16"), - 4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [272], "uint8"), + MixedConstantDatatypes["main"].attrs["constants"]["buffer"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8" + ), + MixedConstantDatatypes["main"].attrs["constants"]["buffer_2"]: np.random.randint( + np.iinfo("int16").min, np.iinfo("int16").max, [1], "int16" + ), + MixedConstantDatatypes["main"].attrs["constants"]["buffer_1"]: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [272], "uint8" + ), }, }, ] From 7a355e874c1984c030f65ec22ce59868e8482871 Mon Sep 17 00:00:00 2001 From: Altan Haan <3124994+altanh@users.noreply.github.com> Date: Fri, 28 Jan 2022 09:13:00 -0800 Subject: [PATCH 32/49] [Relay] fix incorrect binding of Lets in ANF conversion (#10078) * fix incorrect binding of lets in ANF conversion * add test case * remove really weird auto-import from debugging * address comments --- src/relay/transforms/to_a_normal_form.cc | 11 +++++ .../relay/test_pass_to_a_normal_form.py | 40 ++++++++++++++----- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index a0841ec44fae..2f6efb9cef9a 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -223,6 +223,17 @@ class Fill : ExprFunctor, private transform::Lexi bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { return annotated_expr; + } else if (const LetNode* let = AsIgnoringOnDevice(now)) { + // Instead of making a nested binding "let var = (let x = ...; bindings...; body)", we push + // the inner bindings into the outer scope and bind body to var, giving + // "let x = ...; bindings...; let var = body;" as the resulting bindings. + Expr e = GetRef(let); + while (const LetNode* inner_let = AsIgnoringOnDevice(e)) { + GetScope(orig)->let_list->Push(inner_let->var, inner_let->value); + e = inner_let->body; + } + Expr annotated_body = MaybeOnDeviceFixed(e, GetVirtualDevice(orig)); + return GetScope(orig)->let_list->Push(var, annotated_body); } else { return GetScope(orig)->let_list->Push(var, annotated_expr); } diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index cd2e5d2fd249..f44f2a99258b 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest +import sys import numpy as np import tvm from tvm import te @@ -94,6 +96,34 @@ def test_if(): assert tvm.ir.structural_equal(anf, expected_output) +def test_let_as_subexpr(): + def on_cpu(x): + return relay.annotation.on_device(x, tvm.device("cpu"), constrain_result=True) + + x = relay.Var("x", relay.IncompleteType()) + c = relay.const(1) + l = relay.Let(x, on_cpu(c + c), x) + body = l * l + + anf = run_opt_pass(body, [transform.ToANormalForm(), transform.InferType()]) + + v0 = relay.Var("v0", relay.IncompleteType()) + v1 = relay.Var("v1", relay.IncompleteType()) + v2 = relay.Var("v2", relay.IncompleteType()) + expected_output = relay.Let( + v0, + on_cpu(c), + relay.Let( + x, + on_cpu(v0 + v0), + relay.Let(v1, x, relay.Let(v2, v1 * v1, v2)), + ), + ) + expected_output = run_opt_pass(expected_output, transform.InferType()) + + tvm.ir.assert_structural_equal(anf, expected_output) + + # make sure we dont infinite loop. # it is too large so we wont check for the exact program. def test_recursion(): @@ -198,12 +228,4 @@ def test_gradient_if(): if __name__ == "__main__": - test_explicit_bound() - test_order() - test_if() - test_recursion() - test_ref() - test_let() - test_nat_add() - test_function() - test_gradient_if() + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 552fefba66a6a1c1797431be9e8c52353d6c8234 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Fri, 28 Jan 2022 17:30:07 +0000 Subject: [PATCH 33/49] [microTVM] Update Zephyr to 2.7 (#10094) This supports the reference system added in #9853 --- docker/install/ubuntu_install_zephyr.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/install/ubuntu_install_zephyr.sh b/docker/install/ubuntu_install_zephyr.sh index ff716b68fddb..566d3a5761e4 100644 --- a/docker/install/ubuntu_install_zephyr.sh +++ b/docker/install/ubuntu_install_zephyr.sh @@ -47,7 +47,7 @@ pip3 install west # To keep in sync with the version # defined in apps/microtvm/zephyr/template_project/microtvm_api_server.py # We use `-branch` tag since it tracks the same version with extra patches for bugs. -ZEPHYR_VERSION="v2.5-branch" +ZEPHYR_VERSION="v2.7-branch" ZEPHYR_PROJECT_PATH=/opt/zephyrproject ZEPHYR_INIT_SCRIPT=$(find -name "ubuntu_init_zephyr_project.sh") bash ${ZEPHYR_INIT_SCRIPT} ${ZEPHYR_PROJECT_PATH} ${ZEPHYR_VERSION} @@ -69,10 +69,10 @@ chmod o+rwx zephyr/.cache #/opt/west/bin/pip3 install -r /opt/zephyrproject/zephyr/scripts/requirements.txt pip3 install -r /opt/zephyrproject/zephyr/scripts/requirements.txt -ZEPHYR_SDK_VERSION=0.12.3 +ZEPHYR_SDK_VERSION=0.13.2 ZEPHYR_SDK_FILE=zephyr-sdk-linux-setup.run wget --no-verbose -O $ZEPHYR_SDK_FILE \ - https://github.com/zephyrproject-rtos/sdk-ng/releases/download/v${ZEPHYR_SDK_VERSION}/zephyr-sdk-${ZEPHYR_SDK_VERSION}-x86_64-linux-setup.run + https://github.com/zephyrproject-rtos/sdk-ng/releases/download/v${ZEPHYR_SDK_VERSION}/zephyr-sdk-${ZEPHYR_SDK_VERSION}-linux-x86_64-setup.run chmod +x $ZEPHYR_SDK_FILE "./$ZEPHYR_SDK_FILE" -- -d /opt/zephyr-sdk rm "$ZEPHYR_SDK_FILE" From 91abbf8f6e0058b023a0e61558618977ee085e89 Mon Sep 17 00:00:00 2001 From: Hua Jiang Date: Fri, 28 Jan 2022 12:36:31 -0800 Subject: [PATCH 34/49] [Runtime][PipelineExecutor] Pipeline Executor Sequential execution (#10082) * [Runtime][PipelineExecutor] Pipeline Executor Sequential execution In the first, adding the "get output" logic. Secondly, adding the the sequential executing logic of pipeline executor. In the last, testing the pipeline executor interface and checking the output data. * Address review comments. Co-authored-by: Cody Yu * trigger build. Co-authored-by: Cody Yu --- python/tvm/contrib/pipeline_executor.py | 10 ++ src/runtime/pipeline/pipeline_executor.cc | 34 ++++-- src/runtime/pipeline/pipeline_executor.h | 12 ++ src/runtime/pipeline/pipeline_scheduler.cc | 80 +++++++++++++ src/runtime/pipeline/pipeline_scheduler.h | 25 ++++ src/runtime/pipeline/pipeline_struct.h | 82 ++++++++++++- tests/python/relay/test_pipeline_executor.py | 120 ++++++++++++++++--- 7 files changed, 340 insertions(+), 23 deletions(-) diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index 6e991f0c8d7a..b858a209db83 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -121,6 +121,7 @@ def __init__(self, module): self._set_param = self.module["set_param"] self._set_input = self.module["set_input"] self._get_input = self.module["get_input"] + self._get_output = self.module["get_output"] self._get_num_outputs = self.module["get_num_outputs"] self._get_input_pipeline_map = self.module["get_input_pipeline_map"] @@ -203,6 +204,15 @@ def get_input(self, key): """ return self._get_input(key) + def get_output(self): + """Get the output. + Returns + ------- + data : Array[NDArray] + A list of output data. + """ + return self._get_output() + @property def num_outputs(self): """Get the number of outputs. diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index 30c09514480f..2cad8cf3b060 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -74,6 +74,9 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, LOG(FATAL) << "Function only support the input name value in the form of string"; } }); + } else if (name == "get_output") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(); }); } else if (name == "run") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(args[0]); }); } else if (name == "stop") { @@ -111,6 +114,14 @@ NDArray PipelineExecutor::GetInput(std::string input_name) { } return runtimes_[indexs.first]->GetInput(indexs.second); } +/*! + * \brief Getting a module index via a input parameters group name. + * \param name The parameters group name. + * \return int The module index. + */ +int PipelineExecutor::GetParamModuleIndex(const std::string& name) { + return param_connection_config[name]; +} /*! * \brief Using the global input name to get the index, and also get the input interface name of corresponding subgraph from the input connection configuration. @@ -136,14 +147,16 @@ int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) { * \param serialized_mode Whether run the pipeline executor in serialized mode. */ void PipelineExecutor::Run(bool serialized_mode) { - // TODO(huajsj): Run the pipeline executor. + pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_, serialized_mode); } +/*! + * \brief return A list of global output data. + */ +Array PipelineExecutor::GetOutput(void) { return pipeline_scheduler_.PipelineGetOutput(); } /*! * \brief Stop the pipeline executor. */ -void PipelineExecutor::Stop() { - // TODO(huajsj): Stop the pipeline executor. -} +void PipelineExecutor::Stop() { pipeline_scheduler_.PipelineStop(); } /*! * \brief Use the mod_config information to create a graph runtime list. @@ -208,9 +221,16 @@ std::vector PipelineExecutor::CreateGraphModules(const ModuleConfig& mod */ void PipelineExecutor::SetParam(std::string param_group_name, std::string param_key_name, DLTensor* data_in) { - // Get the module index from the param name. - int module_index = this->GetParamsGroupPipelineMap(param_group_name); - // TODO(huajsj): set the parameters into runtime module. + // Get the module index via the parameters group name. + int module_index = this->GetParamModuleIndex(param_group_name); + ICHECK(module_index >= 0 && module_index < static_cast(runtimes_.size())) + << "Parameter group name " << param_group_name << " does not exist."; + auto runtime = runtimes_[module_index]; + // Get the parameter index via the param key name + int index = runtime->GetInputIndex(param_key_name); + ICHECK(index >= 0) << "Parameter name " << param_key_name << " does not exist in module " + << module_index; + runtime->SetInput(index, data_in); } /*! * \brief Return the input index and module index for a given input name. diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index 7dc5baf17ee1..7b9f5eadf92b 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -118,6 +118,18 @@ class TVM_DLL PipelineExecutor : public ModuleNode { * \brief Stop the pipeline executor. */ void Stop(); + /*! + * \brief Get a list output data. + * \return A list of output data. + */ + Array GetOutput(); + /*! + * \brief A pipeline params with a specific name correspond with the params of a specific + * backend module, this function return the module index for the params name. + * \param name The parameters group name. + * \return Return backend runtime module index. + */ + int GetParamModuleIndex(const std::string& name); /*! * \brief A pipeline input with a specific name correspond with a input of a specific * backend module, this function return a module index and a input index in "pair" diff --git a/src/runtime/pipeline/pipeline_scheduler.cc b/src/runtime/pipeline/pipeline_scheduler.cc index 499d75784a15..4a3368e32391 100644 --- a/src/runtime/pipeline/pipeline_scheduler.cc +++ b/src/runtime/pipeline/pipeline_scheduler.cc @@ -18,6 +18,7 @@ */ #include "pipeline_scheduler.h" +#include #include #include namespace tvm { @@ -35,7 +36,86 @@ std::vector> PipelineScheduler::PipelineInit( auto runItem = std::make_shared(graph_modules_[i], i); runtimes.push_back(runItem); } + // Creating a list of NDArray in order to storage the outputs data. + auto global_output_map = pipeline_config.GetGlobalConfigOutputBindings(); + for (size_t i = 0; i < global_output_map.size(); i++) { + if (global_output_map.find(i) == global_output_map.end()) { + LOG(FATAL) << "Not find global output index " << i; + } + ModuleOutputPair& output_pair = global_output_map[i]; + NDArray output = runtimes[output_pair.first]->CreateFromOutput(output_pair.second); + output_arrays_.push_back(output); + } return runtimes; } +/*! + * \brief Running the pipeline logic in the sequential mode. + * \param runtimes A list of backend runtime modules. + * \param pipeline_config The dependent configuration of each runtime module. + */ +void PipelineScheduler::PipelineRunSequential( + const std::vector>& runtimes, + ConfigPipelineExecution pipeline_config) { + for (size_t i = 0; i < runtimes.size(); i++) { + // The "runtimes" is a list of runtime sorted by the runtime index which should be + // contiguous ascend. + if (static_cast(i) != runtimes[i]->GetModuleIndex()) { + LOG(FATAL) << "Runtime index " << runtimes[i]->GetModuleIndex() + << " is not as same as vector offset value " << i; + } + + if (!pipeline_config.FindModuleInConfig(i)) { + LOG(FATAL) << "Not find the configuration for the module " << i; + } + + runtimes[i]->Run(); + // Getting the output then forwarding into other module once it is configured as input of + // another module or storaging into the "output_array" when the output is a global one. + int outputs_num = runtimes[i]->NumOutputs(); + for (int j = 0; j < outputs_num; j++) { + ConfigBindings& out_binding = pipeline_config[i][j]; + std::unordered_map& input_connections = out_binding.Get(); + NDArray output = runtimes[i]->GetOutput(j); + for (auto bind : input_connections) { + // "bind.first < 0" means the bind is a global bind, by pass the forwarding for + // a global bind. + if (bind.first < 0) continue; + // Setting the output as an input data into the runtime module. + runtimes[bind.first]->SetInput(bind.second, const_cast(output.operator->())); + } + // Store the output. + if (out_binding.IsGlobalOutput()) { + int global_idx = out_binding.GetGlobalOutputIndex(); + TVMArrayCopyFromTo(const_cast(output.operator->()), + const_cast(output_arrays_[global_idx].operator->()), nullptr); + } + } + } +} +/*! + * \brief Running pipeline logic. + * \param runtimes A list of backend runtime modules. + * \param pipeline_config The dependency configuration of each runtime module. + * \param sequential_mode Whether the execution is in a sequential mode. + */ +void PipelineScheduler::PipelineRun(const std::vector>& runtimes, + ConfigPipelineExecution pipeline_config, bool sequential_mode) { + if (!sequential_mode) { + // TODO(huajsj) remove this check after all of pipeline features in. + LOG(FATAL) << "Currently only supports sequential mode."; + } else { + PipelineRunSequential(runtimes, pipeline_config); + } +} +/*! + * \brief Stop the pipeline exection. + */ +void PipelineScheduler::PipelineStop() { + // TODO(huajsj) Add stop logic. +} +/*! + * \brief Get a list of output. + */ +Array PipelineScheduler::PipelineGetOutput() { return output_arrays_; } } // namespace runtime } // namespace tvm diff --git a/src/runtime/pipeline/pipeline_scheduler.h b/src/runtime/pipeline/pipeline_scheduler.h index 02c44420bd51..6075747a6c7f 100644 --- a/src/runtime/pipeline/pipeline_scheduler.h +++ b/src/runtime/pipeline/pipeline_scheduler.h @@ -43,10 +43,35 @@ class PipelineScheduler { */ std::vector> PipelineInit( const std::vector& modules, const ConfigPipelineExecution& pipeline_config); + /*! + * \brief Running the pipeline logic. + * \param runtimes A list of backend runtime modules. + * \param pipeline_config The dependency configuration of each runtime module. + * \param sequential_mode Whether the execution is in a sequential mode. + */ + void PipelineRun(const std::vector>& runtimes, + ConfigPipelineExecution pipeline_config, bool sequential_mode = false); + /*! + * \brief Running the pipeline logic in the sequential mode. + * \param runtimes A list of backend runtime modules. + * \param pipeline_config The dependent configuration of each runtime module. + */ + void PipelineRunSequential(const std::vector>& runtimes, + ConfigPipelineExecution pipeline_config); + /*! + * \brief Stop the pipeline exection. + */ + void PipelineStop(); + /*! + * \brief Get a list of outputs. + */ + Array PipelineGetOutput(); private: /*!\brief The list of graph executors.*/ std::vector graph_modules_; + /*!\brief A list of NDArray used to storage outputs.*/ + Array output_arrays_; }; } // namespace runtime } // namespace tvm diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index 40628e989a90..4002885f6702 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -31,6 +31,17 @@ #include namespace tvm { namespace runtime { +#define GLOBAL_MODULE_INDEX -1 +/*! + *\brief The pair includes the module output index and the global output index. + * The first 'int' is the module output index, and the second 'int' is the global output index. + */ +using GlobalOutputPair = std::pair; +/*! + *\brief The pair includes the module index and the module output index. + * The first 'int' is the module index, and the second 'int' is the module output index. + */ +using ModuleOutputPair = std::pair; /*! * \brief All binding information of a output interface. */ @@ -38,7 +49,10 @@ class ConfigBindings { public: /*!\brief Whether this binding is bound to the PipelineExecutor output interface.*/ bool IsGlobalOutput() const { return global_output_index_ > -1; } - + /*!\brief Getting the global output index in the current binding.*/ + int GetGlobalOutputIndex() const { return global_output_index_; } + /*!\brief Returning the binding configuration.*/ + std::unordered_map& Get() { return bindings_; } /*! * \brief Create a module interface map from JSONReader. * \param reader JSON reader. @@ -123,6 +137,19 @@ class ConfigOutputBindings { } return num_output; } + /*! + *\brief Getting the map which includes the global outputs and the current module outputs. + *\return A list of "GlobalOutputPair". + */ + std::vector GetGlobalConfigOutputBindings(void) const { + std::vector ret; + for (auto bindings : output_binding_map_) { + if (bindings.second.IsGlobalOutput()) { + ret.push_back(GlobalOutputPair(bindings.first, bindings.second.GetGlobalOutputIndex())); + } + } + return ret; + } /*! * \brief Create a output binding map from JSONReader. * \param reader Json reader. @@ -158,11 +185,19 @@ class ConfigOutputBindings { */ class ConfigPipelineExecution { public: + ConfigOutputBindings& operator[](int key) { + ICHECK(config_.find(key) != config_.end()); + return config_[key]; + } /* *!\brief This function is used to verify whether config is loaded successfully. * \return Return "true" to indicate that this class has not been successfully loaded. */ bool Empty() { return config_.empty(); } + /*! + *\brief Check if the module index existing in the "config". + */ + bool FindModuleInConfig(int mod_idx) { return config_.find(mod_idx) != config_.end(); } /*! * \brief Getting the number of global outputs. * \return The number of outputs in the entire pipeline. @@ -174,6 +209,31 @@ class ConfigPipelineExecution { } return num_output; } + /* + *!\brief Get the map of global outputs and module outputs. + */ + std::unordered_map GetGlobalConfigOutputBindings(void) const { + return global_output_map_; + } + /* + *!\brief Parsing the configuration. + */ + void ParseConfiguration(const std::unordered_map& config) { + if (config.empty()) { + LOG(FATAL) << "The Configuration loading not finish yet."; + } + for (auto mod_output : config) { + // Using the global output index as the key to create a map including global index and + // module output index. + const std::vector& global_output = + mod_output.second.GetGlobalConfigOutputBindings(); + + for (auto output : global_output) { + global_output_map_[output.second] = ModuleOutputPair(mod_output.first, output.first); + } + } + return; + } /*! * \brief Create a pipeline config from JSONReader. * \param reader Json reader. @@ -203,6 +263,8 @@ class ConfigPipelineExecution { // Build the mapping of mod_idx and "ConfigOutputBindings". config_[mod_idx] = output; } + // Doing the configuration parsing after the loading finished. + ParseConfiguration(config_); } private: @@ -211,6 +273,11 @@ class ConfigPipelineExecution { * information. */ std::unordered_map config_; + /* + *\brief The key is the global output index, and the map is including global outputs index and + * the module outputs pair. + */ + std::unordered_map global_output_map_; }; struct InputConnectionConfig { @@ -314,9 +381,11 @@ class BackendRuntime { /*!\brief The packed functions.*/ tvm::runtime::PackedFunc set_input_; tvm::runtime::PackedFunc get_input_; + tvm::runtime::PackedFunc get_output_; tvm::runtime::PackedFunc get_num_output_; tvm::runtime::PackedFunc get_num_inputs_; tvm::runtime::PackedFunc get_input_index_; + tvm::runtime::PackedFunc run_; /*! * \brief Copying from a given tensor and using 'CPU' as the device. */ @@ -367,6 +436,8 @@ class BackendRuntime { get_num_inputs_ = module_.GetFunction("get_num_inputs"); set_input_ = module_.GetFunction("set_input"); get_input_ = module_.GetFunction("get_input"); + get_output_ = module_.GetFunction("get_output"); + run_ = module_.GetFunction("run"); } BackendRuntime(void) {} ~BackendRuntime() { @@ -374,6 +445,11 @@ class BackendRuntime { TVMArrayFree(data.second); } } + /*!\brief Creating a NDArray containing same shape and data type with a module output. */ + NDArray CreateFromOutput(int idx) { + NDArray data = get_output_(idx); + return CreateNDArrayFromDLTensor(const_cast(data.operator->())); + } /*!\brief Return the index of the current module.*/ int GetModuleIndex() { return runtime_idx_; } /*!\brief Return the number of output*/ @@ -395,6 +471,10 @@ class BackendRuntime { NDArray GetInput(int index) const { return get_input_(index); } /*!\bief Getting the input data via the input name.*/ int GetInputIndex(const std::string& name) { return get_input_index_(name); } + /*!\brief Using the output index to get the module output.*/ + NDArray GetOutput(int index) { return get_output_(index); } + /*!\brief Running the runtime.*/ + void Run() { run_(); } }; /*! * \brief The information used to initialize the graph executor module, the information diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index 99c24ef93b80..0851d377fe6a 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -136,7 +136,70 @@ def recreate_parameters(mod): for key, value in lib.params.items(): new_value = value.numpy() + np.full(value.shape, 10).astype(value.dtype) mod_customized_params[key] = tvm.nd.array(new_value) - return mod_customized_params + return mod_customized_params, mod + + +def run_modules( + mod_configs, + dev, + target, + global_input_name, + global_input_data, + mod_set_input, + input_name, + input_data, + params_mod=None, + params=None, +): + # Running modules in serialized model. The returnning data are used to verify the pipeline + # executor result. + mod_input = {} + final_output = {} + idx = 0 + for mod in mod_configs: + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target) + + m = graph_executor.GraphModule(lib["default"](dev)) + # Getting the input data then setting the input data into the module. + if idx in mod_input: + for input in mod_input[idx]: + input = mod_input[idx][input] + m.set_input(input["index"], input["data"]) + else: + m.set_input(global_input_name, global_input_data) + + # Setting the "input_data" into the module. + if mod == mod_set_input: + m.set_input(input_name, input_data) + # If the module is "params_mod" then setting the parameters to this module. + if params_mod == mod: + m.set_input(None, None, **params) + + m.run() + n = m.get_num_outputs() + # Setting current output data as the input of next module. + mconfig = mod_configs[mod] + for output in mconfig["pipeline"]["output"]: + output_data = m.get_output(output["output_idx"]).numpy() + for dep in output["dependencies"]: + is_global = False + if "global_output_index" in dep: + is_global = True + name = dep["global_output_index"] + else: + mod_idx = dep["mod_idx"] + name = dep["input_name"] + if is_global: + final_output[name] = output_data + else: + if mod_idx in mod_input: + mod_input[mod_idx][name] = {"index": name, "data": output_data} + else: + mod_input[mod_idx] = {name: {"index": name, "data": output_data}} + idx = idx + 1 + + return final_output def test_pipe_runtime_error_check(): @@ -188,7 +251,7 @@ def test_pipe_runtime_error_check(): with tvm.transform.PassContext(opt_level=3): pipeline_mod_factory = pipeline_executor.build(pipe_config) pipeline_module = pipeline_executor.PipelineModule(pipeline_mod_factory) - customized_parameters = recreate_parameters(mod1) + customized_parameters, _ = recreate_parameters(mod1) # Checking the pipeline executor runtime errors. with pytest.raises(RuntimeError): @@ -212,9 +275,10 @@ def test_pipeline(): pipe_config = pipeline_executor.PipelineConfig() - customized_parameters = recreate_parameters(mod2) + customized_parameters, customized_parameters_mod = recreate_parameters(mod1) + assert customized_parameters_mod == mod1 # The global parameters group named "param_0" will be connected to "mod1" as parameters. - pipe_config["param_group"]["param_0"].connect(pipe_config[mod2]["param"]) + pipe_config["param_group"]["param_0"].connect(pipe_config[mod1]["param"]) # The pipeline input named "data_0" will be connected to a input named "data_0" # of mod1. pipe_config["input"]["data_a"].connect(pipe_config[mod1]["input"]["data_0"]) @@ -237,7 +301,6 @@ def test_pipeline(): # The mod3 output[0] will be connected to pipeline output[1]. pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"]) - print(pipe_config) # Print configueration (print(pipe_config)), the result looks like following. # # Inputs @@ -291,18 +354,45 @@ def test_pipeline(): input_map = pipeline_module_test.get_input_pipeline_map("data_a") assert input_map[0] == "0" and input_map[1] == "data_0" module_index = pipeline_module_test.get_params_group_pipeline_map("param_0") - assert module_index == 1 + assert module_index == 0 # Using the parameters group name to set parameters. pipeline_module_test.set_params("param_0", customized_parameters) - # Getting the result from the pipeline executor - data_a = np.full(dshape, 1).astype("float32") - data_b = np.full(dshape, 2).astype("float32") - pipeline_module_test.set_input("data_a", data_a) - pipeline_module_test.set_input("data_b", data_b) - input_data = pipeline_module_test.get_input("data_b") - tvm.testing.assert_allclose(data_b, input_data.numpy()) - input_data = pipeline_module_test.get_input("data_a") - tvm.testing.assert_allclose(data_a, input_data.numpy()) + for data in datas: + # Getting the result without setting customized parameters. + wrong_output = run_modules( + mconfig["module_connection"], + tvm.cpu(), + "llvm", + "data_0", + data, + mod2, + "data_1", + data, + ) + # Getting the result with setting customized parameters. + normal_output = run_modules( + mconfig["module_connection"], + tvm.cpu(), + "llvm", + "data_0", + data, + mod2, + "data_1", + data, + customized_parameters_mod, + customized_parameters, + ) + pipeline_module_test.set_input("data_a", data) + pipeline_module_test.set_input("data_b", data) + input_data = pipeline_module_test.get_input("data_a") + tvm.testing.assert_allclose(data, input_data.numpy()) + # Running the pipeline executor in sequential mode. + pipeline_module_test.run(True) + outputs = pipeline_module_test.get_output() + for i in range(len(outputs)): + tvm.testing.assert_allclose(normal_output[i], outputs[i].numpy()) + assert not (normal_output[i] == wrong_output[i]).all() + pipeline_module_test.stop() if __name__ == "__main__": From d3c0f4046e5665298ca4fa26b119903e59e0e38e Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Sat, 29 Jan 2022 04:52:21 +0800 Subject: [PATCH 35/49] [MetaSchedule][M4a] Mutator: Mutate Parallel (#10096) --- include/tvm/tir/schedule/instruction.h | 3 + python/tvm/meta_schedule/mutator/__init__.py | 1 + .../meta_schedule/mutator/mutate_parallel.py | 33 ++ src/meta_schedule/mutator/mutate_parallel.cc | 312 ++++++++++++++++++ src/tir/schedule/analysis.h | 20 ++ src/tir/schedule/analysis/analysis.cc | 31 ++ src/tir/schedule/instruction.cc | 5 + ...t_meta_schedule_mutator_mutate_parallel.py | 113 +++++++ 8 files changed, 518 insertions(+) create mode 100644 python/tvm/meta_schedule/mutator/mutate_parallel.py create mode 100644 src/meta_schedule/mutator/mutate_parallel.cc create mode 100644 tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index 5a9e687dc8c7..1af5ab07e67c 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object { // not visited: f_attrs_from_json } + /*! \brief Checks if the instruction kind is EnterPostproc */ + bool IsPostproc() const; + static constexpr const char* _type_key = "tir.InstructionKind"; TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object); }; diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py index 85deb7253e86..af3485b679f1 100644 --- a/python/tvm/meta_schedule/mutator/__init__.py +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -21,4 +21,5 @@ """ from .mutator import Mutator, PyMutator from .mutate_compute_location import MutateComputeLocation +from .mutate_parallel import MutateParallel from .mutate_unroll import MutateUnroll diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py new file mode 100644 index 000000000000..c66dddb825f4 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the parallel extent""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateParallel") +class MutateParallel(Mutator): + """Mutator that mutates the parallel extent""" + + def __init__(self, max_jobs_per_core: int) -> None: + """Mutator that mutates the parallel extent""" + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + ) diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc new file mode 100644 index 000000000000..7c973879f2cc --- /dev/null +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if the instruction is annotation with `meta_schedule_parallel` + * \param inst The instruction to be checked + * \return Whether the instruction is annotation with `meta_schedule_parallel` + */ +bool IsAnnotateWithParallel(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_parallel; +} + +/*! + * \brief Replace the annotation value + * \param inst The instruction to be replaced + * \param ann_val The new annotation value + * \return The replaced instruction + */ +Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { + ICHECK_EQ(inst->inputs.size(), 2); + return Instruction(/*kind=*/inst->kind, // + /*inputs=*/{inst->inputs[0], Integer(ann_val)}, // + /*attrs=*/inst->attrs, + /*outputs=*/inst->outputs); +} + +/*! + * \brief Get the output of the instruction Get-Block + * \param inst The instruction to be checked + * \return The output of the instruction Get-Block + */ +const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { + static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock"); + if (!inst->kind.same_as(inst_get_block)) { + return nullptr; + } + ICHECK_EQ(inst->outputs.size(), 1); + const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode); + return block; +} + +/*! + * \brief Analyze the parallel structure + * \param self The schedule state + * \param block_name The name of the root block + * \param func_name The name of the PrimFunc + * \param limit The uplimit of the parallelism + * \return The parallel structure + */ +std::vector> AnalyzeParallel(const ScheduleState& self, + const String& block_name, const String& func_name, + int64_t limit) { + Array block_srefs = tir::GetBlocks(self, block_name, func_name); + ICHECK_EQ(block_srefs.size(), 1); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); + ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); + std::vector> results; + results.reserve(info.realizes.size()); + for (const BlockRealize& realize : info.realizes) { + // Step 1. Extract static loop extents for spatial loops + std::vector loop_extents; + const ForNode* loop = nullptr; + for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent; + (loop = loop_sref->StmtAs()) != nullptr; // + loop_sref = loop_sref->parent) { + int64_t loop_extent = -1; + if (const auto* ext = GetLoopIntExtent(loop)) { + if (!info.non_spatial_vars.count(loop->loop_var.get())) { + loop_extent = *ext; + } + } + if (loop_extent != -1) { + loop_extents.push_back(loop_extent); + } else { + loop_extents.clear(); + } + } + // Step 2. Take the prefix product of loop extents + if (!loop_extents.empty()) { + results.emplace_back(); + std::vector& result = results.back(); + result.reserve(loop_extents.size()); + int64_t prod_extent = 1; + for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) { + result.push_back(prod_extent *= *it); + if (prod_extent >= limit) { + break; + } + } + } + } + return results; +} + +/*! + * \brief Get the number of parallelizable loops for each subtree + * \param loop_extent_prods The parallel structure for each subtree + * \param limit The uplimit of the parallelism + * \return The number of parallelizable loops for each subtree + */ +std::vector GetNumFusedLoops(const std::vector>& loop_extent_prods, + int64_t limit) { + std::vector results; + results.reserve(loop_extent_prods.size()); + for (const std::vector& prods : loop_extent_prods) { + int n = prods.size(); + int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin(); + if (i > 0 && prods[i - 1] == limit) { + --i; + } + if (i != n) { + ++i; + } + results.push_back(i); + } + return results; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates the parallel extent */ +class MutateParallelNode : public MutatorNode { + public: + /*! + * \brief The maximum number of jobs to be launched per CPU core. + * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int64_t max_jobs_per_core; + /*! \brief The number of cores in CPU. */ + int max_parallel_extent_; + /*! \brief JSON representation of the workload */ + std::string json_mod_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + // `max_parallel_extent_` is not visited. + // `json_mod` is not visited. + } + + static constexpr const char* _type_key = "meta_schedule.MutateParallel"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final { + Target target = context->target.value(); + this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core; + this->json_mod_ = SaveJSON(context->mod.value()); + } + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief The candidate to be mutated */ +struct MutateParallelNode::Candidate { + /*! \brief The annotation instruction */ + Instruction inst; + /*! \brief The current parallel extent */ + int64_t parallel_extent; + /*! \brief The name of the root block */ + String block_name; + /*! \brief The name of the PrimFunc */ + String func_name; +}; + +/*! + * \brief Get an instruction that annotates the maximum parallel extent + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidate The candidate to be mutated + * \return Whether a decision is found + */ +bool FindParallelDecision(const Trace& trace, TRandState* rand_state, + MutateParallelNode::Candidate* candidate) { + using tir::BlockRVNode; + using tir::InstructionNode; + std::unordered_map get_block_insts; + std::vector ann_insts; + get_block_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (tir::IsAnnotateWithParallel(inst)) { + ann_insts.push_back(inst.get()); + } + if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) { + get_block_insts[block_rv] = inst.get(); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const InstructionNode* get_block_inst = + get_block_insts.at(Downcast(ann_inst->inputs[0]).get()); + ICHECK_EQ(get_block_inst->attrs.size(), 2); + candidate->inst = GetRef(ann_inst); + candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; + candidate->block_name = Downcast(get_block_inst->attrs[0]); + candidate->func_name = Downcast(get_block_inst->attrs[1]); + return true; +} + +Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { + // Step 1. Find a parallel decision. + Candidate candidate; + if (!FindParallelDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + // Step 2. Replay the instructions to recover loop extents + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/Downcast(LoadJSON(this->json_mod_)), // + /*rand_state=*/ForkSeed(rand_state), // + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + // Step 3. Find all possible parallel plans + std::vector> loop_extent_prods = tir::AnalyzeParallel( + sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_); + std::unordered_map> limit2plan; + std::map, int64_t> plan2limit; + for (const std::vector& prods : loop_extent_prods) { + for (int64_t limit : prods) { + if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) { + std::vector plan = tir::GetNumFusedLoops(loop_extent_prods, limit); + limit2plan[limit] = plan; + plan2limit[plan] = limit; + } + } + } + // Step 4. Remove the original plan and remove it + std::vector original_plan = + tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); + auto it = plan2limit.find(original_plan); + if (it != plan2limit.end()) { + plan2limit.erase(it); + } + // Step 5. Pick a new plan + int n_plans = plan2limit.size(); + if (n_plans == 0) { + return NullOpt; + } + it = plan2limit.begin(); + for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { + ++it; + } + int64_t limit = it->second; + // Step 6. Assemble a new trace + Array insts; + insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst.same_as(candidate.inst)) { + insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit)); + } else if (inst->kind->IsPostproc()) { + break; + } else { + insts.push_back(inst); + } + } + return Trace(insts, trace->decisions); +} + +Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + return Mutator(n); +} + +TVM_REGISTER_NODE_TYPE(MutateParallelNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 591201312cd2..cdbb70bef6dd 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -91,6 +91,26 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline, bool require_subtree_compact_dataflow); +/*! + * \brief The information of a block scope, including the leaf blocks, + * as well as the loop types (spatial, reduction) for each loop in the scope. + */ +struct ScopeBlockLoopInfo { + /*! \brief A list of the leaf blocks, from left to right */ + std::vector realizes; + /*! \brief The loop vars bound to spatial block iters */ + std::unordered_set spatial_vars; + /*! \brief The loop vars bound to non-spatial block iters */ + std::unordered_set non_spatial_vars; +}; + +/*! + * \brief Inspect the scope of the given sref + * \param scope_block The root block of the scope + * \return The information of the scope + */ +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); + /*! * \brief Checks whether the block is a complete block under the scope * \param self The schedule state diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 1579f9154fe6..afdff9d5f832 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -150,6 +150,37 @@ Definition of a scope that is a stage pipeline: return scope_root_sref; } +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { + struct Collector : public StmtVisitor { + void VisitStmt_(const BlockRealizeNode* realize) final { + result.realizes.push_back(GetRef(realize)); + const Array& iter_vars = realize->block->iter_vars; + const Array& iter_values = realize->iter_values; + ICHECK_EQ(iter_vars.size(), iter_values.size()); + int n = realize->iter_values.size(); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = iter_vars[i]; + const PrimExpr& iter_value = iter_values[i]; + std::unordered_set* vars = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + vars = &result.spatial_vars; + } else { + vars = &result.non_spatial_vars; + } + PostOrderVisit(iter_value, [vars](const ObjectRef& obj) { + if (const VarNode* var = obj.as()) { + vars->insert(var); + } + }); + } + } + + ScopeBlockLoopInfo result; + } visitor; + visitor(scope_block->body); + return std::move(visitor.result); +} + /*! * \brief Check the dominant property of a block: * the block is the only writer of its output, dominating the reader of its output buffers diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index af721767c32f..cedba4b96095 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -21,6 +21,11 @@ namespace tvm { namespace tir { +bool InstructionKindNode::IsPostproc() const { + static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); + return this == inst_enter_postproc.get(); +} + Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, Array outputs) { ObjectPtr n = make_object(); diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py new file mode 100644 index 000000000000..e263114ef60f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateParallel, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.parallel", ann_val=ann_val) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator: + mutator = MutateParallel(max_jobs_per_core) + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_parallel_matmul(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + max_jobs_per_core=256, + ) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ann_val=64, + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + ann_val = int(trace.insts[-1].inputs[1]) + results.add(ann_val) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {4, 32, 4096} + + +if __name__ == """__main__""": + test_mutate_parallel_matmul() From d752e4aef529c4983abc5d52b6b638c3dc5b292b Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Fri, 28 Jan 2022 14:55:22 -0800 Subject: [PATCH 36/49] [Hexagon] Update hexagon API build instruction and cleanup hexagon_proxy_rpc (#10068) * Fix hexagon api build and Update Readme * Cleanup hexagon_proxy_rpc * Target Hack * Remove hack * address @cconvey comments * remove the rest of proxy rpc --- apps/hexagon_api/CMakeLists.txt | 2 + apps/hexagon_proxy_rpc/Readme.md | 82 ----- apps/hexagon_proxy_rpc/cmake/HexagonRPC.cmake | 57 --- .../cmake/android/CMakeLists.txt | 104 ------ .../cmake/hexagon/CMakeLists.txt | 81 ----- apps/hexagon_proxy_rpc/common.h | 59 ---- apps/hexagon_proxy_rpc/hexagon_core.cc | 204 ----------- apps/hexagon_proxy_rpc/hexagon_proxy_rpc.idl | 35 -- apps/hexagon_proxy_rpc/rpc_env.cc | 326 ------------------ cmake/modules/Hexagon.cmake | 15 +- python/tvm/contrib/hexagon/build.py | 2 +- .../test_hexagon/proxy_rpc/__init__.py | 18 - .../test_hexagon/proxy_rpc/test_matmul.py | 73 ---- .../contrib/test_hexagon/rpc/test_launcher.md | 45 ++- 14 files changed, 37 insertions(+), 1066 deletions(-) delete mode 100644 apps/hexagon_proxy_rpc/Readme.md delete mode 100644 apps/hexagon_proxy_rpc/cmake/HexagonRPC.cmake delete mode 100644 apps/hexagon_proxy_rpc/cmake/android/CMakeLists.txt delete mode 100644 apps/hexagon_proxy_rpc/cmake/hexagon/CMakeLists.txt delete mode 100644 apps/hexagon_proxy_rpc/common.h delete mode 100644 apps/hexagon_proxy_rpc/hexagon_core.cc delete mode 100644 apps/hexagon_proxy_rpc/hexagon_proxy_rpc.idl delete mode 100644 apps/hexagon_proxy_rpc/rpc_env.cc delete mode 100644 tests/python/contrib/test_hexagon/proxy_rpc/__init__.py delete mode 100644 tests/python/contrib/test_hexagon/proxy_rpc/test_matmul.py diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt index 3c5eb616f1da..557dcfb85045 100644 --- a/apps/hexagon_api/CMakeLists.txt +++ b/apps/hexagon_api/CMakeLists.txt @@ -88,3 +88,5 @@ ExternalProject_Add_Step(hexagon_tvm_runtime_rpc copy_binaries DEPENDEES install ) +configure_file("${TVM_SOURCE_DIR}/src/runtime/hexagon/rpc/android_bash.sh.template" + ${HEXAGON_API_BINARY_DIR} COPYONLY) diff --git a/apps/hexagon_proxy_rpc/Readme.md b/apps/hexagon_proxy_rpc/Readme.md deleted file mode 100644 index d7b577b2b378..000000000000 --- a/apps/hexagon_proxy_rpc/Readme.md +++ /dev/null @@ -1,82 +0,0 @@ - - - - - - - - - - - - - - - - -# Hexagon Proxy RPC server - -The proxy RPC server for Hexagon is a wrapper which takes standard TVM RPC calls from a python host -to a remote Android device and forwards them across FastRPC to Hexagon. This RPC flow will be replaced -by running a minimal RPC server directly on Hexagon. For now we provide a prototype forwarding RPC server -for host driven execution on Hexagon. - -## Compilation - -Project inventory: -* Android - * libtvm_runtime.so (containing HexagonHostDeviceAPI src/runtime/Hexagon/proxy_rpc/device_api.cc) - * tvm_rpc (C++ RPC server) - * librpc_env (Hexagon specific RPC proxy environment) - -* Hexagon - * libhexagon_proxy_rpc_skel.so (Hexagon device code containing FastRPC endpoints for the Hexagon Proxy RPC server) - -All Android and Hexagon device artifacts will be placed in `apps_hexagon_proxy_rpc` from which they can be pushed -to an attached `adb` device. - -### Prerequisites - -1. Android NDK version r19c or later. -2. Hexagon SDK version 4.0.0 or later. - -Android NDK can be downloaded from https://developer.android.com/ndk. -Hexagon SDK is available at //developer.qualcomm.com/software/Hexagon-dsp-sdk. - -### Compilation with TVM - -Building the Hexagon Proxy RPC as a component of the main TVM build -used for Hexagon codegen can be achieved by setting `USE_HEXAGON_PROXY_RPC=ON`. -A minimal example invocation for compiling TVM along with the Hexagon Proxy RPC server -is included below: - -``` -cmake -DCMAKE_C_COMPILER=/path/to/clang \ - -DCMAKE_CXX_COMPILER=/path/to/clang++ \ - -DCMAKE_CXX_FLAGS='-stdlib=libc++' \ - -DCMAKE_CXX_STANDARD=14 \ - -DUSE_RPC=ON \ - -DUSE_LLVM=/path/to/llvm/bin/llvm-config \ - -DUSE_HEXAGON_PROXY_RPC=ON \ - -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-28 \ - -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ - -DUSE_HEXAGON_ARCH=v65|v66|v68 \ - -DUSE_HEXAGON_SDK=/path/to/Hexagon/SDK \ - -DUSE_HEXAGON_TOOLCHAIN=/path/to/Hexagon/toolchain/ .. -``` - -where `v65|v66|v68` means "one of" these architecture versions. -The Hexagon proxy RPC application (tvm_rpc) is an android binary and thus requires the use -of an android toolchain for compilation. Similarly, the Hexagon tvm runtime -requires the use of the Hexagon toolchain and depends on the Hexagon SDK. The -resulting Hexagon launcher binaries can be found in the `apps_Hexagon_launcher` -subdirectory of the cmake build directory. The above command -will build support for Hexagon codegen in the TVM library that requires -`USE_LLVM` to be set to an llvm-config that has the Hexagon target built in. - - -# Disclaimer - -The Hexagon proxy RPC is intended for use with prototyping and does not utilize any -performance acceleration, as such the measured performance may be very poor. diff --git a/apps/hexagon_proxy_rpc/cmake/HexagonRPC.cmake b/apps/hexagon_proxy_rpc/cmake/HexagonRPC.cmake deleted file mode 100644 index 3ae6c8a7e664..000000000000 --- a/apps/hexagon_proxy_rpc/cmake/HexagonRPC.cmake +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -if(NOT DEFINED USE_HEXAGON_SDK) - message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") -endif() -if (NOT DEFINED USE_HEXAGON_ARCH) - message(SEND_ERROR "Please set USE_HEXAGON_ARCH to the Hexagon architecture version") -endif() - -set(TVM_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../") - -include(ExternalProject) -include("${TVM_SOURCE_DIR}/cmake/utils/Utils.cmake") -include("${TVM_SOURCE_DIR}/cmake/modules/HexagonSDK.cmake") - -find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") - -include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_REMOTE_ROOT}) - -set(QAIC_EXE "${HEXAGON_QAIC_EXE}") -foreach(INCDIR IN LISTS HEXAGON_SDK_INCLUDES HEXAGON_REMOTE_ROOT) - list(APPEND QAIC_FLAGS "-I${INCDIR}") -endforeach() - -set(HEXAGON_PROXY_RPC_SRC "${CMAKE_CURRENT_SOURCE_DIR}/../../") -set(CMAKE_SKIP_RPATH TRUE) - -# Qaic for the domain header. -# -# Don't add paths to these filenames, or otherwise cmake may spontaneously -# add -o option to the qaic invocation (with an undesirable path). -set(HEXAGON_PROXY_RPC_IDL "hexagon_proxy_rpc.idl") -set(HEXAGON_PROXY_RPC_H "hexagon_proxy_rpc.h") -set(HEXAGON_PROXY_RPC_SKEL_C "hexagon_proxy_rpc_skel.c") -set(HEXAGON_PROXY_RPC_STUB_C "hexagon_proxy_rpc_stub.c") - -include_directories( - "${HEXAGON_PROXY_RPC_SRC}" - "${TVM_SOURCE_DIR}/include" - "${TVM_SOURCE_DIR}/3rdparty/dlpack/include" - "${TVM_SOURCE_DIR}/3rdparty/dmlc-core/include" -) diff --git a/apps/hexagon_proxy_rpc/cmake/android/CMakeLists.txt b/apps/hexagon_proxy_rpc/cmake/android/CMakeLists.txt deleted file mode 100644 index 869456cce7e7..000000000000 --- a/apps/hexagon_proxy_rpc/cmake/android/CMakeLists.txt +++ /dev/null @@ -1,104 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonAndroidRPC C CXX) - -include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonRPC.cmake") - -add_custom_command( - OUTPUT ${HEXAGON_PROXY_RPC_STUB_C} ${HEXAGON_PROXY_RPC_H} - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${HEXAGON_PROXY_RPC_SRC}/${HEXAGON_PROXY_RPC_IDL}" - MAIN_DEPENDENCY "${HEXAGON_PROXY_RPC_SRC}/${HEXAGON_PROXY_RPC_IDL}" -) - -include_directories(SYSTEM - "${HEXAGON_SDK_INCLUDES}" - "${HEXAGON_RPCMEM_ROOT}/inc" - "${CMAKE_CURRENT_BINARY_DIR}" # Output of qaic will go here -) - -link_directories(${HEXAGON_REMOTE_ROOT}) - -add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) - -set(TVM_RPC_ENV_SOURCES - ${HEXAGON_PROXY_RPC_SRC}/rpc_env.cc -) - -add_library(rpc_env SHARED - ${TVM_RPC_ENV_SOURCES} - ${HEXAGON_PROXY_RPC_H} - ${HEXAGON_PROXY_RPC_STUB_C} -) - -ExternalProject_Add(android_tvm_runtime - SOURCE_DIR "${TVM_SOURCE_DIR}" - BUILD_COMMAND $(MAKE) runtime - CMAKE_ARGS - "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" - "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" - "-DANDROID_ABI=${ANDROID_ABI}" - "-DCMAKE_CXX_STANDARD=14" - "-DUSE_LIBBACKTRACE=OFF" - "-DUSE_LLVM=OFF" - "-DUSE_RPC=ON" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(android_tvm_runtime BINARY_DIR) -ExternalProject_Add_Step(android_tvm_runtime copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/libtvm_runtime.so - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDEES install -) - -add_dependencies(rpc_env android_tvm_runtime) -add_library(a_tvm_runtime SHARED IMPORTED) -set_target_properties(a_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.so") - -target_link_libraries(rpc_env cdsprpc log a_tvm_runtime) - -# TVM CPP RPC build -set(TVM_RPC_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../cpp_rpc") - - -set(TVM_RPC_SOURCES - ${TVM_RPC_DIR}/main.cc - ${TVM_RPC_DIR}/rpc_server.cc -) - -# Set output to same directory as the other TVM libs -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) -add_executable(tvm_rpc ${TVM_RPC_SOURCES}) - - -target_include_directories( - tvm_rpc - PUBLIC "${TVM_RPC_DIR}../../include" - PUBLIC "${TVM_RPC_DIR}../../3rdparty/dlpack" - PUBLIC "${TVM_RPC_DIR}../../3rdparty/dmlc-core" -) - -add_dependencies(rpc_env android_tvm_runtime) -target_link_libraries(rpc_env a_tvm_runtime) - -add_dependencies(tvm_rpc android_tvm_runtime rpc_env) -target_link_libraries(tvm_rpc a_tvm_runtime rpc_env) diff --git a/apps/hexagon_proxy_rpc/cmake/hexagon/CMakeLists.txt b/apps/hexagon_proxy_rpc/cmake/hexagon/CMakeLists.txt deleted file mode 100644 index 525212bab3b3..000000000000 --- a/apps/hexagon_proxy_rpc/cmake/hexagon/CMakeLists.txt +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonRPCSkel C CXX) - -include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonRPC.cmake") - -add_custom_command( - OUTPUT ${HEXAGON_PROXY_RPC_SKEL_C} ${HEXAGON_PROXY_RPC_H} - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${HEXAGON_PROXY_RPC_SRC}/${HEXAGON_PROXY_RPC_IDL}" - MAIN_DEPENDENCY "${HEXAGON_PROXY_RPC_SRC}/${HEXAGON_PROXY_RPC_IDL}" -) - -include_directories(SYSTEM - ${HEXAGON_QURT_INCLUDES} - ${CMAKE_CURRENT_BINARY_DIR} # Output of qaic will go here -) - -link_directories(${HEXAGON_QURT_LIBS}) - -add_definitions(-D_MACH_I32=int) -add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0) -add_definitions(-DDMLC_USE_LOGGING_LIBRARY=) - -# Extra compile flags (both C and C++). -set(EXTRA_COMP_FLAGS - "-O3" - "-m${USE_HEXAGON_ARCH}" -) -string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") -set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") -set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") - -set(SKEL_SRCS - "${HEXAGON_PROXY_RPC_SRC}/hexagon_core.cc" -) - -add_library(hexagon_proxy_rpc_skel SHARED - "${HEXAGON_PROXY_RPC_H}" - "${HEXAGON_PROXY_RPC_SKEL_C}" - "${SKEL_SRCS}" -) - -ExternalProject_Add(static_hexagon_tvm_runtime - SOURCE_DIR "${TVM_SOURCE_DIR}" - BUILD_COMMAND $(MAKE) runtime - CMAKE_ARGS - "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" - "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DCMAKE_CXX_STANDARD=14" - "-DUSE_LIBBACKTRACE=OFF" - "-DUSE_LLVM=OFF" - "-DUSE_RPC=OFF" - "-DBUILD_STATIC_RUNTIME=ON" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(static_hexagon_tvm_runtime BINARY_DIR) - -add_dependencies(hexagon_proxy_rpc_skel static_hexagon_tvm_runtime) -add_library(h_tvm_runtime STATIC IMPORTED) -set_target_properties(h_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") - -target_link_libraries(hexagon_proxy_rpc_skel -Wl,--whole-archive h_tvm_runtime -Wl,--no-whole-archive) diff --git a/apps/hexagon_proxy_rpc/common.h b/apps/hexagon_proxy_rpc/common.h deleted file mode 100644 index d93c90a6278c..000000000000 --- a/apps/hexagon_proxy_rpc/common.h +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_PROXY_RPC_COMMON_H_ -#define TVM_RUNTIME_HEXAGON_PROXY_RPC_COMMON_H_ - -#include -#include -#include -#include -#include -#include - -#include -#include - -struct HandlePacket { - int ndim; - uint32_t handles[]; - int size() const { return size(ndim); } - static int size(int ndim) { return sizeof(HandlePacket) + ndim * sizeof(uint32_t); } -}; - -struct tensor_meta { - int ndim; - DLDataType dtype; - int64_t shape[]; - - int meta_size() const { return meta_size(ndim); } - int data_size() const { - int size = tvm::runtime::DataType(dtype).bytes(); - for (int d = 0; d != ndim; ++d) { - size *= shape[d]; - } - return size; - } - - static int meta_size(int ndim) { return sizeof(tensor_meta) + ndim * sizeof(int64_t); } - - std::string to_string() const; -}; - -#endif // TVM_RUNTIME_HEXAGON_PROXY_RPC_COMMON_H_ diff --git a/apps/hexagon_proxy_rpc/hexagon_core.cc b/apps/hexagon_proxy_rpc/hexagon_core.cc deleted file mode 100644 index e45bc24c30bf..000000000000 --- a/apps/hexagon_proxy_rpc/hexagon_core.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern "C" { -#include -#include -#include -#include -#include -#include -} - -#include -#include - -#include -#include -#include - -#include "common.h" -#include "hexagon_proxy_rpc.h" - -template -T* DeserializeToPointerType(unsigned int module) { - return reinterpret_cast(module); -} - -template -unsigned int SerializeFromPointerType(T* pointer) { - return *reinterpret_cast(&pointer); -} - -tvm::runtime::Module load_module(const std::string& file_name) { - static const tvm::runtime::PackedFunc loader = - *tvm::runtime::Registry::Get("runtime.module.loadfile_hexagon"); - tvm::runtime::TVMRetValue rv = loader(file_name); - if (rv.type_code() == kTVMModuleHandle) { - return rv.operator tvm::runtime::Module(); - } - return tvm::runtime::Module(); -} - -int __QAIC_HEADER(hexagon_proxy_rpc_open)(const char* uri, remote_handle64* handle) { - FARF(ALWAYS, "[hexagon_proxy_rpc_open] FastRPC connection established"); - *handle = 0; - const tvm::runtime::PackedFunc api = *tvm::runtime::Registry::Get("device_api.hexagon.v2"); - tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api); - return AEE_SUCCESS; -} - -int __QAIC_HEADER(hexagon_proxy_rpc_close)(remote_handle64 handle) { - // Comment to stop clang-format from single-lining this function. - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_load)(remote_handle64 handle, const char* module_path, - unsigned int* module) { - auto* mod_ptr = new tvm::runtime::Module(load_module(module_path)); - *module = SerializeFromPointerType(mod_ptr); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_unload)(remote_handle64 handle, unsigned int module) { - tvm::runtime::Module* mod_ptr = DeserializeToPointerType(module); - delete mod_ptr; - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_get_function)(remote_handle64 handle, const char* name, - unsigned int module, unsigned int* func) { - tvm::runtime::Module* mod_ptr = DeserializeToPointerType(module); - std::string fname(name); - tvm::runtime::PackedFunc f = (*mod_ptr)->GetFunction(fname); - auto* f_ptr = new tvm::runtime::PackedFunc(f); - *func = SerializeFromPointerType(f_ptr); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_release_function)(remote_handle64 handle, - unsigned int func) { - tvm::runtime::PackedFunc* f_ptr = DeserializeToPointerType(func); - delete f_ptr; - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_invoke)(remote_handle64 handle, unsigned int func, - const unsigned char* handles, int nhandles) { - tvm::runtime::PackedFunc* f_ptr = DeserializeToPointerType(func); - const auto* meta = reinterpret_cast(handles); - std::vector values; - std::vector type_codes; - for (size_t i = 0; i < meta->ndim; i++) { - tvm::runtime::NDArray* array = - DeserializeToPointerType(meta->handles[i]); - type_codes.push_back(kTVMDLTensorHandle); - values.emplace_back(); - const DLTensor* dltensor = array->operator->(); - values.back().v_handle = const_cast(static_cast(dltensor)); - } - - { - int res = qurt_hvx_reserve(QURT_HVX_RESERVE_ALL_AVAILABLE); - switch (res) { - case QURT_HVX_RESERVE_NOT_SUPPORTED: - case QURT_HVX_RESERVE_NOT_SUCCESSFUL: - FARF(ERROR, "error reserving HVX: %u", res); - return AEE_EFAILED; - default: - break; - } - // Lock HVX. - int lck = qurt_hvx_lock(QURT_HVX_MODE_128B); - if (lck != 0) { - FARF(ERROR, "error locking HVX: %u", lck); - return AEE_EFAILED; - } - } - tvm::runtime::TVMRetValue rv; - f_ptr->CallPacked(tvm::runtime::TVMArgs(values.data(), type_codes.data(), values.size()), &rv); - { - int unl = qurt_hvx_unlock(); - if (unl != 0) { - FARF(ERROR, "error unlocking HVX: %u", unl); - return AEE_EFAILED; - } - // Release HVX. - int rel = qurt_hvx_cancel_reserve(); - if (rel != 0) { - FARF(ERROR, "error canceling HVX reservation: %u", rel); - return AEE_EFAILED; - } - } - - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_allocate)(remote_handle64 handle, - const unsigned char* input_meta, - int input_meta_size, const char* mem_scope, - unsigned int* tensor) { - const auto* meta = reinterpret_cast(input_meta); - auto device = tvm::Device{static_cast(kDLHexagon), 0}; - tvm::runtime::Optional scope; - if (*mem_scope) { - scope = mem_scope; - } - auto* array = new tvm::runtime::NDArray(std::move(tvm::runtime::NDArray::Empty( - tvm::ShapeTuple(meta->shape, meta->shape + meta->ndim), meta->dtype, device, scope))); - *tensor = SerializeFromPointerType(array); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_read)(remote_handle64 handle, unsigned char* dst_ptr, - int nbytes, unsigned int src) { - tvm::runtime::NDArray* src_ptr = DeserializeToPointerType(src); - const DLTensor* t = src_ptr->operator->(); - tvm::ShapeTuple shape(t->shape, t->shape + t->ndim); - auto* container = new tvm::runtime::NDArray::Container( - static_cast(dst_ptr), shape, src_ptr->operator->()->dtype, tvm::Device{kDLCPU, 0}); - container->SetDeleter([](tvm::Object* container) { - delete static_cast(container); - }); - tvm::runtime::NDArray dst(GetObjectPtr(container)); - dst.CopyFrom(*src_ptr); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_write)(remote_handle64 handle, unsigned int dst, - const unsigned char* src_ptr, int nbytes) { - tvm::runtime::NDArray* dst_ptr = DeserializeToPointerType(dst); - const DLTensor* t = dst_ptr->operator->(); - tvm::ShapeTuple shape(t->shape, t->shape + t->ndim); - auto* container = - new tvm::runtime::NDArray::Container(const_cast(src_ptr), shape, - dst_ptr->operator->()->dtype, tvm::Device{kDLCPU, 0}); - container->SetDeleter([](tvm::Object* container) { - delete static_cast(container); - }); - tvm::runtime::NDArray src(GetObjectPtr(container)); - dst_ptr->CopyFrom(src); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(hexagon_proxy_rpc_release)(remote_handle64 handle, unsigned int array) { - tvm::runtime::NDArray* array_ptr = DeserializeToPointerType(array); - delete array_ptr; - return AEE_SUCCESS; -} diff --git a/apps/hexagon_proxy_rpc/hexagon_proxy_rpc.idl b/apps/hexagon_proxy_rpc/hexagon_proxy_rpc.idl deleted file mode 100644 index 0badf382d943..000000000000 --- a/apps/hexagon_proxy_rpc/hexagon_proxy_rpc.idl +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "remote.idl" -#include "AEEStdDef.idl" - -typedef sequence buffer; - -interface hexagon_proxy_rpc : remote_handle64 { - AEEResult load(in string module_path, rout unsigned long mod); - AEEResult unload(in unsigned long mod); - AEEResult get_function(in string name, in unsigned long mod, rout unsigned long func); - AEEResult release_function(in unsigned long func); - AEEResult invoke(in unsigned long func, in buffer handles); - AEEResult allocate(in buffer template_tensor, in string mem_scope, rout unsigned long tensor); - AEEResult read(rout buffer dst_ptr, in unsigned long src); - AEEResult write(in unsigned long dst, in buffer src_ptr); - AEEResult release(in unsigned long array); -}; diff --git a/apps/hexagon_proxy_rpc/rpc_env.cc b/apps/hexagon_proxy_rpc/rpc_env.cc deleted file mode 100644 index 911ca580ba4f..000000000000 --- a/apps/hexagon_proxy_rpc/rpc_env.cc +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file rpc_env.cc - * \brief Server environment of the RPC. - */ -#include "../cpp_rpc/rpc_env.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "../../src/support/utils.h" -#include "common.h" -#include "hexagon_proxy_rpc.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief CleanDir Removes the files from the directory - * \param dirname THe name of the directory - */ -void CleanDir(const std::string& dirname); - -namespace hexagon { -using FastRPCHandle = remote_handle64; -using Handle = uint32_t; - -AEEResult enable_unsigned_pd(bool enable) { - remote_rpc_control_unsigned_module data; - data.domain = CDSP_DOMAIN_ID; - data.enable = static_cast(enable); - AEEResult rc = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, &data, sizeof(data)); - if (rc != AEE_SUCCESS) { - std::cout << "error " << (enable ? "enabling" : "disabling") << " unsigned PD\n"; - } - return rc; -} - -AEEResult set_remote_stack_size(int size) { - remote_rpc_thread_params data; - data.domain = CDSP_DOMAIN_ID; - data.prio = -1; - data.stack_size = size; - AEEResult rc = remote_session_control(FASTRPC_THREAD_PARAMS, &data, sizeof(data)); - if (rc != AEE_SUCCESS) { - std::cout << "error setting remote stack size: " << std::hex << rc << '\n'; - } - return rc; -} - -class FastRPCChannel { - public: - explicit FastRPCChannel(const std::string& uri) { - enable_unsigned_pd(true); - set_remote_stack_size(128 * 1024); - - int rc = hexagon_proxy_rpc_open(uri.c_str(), &handle_); - if (rc != AEE_SUCCESS) { - handle_ = std::numeric_limits::max(); - } - } - - ~FastRPCChannel() { - if (handle_ == std::numeric_limits::max()) { - return; - } - - hexagon_proxy_rpc_close(handle_); - handle_ = std::numeric_limits::max(); - } - - FastRPCHandle GetHandle() { return handle_; } - - private: - FastRPCHandle handle_ = std::numeric_limits::max(); -}; - -class HexagonModuleNode : public ModuleNode { - public: - HexagonModuleNode() = delete; - HexagonModuleNode(FastRPCHandle h, std::string file_name) : handle_(h), mod_{0} { - AEEResult rc = hexagon_proxy_rpc_load(handle_, file_name.c_str(), &mod_); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error loading module\n"; - } - } - ~HexagonModuleNode() { - AEEResult rc = hexagon_proxy_rpc_unload(handle_, mod_); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error unloading module\n"; - } - for (Handle func : packed_func_handles_) { - AEEResult rc = hexagon_proxy_rpc_release_function(handle_, func); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error releasing function\n"; - } - } - } - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - hexagon::Handle func; - AEEResult rc = hexagon_proxy_rpc_get_function(handle_, name.c_str(), mod_, &func); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error calling get_function\n"; - } - packed_func_handles_.push_back(func); - return PackedFunc([handle = this->handle_, func, name](TVMArgs args, TVMRetValue* rv) { - std::vector handles; - for (size_t i = 0; i < args.size(); i++) { - ICHECK_EQ(args.type_codes[i], kTVMDLTensorHandle); - DLTensor* tensor = args[i]; - auto f = runtime::Registry::Get("runtime.hexagon.GetHandle"); - int32_t thandle = (*f)(tensor->data); - handles.push_back(thandle); - } - auto* packet = reinterpret_cast(rpcmem_alloc( - RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, HandlePacket::size(args.size()))); - packet->ndim = args.size(); - std::copy(handles.begin(), handles.end(), packet->handles); - AEEResult rc = hexagon_proxy_rpc_invoke( - handle, func, reinterpret_cast(packet), packet->size()); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error invoking function: " << name; - } - rpcmem_free(packet); - }); - } - const char* type_key() const { return "HexagonModule"; } - - private: - FastRPCHandle handle_; - Handle mod_; - std::vector packed_func_handles_; -}; -} // namespace hexagon - -RPCEnv::RPCEnv(const std::string& wd) { - if (wd != "") { - base_ = wd + "/.cache"; - mkdir(wd.c_str(), 0777); - mkdir(base_.c_str(), 0777); - } else { - char cwd[PATH_MAX]; - auto cmdline = fopen("/proc/self/cmdline", "r"); - fread(cwd, 1, sizeof(cwd), cmdline); - fclose(cmdline); - std::string android_base_ = "/data/data/" + std::string(cwd) + "/cache"; - struct stat statbuf; - // Check if application data directory exist. If not exist, usually means we run tvm_rpc from - // adb shell terminal. - if (stat(android_base_.data(), &statbuf) == -1 || !S_ISDIR(statbuf.st_mode)) { - // Tmp directory is always writable for 'shell' user. - android_base_ = "/data/local/tmp"; - } - base_ = android_base_ + "/rpc"; - mkdir(base_.c_str(), 0777); - } - - static hexagon::FastRPCChannel hexagon_proxy_rpc(hexagon_proxy_rpc_URI CDSP_DOMAIN); - if (hexagon_proxy_rpc.GetHandle() == -1) { - LOG(FATAL) << "Error opening FastRPC channel\n"; - } - - TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPath(args[0]); - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") - .set_body([this, handle = hexagon_proxy_rpc.GetHandle()](TVMArgs args, TVMRetValue* rv) { - std::string file_name = this->GetPath(args[0]); - auto n = make_object(handle, file_name); - *rv = Module(n); - LOG(INFO) << "Load module from " << file_name << " ..."; - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.hexagon.allocate") - .set_body([handle = hexagon_proxy_rpc.GetHandle()](TVMArgs args, TVMRetValue* rv) { - DLTensor* ext_tensor = args[0]; - Optional mem_scope = args[1]; - - auto* input_meta = reinterpret_cast(rpcmem_alloc( - RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, tensor_meta::meta_size(ext_tensor->ndim))); - input_meta->ndim = ext_tensor->ndim; - input_meta->dtype = ext_tensor->dtype; - std::copy(ext_tensor->shape, ext_tensor->shape + ext_tensor->ndim, input_meta->shape); - - hexagon::Handle hexagon_buffer; - const char* scope = mem_scope.defined() ? mem_scope.value().c_str() : ""; - AEEResult rc = - hexagon_proxy_rpc_allocate(handle, reinterpret_cast(input_meta), - input_meta->meta_size(), scope, &hexagon_buffer); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error allocating hexagon ndrray\n"; - } - rpcmem_free(input_meta); - *rv = static_cast(hexagon_buffer); - return rc == AEE_SUCCESS; - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.hexagon.read_to_host") - .set_body([handle = hexagon_proxy_rpc.GetHandle()](TVMArgs args, TVMRetValue* rv) { - void* host_ptr = static_cast(args[0]); - size_t nbytes = args[1]; - hexagon::Handle hexagon_buffer = static_cast(args[2]); - AEEResult rc = hexagon_proxy_rpc_read(handle, static_cast(host_ptr), - static_cast(nbytes), hexagon_buffer); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error reading from hexagon buffer\n"; - } - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.hexagon.write_from_host") - .set_body([handle = hexagon_proxy_rpc.GetHandle()](TVMArgs args, TVMRetValue* rv) { - hexagon::Handle hexagon_buffer = static_cast(args[0]); - void* host_ptr = static_cast(args[1]); - size_t nbytes = args[2]; - AEEResult rc = hexagon_proxy_rpc_write( - handle, hexagon_buffer, static_cast(host_ptr), static_cast(nbytes)); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error writing to hexagon buffer\n"; - } - }); - - TVM_REGISTER_GLOBAL("tvm.rpc.hexagon.release") - .set_body([handle = hexagon_proxy_rpc.GetHandle()](TVMArgs args, TVMRetValue* rv) { - hexagon::Handle hexagon_buffer = static_cast(args[0]); - AEEResult rc = hexagon_proxy_rpc_release(handle, hexagon_buffer); - if (rc != AEE_SUCCESS) { - LOG(FATAL) << "Error writing to hexagon buffer\n"; - } - }); -} - -/*! - * \brief GetPath To get the work path from packed function - * \param file_name The file name - * \return The full path of file. - */ -std::string RPCEnv::GetPath(const std::string& file_name) const { - // we assume file_name has "/" means file_name is the exact path - // and does not create /.rpc/ - return file_name.find('/') != std::string::npos ? file_name : base_ + "/" + file_name; -} -/*! - * \brief Remove The RPC Environment cleanup function - */ -void RPCEnv::CleanUp() const { - CleanDir(base_); - const int ret = rmdir(base_.c_str()); - if (ret != 0) { - LOG(WARNING) << "Remove directory " << base_ << " failed"; - } -} - -/*! - * \brief ListDir get the list of files in a directory - * \param dirname The root directory name - * \return vector Files in directory. - */ -std::vector ListDir(const std::string& dirname) { - std::vector vec; - DIR* dp = opendir(dirname.c_str()); - if (dp == nullptr) { - int errsv = errno; - LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); - } - dirent* d; - while ((d = readdir(dp)) != nullptr) { - std::string filename = d->d_name; - if (filename != "." && filename != "..") { - std::string f = dirname; - if (f[f.length() - 1] != '/') { - f += '/'; - } - f += d->d_name; - vec.push_back(f); - } - } - closedir(dp); - return vec; -} - -/*! - * \brief CleanDir Removes the files from the directory - * \param dirname The name of the directory - */ -void CleanDir(const std::string& dirname) { - auto files = ListDir(dirname); - for (const auto& filename : files) { - std::string file_path = dirname + "/"; - file_path += filename; - const int ret = std::remove(filename.c_str()); - if (ret != 0) { - LOG(WARNING) << "Remove file " << filename << " failed"; - } - } -} -} // namespace runtime -} // namespace tvm diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index d4dfaf22d698..a990101bdecf 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -34,7 +34,7 @@ function(find_hexagon_toolchain) set(TRY_PATH "${USE_HEXAGON_SDK}") endif() message(STATUS "Looking for Hexagon toolchain in ${TRY_PATH}") - tvm_file_glob(GLOB_RECURSE HEXAGON_CLANG "${TRY_PATH}/*/hexagon-clang++") + file(GLOB_RECURSE HEXAGON_CLANG "${TRY_PATH}/*/hexagon-clang++") if(HEXAGON_CLANG) # The path is ${HEXAGON_TOOLCHAIN}/bin/hexagon-clang++. get_filename_component(HEXAGON_TMP0 "${HEXAGON_CLANG}" DIRECTORY) @@ -98,16 +98,9 @@ if(USE_HEXAGON_LAUNCHER STREQUAL "ON") message(SEND_ERROR "USE_HEXAGON_LAUNCHER is deprecated, please build apps separately") endif() -if(USE_HEXAGON_PROXY_RPC STREQUAL "ON") - message(SEND_ERROR "USE_HEXAGON_PROXY_RPC is deprecated, please build apps separately") -endif() - # find_hexagon_sdk_root has been called at this point. if(USE_HEXAGON_RPC) - set(HEXAGON_RPC_OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/hexagon_rpc") - file(MAKE_DIRECTORY ${HEXAGON_RPC_OUTPUT}) - set(TVMRT_SOURCE_DIR "${CMAKE_SOURCE_DIR}/src/runtime") set(QAIC_EXE "${HEXAGON_QAIC_EXE}") foreach(INCDIR IN LISTS HEXAGON_SDK_INCLUDES HEXAGON_REMOTE_ROOT) @@ -131,10 +124,6 @@ if(USE_HEXAGON_RPC) tvm_file_glob(GLOB RUNTIME_HEXAGON_SRCS "${TVMRT_SOURCE_DIR}/hexagon/rpc/android/*.cc") list(APPEND RUNTIME_HEXAGON_SRCS "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc_stub.c") - # copy android_bash template file - configure_file("${TVMRT_SOURCE_DIR}/hexagon/rpc/android_bash.sh.template" - ${HEXAGON_RPC_OUTPUT} COPYONLY) - elseif(BUILD_FOR_HEXAGON) # Hexagon part find_hexagon_toolchain() @@ -154,8 +143,6 @@ if(USE_HEXAGON_RPC) SYSTEM PRIVATE "${TVMRT_SOURCE_DIR}/hexagon/rpc" ) endif() - - set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES "${HEXAGON_RPC_OUTPUT}") endif() if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index e640aad89231..def9ea17ace0 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -49,7 +49,7 @@ def get_hexagon_rpc_dir() -> pathlib.Path: global HEXAGON_RPC_DIR if HEXAGON_RPC_DIR is None: for path in libinfo.find_lib_path(): - rpc_dir = os.path.join(os.path.dirname(path), "hexagon_rpc") + rpc_dir = os.path.join(os.path.dirname(path), "hexagon_api_output") if os.path.isdir(rpc_dir): HEXAGON_RPC_DIR = rpc_dir break diff --git a/tests/python/contrib/test_hexagon/proxy_rpc/__init__.py b/tests/python/contrib/test_hexagon/proxy_rpc/__init__.py deleted file mode 100644 index 5261dc9cf052..000000000000 --- a/tests/python/contrib/test_hexagon/proxy_rpc/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Testing infrastructure for Hexagon Proxy RPC """ diff --git a/tests/python/contrib/test_hexagon/proxy_rpc/test_matmul.py b/tests/python/contrib/test_hexagon/proxy_rpc/test_matmul.py deleted file mode 100644 index 839fdc9bc29d..000000000000 --- a/tests/python/contrib/test_hexagon/proxy_rpc/test_matmul.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os - -import tvm -import tvm.testing -from tvm import te -import tvm.contrib.hexagon.hexagon as hexagon -from tvm.contrib import utils -import numpy as np - -from ..conftest import requires_hexagon_toolchain, requires_rpc_tracker_and_android_key - - -@requires_rpc_tracker_and_android_key -@requires_hexagon_toolchain -class TestMatMul: - M = tvm.testing.parameter(32) - N = tvm.testing.parameter(32) - K = tvm.testing.parameter(32) - - def test_matmul(self, M, N, K, rpc_sess, remote_path): - X = te.placeholder((M, K), dtype="float32") - Y = te.placeholder((K, N), dtype="float32") - k1 = te.reduce_axis((0, K), name="k1") - Z = te.compute((M, N), lambda i, j: te.sum(X[i, k1] * Y[k1, j], axis=[k1])) - schedule = te.create_schedule(Z.op) - - target_hexagon = tvm.target.hexagon("v68", link_params=True) - mod = tvm.build(schedule, [X, Y, Z], target=target_hexagon, target_host=target_hexagon) - - temp = utils.tempdir() - dso_binary_path = temp.relpath(os.path.basename(remote_path)) - mod.save(dso_binary_path) - - rpc_sess.upload(dso_binary_path, target=remote_path) - - mod = rpc_sess.load_module(remote_path) - - x = np.random.uniform(size=[i.value for i in X.shape]).astype(X.dtype) - y = np.random.uniform(size=[i.value for i in Y.shape]).astype(Y.dtype) - z = np.zeros([i.value for i in Z.shape], dtype=Z.dtype) - - dev = rpc_sess.hexagon(0) - xt = tvm.nd.array(x, device=dev) - yt = tvm.nd.array(y, device=dev) - zt = tvm.nd.array(z, device=dev) - mod(xt, yt, zt) - - target_llvm = tvm.target.Target("llvm") - mod = tvm.build(schedule, [X, Y, Z], target=target_llvm, target_host=target_llvm) - device = tvm.cpu(0) - xtcpu = tvm.nd.array(x, device) - ytcpu = tvm.nd.array(y, device) - ztcpu = tvm.nd.array(z, device) - mod(xtcpu, ytcpu, ztcpu) - - tvm.testing.assert_allclose(zt.asnumpy(), ztcpu.asnumpy(), rtol=1e-4) diff --git a/tests/python/contrib/test_hexagon/rpc/test_launcher.md b/tests/python/contrib/test_hexagon/rpc/test_launcher.md index 463b88e3f374..bcf255e478f1 100644 --- a/tests/python/contrib/test_hexagon/rpc/test_launcher.md +++ b/tests/python/contrib/test_hexagon/rpc/test_launcher.md @@ -26,29 +26,50 @@ Here are the steps that are taken to prepare a runtime on a Hexagon device to te - Build TVM library with Hexagon support for host machine. - Build TVMRuntime library and C++ RPC server for host machine. -To build these pieces, you can use a cmake command as follow. +Note: First, ensure to export Clang libraries to `LD_LIBRARY_PATH` and Hexagon toolchain to `HEXAGON_TOOLCHAIN`: ```bash -cmake -DUSE_HEXAGON_RPC=ON \ - -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"path to `llvm-clang/lib` sub-directory" + +export HEXAGON_TOOLCHAIN="Path to Hexagon toolchain. It can be the Hexagon toolchain included in the SDK, for example `HEXAGON_SDK_PATH/tools/HEXAGON_Tools/x.y.z/Tools`. The `x.y.z` in the path is the toolchain version number, which is specific to the version of the SDK." +``` + +To build these pieces, first build Hexagon API application under `apps/hexagon_api`. + +```bash +cd apps/hexagon_api +mkdir build +cd build +cmake -DUSE_ANDROID_TOOLCHAIN="path to `android-ndk/build/cmake/android.toolchain.cmake` file" \ -DANDROID_PLATFORM=android-28 \ -DANDROID_ABI=arm64-v8a \ -DUSE_HEXAGON_ARCH=v65|v66|v68 \ - -DUSE_HEXAGON_SDK=/path/to/Hexagon/SDK \ - -DUSE_HEXAGON_TOOLCHAIN=/path/to/Hexagon/toolchain/ \ - -DUSE_LLVM=/path/to/llvm/bin/llvm-config \ + -DUSE_HEXAGON_SDK="path to Hexagon SDK" \ + -DUSE_HEXAGON_TOOLCHAIN="path to Hexagon toolchain `Tools` sub-directory which explained above" \ + -DUSE_OUTPUT_BINARY_DIR="path to `build/hexagon_api_output` which is a sub-directory of `tvm`" .. +``` + +This command generates `tvm_rpc_android` and `libtvm_runtime.so` to run on Android. Also, it generates `libtvm_runtime.a` and `libhexagon_rpc_skel.so` to run on Hexagon device. Now we have TVM artifacts which are used to run on the remote device. + +Next, we need to build TVM on host with RPC and Hexagon dependencies. To do that follow these commands. + +```bash +cd tvm +mkdir build +cd build +cmake -DUSE_LLVM="path to `llvm/bin/llvm-config`" \ -DUSE_CPP_RPC=ON \ - -DCMAKE_CXX_COMPILER=/path/to/clang++ \ - -DCMAKE_CXX_FLAGS='-stdlib=libc++' .. + -DCMAKE_CXX_COMPILER="path to `clang++` executable" \ + -DCMAKE_CXX_FLAGS='-stdlib=libc++' \ + -DUSE_HEXAGON_SDK="path to Hexagon SDK" \ + -DUSE_HEXAGON_ARCH="choose from v65|v66|v68" \ + -DUSE_HEXAGON_DEVICE=sim .. ``` ## Testing Using HexagonLauncher -Before starting a test you need to run an RPC tracker on your local machine and export HOST and PORT as environment variables. Also, you need to export Clang libraries to `LD_LIBRARY_PATH` and Hexagon toolchain to `HEXAGON_TOOLCHAIN`. +Before starting a test you need to run an RPC tracker on your local machine and export HOST and PORT as environment variables. Also, you need to export Clang libraries to `LD_LIBRARY_PATH` and Hexagon toolchain to `HEXAGON_TOOLCHAIN` as explained above. ```bash -export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/path/to/clang++/lib" -export HEXAGON_TOOLCHAIN="/path/to/Hexagon/toolchain/" - export TVM_TRACKER_HOST="0.0.0.0" export TVM_TRACKER_PORT=9192 python -m tvm.exec.rpc_tracker --host $TVM_TRACKER_HOST --port $TVM_TRACKER_PORT From 01d70330671106d3b74037cc433502e72955cbed Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 24 Jan 2022 15:50:40 -0800 Subject: [PATCH 37/49] [MetatSchedule] testcase for TensorRT builder/runner --- .../unittest/test_meta_schedule_byoc_trt.py | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 tests/python/unittest/test_meta_schedule_byoc_trt.py diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_trt.py new file mode 100644 index 000000000000..ca38a7d118a7 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_byoc_trt.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ + + +import sys +import pytest +import itertools +import tvm +from tvm import relay +from tvm.relay import testing +from tvm.relay.op.contrib import tensorrt +import numpy as np +from typing import List, Tuple + +# from tvm import script +# from tvm._ffi import register_func +# from tvm.runtime import Module +from tvm._ffi import register_func +from tvm.relay.testing.init import Initializer +from tvm.target import Target +from tvm.runtime import Module +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) + +from tvm.tir import FloatImm +from tvm.meta_schedule.testing import get_network + +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) + +# conv2d+relu network +def get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, +): + + data = relay.var("data", relay.TensorType(data_shape, dtype)) + weight = relay.var("weight") + + net = relay.nn.conv2d( + data=data, + weight=weight, # conv kernel + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + channels=out_channels, + kernel_size=kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + net = relay.add(net, net) + net = relay.nn.relu(net) + + inputs = relay.analysis.free_vars(net) + return relay.Function(inputs, net) + + +def verify_meta_schedule_with_tensorrt( + mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm" +): + if use_meta_sched: + # With meta_schedule + dev = "nvidia/geforce-rtx-2080" + + # Build + if use_trt: + + def relay_build_with_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + + mod, config = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + return tvm.relay.build_module._build_module_no_factory( + mod, "cuda", "llvm", params + ) + + builder = LocalBuilder(f_build=relay_build_with_tensorrt) + else: + + def relay_build_without_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + # @Sung: Weird. Cannot pass keyword arg + return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) + + builder = LocalBuilder(f_build=relay_build_without_tensorrt) + + builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) + + (builder_result,) = builder.build([builder_input]) + assert builder_result.error_msg is None + assert builder_result.artifact_path is not None + + # Run + evaluator_config = EvaluatorConfig( + number=5, + repeat=2, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + + runner_input = RunnerInput( + builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)] + ) + + def eval_func(rt_mod, device, evaluator_config, repeated_args): + rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device)) + + eval = rt_mod.module.time_evaluator( + func_name="run", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = eval(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + runner = LocalRunner( + evaluator_config=evaluator_config, + f_run_evaluator=eval_func, + ) + + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.run_secs is not None + assert runner_result.error_msg is None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + else: + # Without meta_schedule + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params + ).evaluate() + + +def test_conv2d_relu(): + data_shape = (1, 1280, 14, 14) + out_channels = 256 + kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 + data_layout, kernel_layout = "NCHW", "OIHW" + dtype = "float32" + + f = get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, + ) + + mod, params = testing.create_workload(f) + verify_meta_schedule_with_tensorrt(mod, params, data_shape) + + +@pytest.mark.parametrize( + "model_name", + ["resnet-50", "mobilenet"], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("use_trt", [True, False]) +def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): + + mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size) + verify_meta_schedule_with_tensorrt( + mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm" + ) + + +# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From e8e0b5c812c3f3d68016bd3596bf6cf7004b58be Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Wed, 26 Jan 2022 12:57:30 -0800 Subject: [PATCH 38/49] add pytest condition to pass CI. rename test name to be consistent. --- .../unittest/test_meta_schedule_byoc_trt.py | 244 ------------------ 1 file changed, 244 deletions(-) delete mode 100644 tests/python/unittest/test_meta_schedule_byoc_trt.py diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_trt.py deleted file mode 100644 index ca38a7d118a7..000000000000 --- a/tests/python/unittest/test_meta_schedule_byoc_trt.py +++ /dev/null @@ -1,244 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" Test Meta Schedule Builder """ - - -import sys -import pytest -import itertools -import tvm -from tvm import relay -from tvm.relay import testing -from tvm.relay.op.contrib import tensorrt -import numpy as np -from typing import List, Tuple - -# from tvm import script -# from tvm._ffi import register_func -# from tvm.runtime import Module -from tvm._ffi import register_func -from tvm.relay.testing.init import Initializer -from tvm.target import Target -from tvm.runtime import Module -from tvm.meta_schedule.arg_info import TensorInfo -from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult -from tvm.meta_schedule.runner import ( - EvaluatorConfig, - LocalRunner, - RunnerInput, -) - -from tvm.tir import FloatImm -from tvm.meta_schedule.testing import get_network - -has_tensorrt_codegen = pytest.mark.skipif( - not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" -) -has_tensorrt_runtime = pytest.mark.skipif( - not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" -) - -# conv2d+relu network -def get_conv2d_relu( - data_shape, - out_channels, - kernel_size, - strides, - padding, - dilation, - groups, - data_layout, - kernel_layout, - dtype, -): - - data = relay.var("data", relay.TensorType(data_shape, dtype)) - weight = relay.var("weight") - - net = relay.nn.conv2d( - data=data, - weight=weight, # conv kernel - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - channels=out_channels, - kernel_size=kernel_size, - data_layout=data_layout, - kernel_layout=kernel_layout, - ) - net = relay.add(net, net) - net = relay.nn.relu(net) - - inputs = relay.analysis.free_vars(net) - return relay.Function(inputs, net) - - -def verify_meta_schedule_with_tensorrt( - mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm" -): - if use_meta_sched: - # With meta_schedule - dev = "nvidia/geforce-rtx-2080" - - # Build - if use_trt: - - def relay_build_with_tensorrt( - mod: Module, - target: Target, - params: dict, - ) -> List[BuilderResult]: - from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt - - mod, config = partition_for_tensorrt(mod, params) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - return tvm.relay.build_module._build_module_no_factory( - mod, "cuda", "llvm", params - ) - - builder = LocalBuilder(f_build=relay_build_with_tensorrt) - else: - - def relay_build_without_tensorrt( - mod: Module, - target: Target, - params: dict, - ) -> List[BuilderResult]: - # @Sung: Weird. Cannot pass keyword arg - return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) - - builder = LocalBuilder(f_build=relay_build_without_tensorrt) - - builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) - - (builder_result,) = builder.build([builder_input]) - assert builder_result.error_msg is None - assert builder_result.artifact_path is not None - - # Run - evaluator_config = EvaluatorConfig( - number=5, - repeat=2, - min_repeat_ms=0, - enable_cpu_cache_flush=False, - ) - - runner_input = RunnerInput( - builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)] - ) - - def eval_func(rt_mod, device, evaluator_config, repeated_args): - rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device)) - - eval = rt_mod.module.time_evaluator( - func_name="run", - dev=device, - number=evaluator_config.number, - repeat=evaluator_config.repeat, - min_repeat_ms=evaluator_config.min_repeat_ms, - f_preproc="cache_flush_cpu_non_first_arg" - if evaluator_config.enable_cpu_cache_flush - else "", - ) - repeated_costs: List[List[float]] = [] - for args in repeated_args: - profile_result = eval(*args) - repeated_costs.append(profile_result.results) - - costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] - return costs - - runner = LocalRunner( - evaluator_config=evaluator_config, - f_run_evaluator=eval_func, - ) - - # Run the module - (runner_future,) = runner.run([runner_input]) - runner_result = runner_future.result() - assert runner_result is not None - assert runner_result.run_secs is not None - assert runner_result.error_msg is None - - for result in runner_result.run_secs: - if isinstance(result, FloatImm): - result = result.value - assert isinstance(result, float) - assert result >= 0.0 - - else: - # Without meta_schedule - if use_trt: - mod, config = tensorrt.partition_for_tensorrt(mod) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda" - ).evaluate() - else: - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params - ).evaluate() - - -def test_conv2d_relu(): - data_shape = (1, 1280, 14, 14) - out_channels = 256 - kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 - data_layout, kernel_layout = "NCHW", "OIHW" - dtype = "float32" - - f = get_conv2d_relu( - data_shape, - out_channels, - kernel_size, - strides, - padding, - dilation, - groups, - data_layout, - kernel_layout, - dtype, - ) - - mod, params = testing.create_workload(f) - verify_meta_schedule_with_tensorrt(mod, params, data_shape) - - -@pytest.mark.parametrize( - "model_name", - ["resnet-50", "mobilenet"], -) -@pytest.mark.parametrize("batch_size", [1, 8, 16]) -@pytest.mark.parametrize("use_meta_sched", [True]) -@pytest.mark.parametrize("use_trt", [True, False]) -def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): - - mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size) - verify_meta_schedule_with_tensorrt( - mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm" - ) - - -# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) -if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) From 4bad1534bf0cc04352a9933cac4a90700de0d4bf Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 28 Jan 2022 16:54:20 -0800 Subject: [PATCH 39/49] Rebase to pass CI and reflect suggestions --- python/tvm/meta_schedule/testing/__init__.py | 1 + python/tvm/meta_schedule/testing/byoc_trt.py | 17 +++++++++++ .../test_meta_schedule_byoc_tensorrt.py | 29 +++---------------- 3 files changed, 22 insertions(+), 25 deletions(-) create mode 100644 python/tvm/meta_schedule/testing/byoc_trt.py diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 7e516a510f66..a5291f7468ff 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -17,3 +17,4 @@ """Testing utilities in meta schedule""" from .local_rpc import LocalRPC from .relay_workload import get_network +from .byoc_trt import relay_build_with_tensorrt diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py new file mode 100644 index 000000000000..bcd021aa2528 --- /dev/null +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -0,0 +1,17 @@ +import tvm +from tvm.runtime import Module +from tvm.target import Target +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult +from typing import List + + +def relay_build_with_tensorrt( + mod: Module, + target: Target, + params: dict, +) -> List[BuilderResult]: + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + + mod, config = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index 24b6094af97c..8fff8f0f95bf 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -25,13 +25,8 @@ from tvm.relay import testing from tvm.relay.op.contrib import tensorrt import numpy as np -from typing import List, Tuple - -# from tvm import script -# from tvm._ffi import register_func -# from tvm.runtime import Module +from typing import List from tvm._ffi import register_func -from tvm.relay.testing.init import Initializer from tvm.target import Target from tvm.runtime import Module from tvm.meta_schedule.arg_info import TensorInfo @@ -94,25 +89,11 @@ def verify_meta_schedule_with_tensorrt( ): if use_meta_sched: # With meta_schedule - dev = "nvidia/geforce-rtx-2080" + dev = "cuda" # Build if use_trt: - - def relay_build_with_tensorrt( - mod: Module, - target: Target, - params: dict, - ) -> List[BuilderResult]: - from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt - - mod, config = partition_for_tensorrt(mod, params) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - return tvm.relay.build_module._build_module_no_factory( - mod, "cuda", "llvm", params - ) + from tvm.meta_schedule.testing import relay_build_with_tensorrt builder = LocalBuilder(f_build=relay_build_with_tensorrt) else: @@ -122,7 +103,6 @@ def relay_build_without_tensorrt( target: Target, params: dict, ) -> List[BuilderResult]: - # @Sung: Weird. Cannot pass keyword arg return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) builder = LocalBuilder(f_build=relay_build_without_tensorrt) @@ -235,7 +215,7 @@ def test_conv2d_relu(): "model_name", ["resnet-50", "mobilenet"], ) -@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("use_meta_sched", [True]) @pytest.mark.parametrize("use_trt", [True, False]) def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): @@ -246,6 +226,5 @@ def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use ) -# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From e38b24dfa1be32793adc528905f2b70b541e5ddb Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 28 Jan 2022 17:00:34 -0800 Subject: [PATCH 40/49] Add ASF header for the new file --- python/tvm/meta_schedule/testing/byoc_trt.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py index bcd021aa2528..7bceeec312a2 100644 --- a/python/tvm/meta_schedule/testing/byoc_trt.py +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import tvm from tvm.runtime import Module from tvm.target import Target From a269c8c875ee152d75b37ad616d9cc07cc3850e6 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 24 Jan 2022 15:50:40 -0800 Subject: [PATCH 41/49] [MetatSchedule] testcase for TensorRT builder/runner --- .../unittest/test_meta_schedule_byoc_trt.py | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 tests/python/unittest/test_meta_schedule_byoc_trt.py diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_trt.py new file mode 100644 index 000000000000..ca38a7d118a7 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_byoc_trt.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ + + +import sys +import pytest +import itertools +import tvm +from tvm import relay +from tvm.relay import testing +from tvm.relay.op.contrib import tensorrt +import numpy as np +from typing import List, Tuple + +# from tvm import script +# from tvm._ffi import register_func +# from tvm.runtime import Module +from tvm._ffi import register_func +from tvm.relay.testing.init import Initializer +from tvm.target import Target +from tvm.runtime import Module +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) + +from tvm.tir import FloatImm +from tvm.meta_schedule.testing import get_network + +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) + +# conv2d+relu network +def get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, +): + + data = relay.var("data", relay.TensorType(data_shape, dtype)) + weight = relay.var("weight") + + net = relay.nn.conv2d( + data=data, + weight=weight, # conv kernel + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + channels=out_channels, + kernel_size=kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + net = relay.add(net, net) + net = relay.nn.relu(net) + + inputs = relay.analysis.free_vars(net) + return relay.Function(inputs, net) + + +def verify_meta_schedule_with_tensorrt( + mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm" +): + if use_meta_sched: + # With meta_schedule + dev = "nvidia/geforce-rtx-2080" + + # Build + if use_trt: + + def relay_build_with_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + + mod, config = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + return tvm.relay.build_module._build_module_no_factory( + mod, "cuda", "llvm", params + ) + + builder = LocalBuilder(f_build=relay_build_with_tensorrt) + else: + + def relay_build_without_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + # @Sung: Weird. Cannot pass keyword arg + return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) + + builder = LocalBuilder(f_build=relay_build_without_tensorrt) + + builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) + + (builder_result,) = builder.build([builder_input]) + assert builder_result.error_msg is None + assert builder_result.artifact_path is not None + + # Run + evaluator_config = EvaluatorConfig( + number=5, + repeat=2, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + + runner_input = RunnerInput( + builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)] + ) + + def eval_func(rt_mod, device, evaluator_config, repeated_args): + rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device)) + + eval = rt_mod.module.time_evaluator( + func_name="run", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = eval(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + runner = LocalRunner( + evaluator_config=evaluator_config, + f_run_evaluator=eval_func, + ) + + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.run_secs is not None + assert runner_result.error_msg is None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + else: + # Without meta_schedule + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params + ).evaluate() + + +def test_conv2d_relu(): + data_shape = (1, 1280, 14, 14) + out_channels = 256 + kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 + data_layout, kernel_layout = "NCHW", "OIHW" + dtype = "float32" + + f = get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, + ) + + mod, params = testing.create_workload(f) + verify_meta_schedule_with_tensorrt(mod, params, data_shape) + + +@pytest.mark.parametrize( + "model_name", + ["resnet-50", "mobilenet"], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("use_trt", [True, False]) +def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): + + mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size) + verify_meta_schedule_with_tensorrt( + mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm" + ) + + +# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 8b87320e1b529db0eb5e7585367364a75b9899e7 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Wed, 26 Jan 2022 12:57:30 -0800 Subject: [PATCH 42/49] add pytest condition to pass CI. rename test name to be consistent. --- ...hedule_byoc_trt.py => test_meta_schedule_byoc_tensorrt.py} | 4 ++++ 1 file changed, 4 insertions(+) rename tests/python/unittest/{test_meta_schedule_byoc_trt.py => test_meta_schedule_byoc_tensorrt.py} (99%) diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py similarity index 99% rename from tests/python/unittest/test_meta_schedule_byoc_trt.py rename to tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index ca38a7d118a7..9204e151344a 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_trt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -52,6 +52,7 @@ not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" ) + # conv2d+relu network def get_conv2d_relu( data_shape, @@ -224,6 +225,9 @@ def test_conv2d_relu(): verify_meta_schedule_with_tensorrt(mod, params, data_shape) +@tvm.testing.requires_cuda +@has_tensorrt_codegen +@has_tensorrt_runtime @pytest.mark.parametrize( "model_name", ["resnet-50", "mobilenet"], From 46325d0e02619b8133d3e393b3e4bf5218703867 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Wed, 26 Jan 2022 19:25:28 -0800 Subject: [PATCH 43/49] add pyteset decorator to pass CI --- tests/python/unittest/test_meta_schedule_byoc_tensorrt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index 9204e151344a..24b6094af97c 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -201,6 +201,9 @@ def eval_func(rt_mod, device, evaluator_config, repeated_args): ).evaluate() +@tvm.testing.requires_cuda +@has_tensorrt_codegen +@has_tensorrt_runtime def test_conv2d_relu(): data_shape = (1, 1280, 14, 14) out_channels = 256 From ca5229298691ce67b19dd2714fdc692a21658f88 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Mon, 24 Jan 2022 15:50:40 -0800 Subject: [PATCH 44/49] [MetatSchedule] testcase for TensorRT builder/runner --- .../unittest/test_meta_schedule_byoc_trt.py | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 tests/python/unittest/test_meta_schedule_byoc_trt.py diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_trt.py new file mode 100644 index 000000000000..ca38a7d118a7 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_byoc_trt.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ + + +import sys +import pytest +import itertools +import tvm +from tvm import relay +from tvm.relay import testing +from tvm.relay.op.contrib import tensorrt +import numpy as np +from typing import List, Tuple + +# from tvm import script +# from tvm._ffi import register_func +# from tvm.runtime import Module +from tvm._ffi import register_func +from tvm.relay.testing.init import Initializer +from tvm.target import Target +from tvm.runtime import Module +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) + +from tvm.tir import FloatImm +from tvm.meta_schedule.testing import get_network + +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) + +# conv2d+relu network +def get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, +): + + data = relay.var("data", relay.TensorType(data_shape, dtype)) + weight = relay.var("weight") + + net = relay.nn.conv2d( + data=data, + weight=weight, # conv kernel + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + channels=out_channels, + kernel_size=kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + net = relay.add(net, net) + net = relay.nn.relu(net) + + inputs = relay.analysis.free_vars(net) + return relay.Function(inputs, net) + + +def verify_meta_schedule_with_tensorrt( + mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm" +): + if use_meta_sched: + # With meta_schedule + dev = "nvidia/geforce-rtx-2080" + + # Build + if use_trt: + + def relay_build_with_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + + mod, config = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + return tvm.relay.build_module._build_module_no_factory( + mod, "cuda", "llvm", params + ) + + builder = LocalBuilder(f_build=relay_build_with_tensorrt) + else: + + def relay_build_without_tensorrt( + mod: Module, + target: Target, + params: dict, + ) -> List[BuilderResult]: + # @Sung: Weird. Cannot pass keyword arg + return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) + + builder = LocalBuilder(f_build=relay_build_without_tensorrt) + + builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) + + (builder_result,) = builder.build([builder_input]) + assert builder_result.error_msg is None + assert builder_result.artifact_path is not None + + # Run + evaluator_config = EvaluatorConfig( + number=5, + repeat=2, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + + runner_input = RunnerInput( + builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)] + ) + + def eval_func(rt_mod, device, evaluator_config, repeated_args): + rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device)) + + eval = rt_mod.module.time_evaluator( + func_name="run", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = eval(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + runner = LocalRunner( + evaluator_config=evaluator_config, + f_run_evaluator=eval_func, + ) + + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.run_secs is not None + assert runner_result.error_msg is None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + else: + # Without meta_schedule + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + with tvm.transform.PassContext(opt_level=3): + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params + ).evaluate() + + +def test_conv2d_relu(): + data_shape = (1, 1280, 14, 14) + out_channels = 256 + kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 + data_layout, kernel_layout = "NCHW", "OIHW" + dtype = "float32" + + f = get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, + ) + + mod, params = testing.create_workload(f) + verify_meta_schedule_with_tensorrt(mod, params, data_shape) + + +@pytest.mark.parametrize( + "model_name", + ["resnet-50", "mobilenet"], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("use_trt", [True, False]) +def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): + + mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size) + verify_meta_schedule_with_tensorrt( + mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm" + ) + + +# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 419d756f694688256f6e8cf5d5689716b84cf9d3 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Wed, 26 Jan 2022 12:57:30 -0800 Subject: [PATCH 45/49] add pytest condition to pass CI. rename test name to be consistent. --- .../unittest/test_meta_schedule_byoc_trt.py | 244 ------------------ 1 file changed, 244 deletions(-) delete mode 100644 tests/python/unittest/test_meta_schedule_byoc_trt.py diff --git a/tests/python/unittest/test_meta_schedule_byoc_trt.py b/tests/python/unittest/test_meta_schedule_byoc_trt.py deleted file mode 100644 index ca38a7d118a7..000000000000 --- a/tests/python/unittest/test_meta_schedule_byoc_trt.py +++ /dev/null @@ -1,244 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" Test Meta Schedule Builder """ - - -import sys -import pytest -import itertools -import tvm -from tvm import relay -from tvm.relay import testing -from tvm.relay.op.contrib import tensorrt -import numpy as np -from typing import List, Tuple - -# from tvm import script -# from tvm._ffi import register_func -# from tvm.runtime import Module -from tvm._ffi import register_func -from tvm.relay.testing.init import Initializer -from tvm.target import Target -from tvm.runtime import Module -from tvm.meta_schedule.arg_info import TensorInfo -from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult -from tvm.meta_schedule.runner import ( - EvaluatorConfig, - LocalRunner, - RunnerInput, -) - -from tvm.tir import FloatImm -from tvm.meta_schedule.testing import get_network - -has_tensorrt_codegen = pytest.mark.skipif( - not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" -) -has_tensorrt_runtime = pytest.mark.skipif( - not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" -) - -# conv2d+relu network -def get_conv2d_relu( - data_shape, - out_channels, - kernel_size, - strides, - padding, - dilation, - groups, - data_layout, - kernel_layout, - dtype, -): - - data = relay.var("data", relay.TensorType(data_shape, dtype)) - weight = relay.var("weight") - - net = relay.nn.conv2d( - data=data, - weight=weight, # conv kernel - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - channels=out_channels, - kernel_size=kernel_size, - data_layout=data_layout, - kernel_layout=kernel_layout, - ) - net = relay.add(net, net) - net = relay.nn.relu(net) - - inputs = relay.analysis.free_vars(net) - return relay.Function(inputs, net) - - -def verify_meta_schedule_with_tensorrt( - mod, params, data_shape, use_meta_sched: bool = True, use_trt: bool = True, mode: str = "vm" -): - if use_meta_sched: - # With meta_schedule - dev = "nvidia/geforce-rtx-2080" - - # Build - if use_trt: - - def relay_build_with_tensorrt( - mod: Module, - target: Target, - params: dict, - ) -> List[BuilderResult]: - from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt - - mod, config = partition_for_tensorrt(mod, params) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - return tvm.relay.build_module._build_module_no_factory( - mod, "cuda", "llvm", params - ) - - builder = LocalBuilder(f_build=relay_build_with_tensorrt) - else: - - def relay_build_without_tensorrt( - mod: Module, - target: Target, - params: dict, - ) -> List[BuilderResult]: - # @Sung: Weird. Cannot pass keyword arg - return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) - - builder = LocalBuilder(f_build=relay_build_without_tensorrt) - - builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) - - (builder_result,) = builder.build([builder_input]) - assert builder_result.error_msg is None - assert builder_result.artifact_path is not None - - # Run - evaluator_config = EvaluatorConfig( - number=5, - repeat=2, - min_repeat_ms=0, - enable_cpu_cache_flush=False, - ) - - runner_input = RunnerInput( - builder_result.artifact_path, "cuda", [TensorInfo("float32", data_shape)] - ) - - def eval_func(rt_mod, device, evaluator_config, repeated_args): - rt_mod = tvm.contrib.graph_executor.GraphModule(rt_mod["default"](device)) - - eval = rt_mod.module.time_evaluator( - func_name="run", - dev=device, - number=evaluator_config.number, - repeat=evaluator_config.repeat, - min_repeat_ms=evaluator_config.min_repeat_ms, - f_preproc="cache_flush_cpu_non_first_arg" - if evaluator_config.enable_cpu_cache_flush - else "", - ) - repeated_costs: List[List[float]] = [] - for args in repeated_args: - profile_result = eval(*args) - repeated_costs.append(profile_result.results) - - costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] - return costs - - runner = LocalRunner( - evaluator_config=evaluator_config, - f_run_evaluator=eval_func, - ) - - # Run the module - (runner_future,) = runner.run([runner_input]) - runner_result = runner_future.result() - assert runner_result is not None - assert runner_result.run_secs is not None - assert runner_result.error_msg is None - - for result in runner_result.run_secs: - if isinstance(result, FloatImm): - result = result.value - assert isinstance(result, float) - assert result >= 0.0 - - else: - # Without meta_schedule - if use_trt: - mod, config = tensorrt.partition_for_tensorrt(mod) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda" - ).evaluate() - else: - with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params - ).evaluate() - - -def test_conv2d_relu(): - data_shape = (1, 1280, 14, 14) - out_channels = 256 - kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 - data_layout, kernel_layout = "NCHW", "OIHW" - dtype = "float32" - - f = get_conv2d_relu( - data_shape, - out_channels, - kernel_size, - strides, - padding, - dilation, - groups, - data_layout, - kernel_layout, - dtype, - ) - - mod, params = testing.create_workload(f) - verify_meta_schedule_with_tensorrt(mod, params, data_shape) - - -@pytest.mark.parametrize( - "model_name", - ["resnet-50", "mobilenet"], -) -@pytest.mark.parametrize("batch_size", [1, 8, 16]) -@pytest.mark.parametrize("use_meta_sched", [True]) -@pytest.mark.parametrize("use_trt", [True, False]) -def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): - - mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size) - verify_meta_schedule_with_tensorrt( - mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm" - ) - - -# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) -if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) From c45d16afe160534c1f843c29c6174b2594e31daf Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 28 Jan 2022 16:54:20 -0800 Subject: [PATCH 46/49] Rebase to pass CI and reflect suggestions --- python/tvm/meta_schedule/testing/__init__.py | 1 + python/tvm/meta_schedule/testing/byoc_trt.py | 17 +++++++++++ .../test_meta_schedule_byoc_tensorrt.py | 29 +++---------------- 3 files changed, 22 insertions(+), 25 deletions(-) create mode 100644 python/tvm/meta_schedule/testing/byoc_trt.py diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 7e516a510f66..a5291f7468ff 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -17,3 +17,4 @@ """Testing utilities in meta schedule""" from .local_rpc import LocalRPC from .relay_workload import get_network +from .byoc_trt import relay_build_with_tensorrt diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py new file mode 100644 index 000000000000..bcd021aa2528 --- /dev/null +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -0,0 +1,17 @@ +import tvm +from tvm.runtime import Module +from tvm.target import Target +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult +from typing import List + + +def relay_build_with_tensorrt( + mod: Module, + target: Target, + params: dict, +) -> List[BuilderResult]: + from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + + mod, config = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index 24b6094af97c..8fff8f0f95bf 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -25,13 +25,8 @@ from tvm.relay import testing from tvm.relay.op.contrib import tensorrt import numpy as np -from typing import List, Tuple - -# from tvm import script -# from tvm._ffi import register_func -# from tvm.runtime import Module +from typing import List from tvm._ffi import register_func -from tvm.relay.testing.init import Initializer from tvm.target import Target from tvm.runtime import Module from tvm.meta_schedule.arg_info import TensorInfo @@ -94,25 +89,11 @@ def verify_meta_schedule_with_tensorrt( ): if use_meta_sched: # With meta_schedule - dev = "nvidia/geforce-rtx-2080" + dev = "cuda" # Build if use_trt: - - def relay_build_with_tensorrt( - mod: Module, - target: Target, - params: dict, - ) -> List[BuilderResult]: - from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt - - mod, config = partition_for_tensorrt(mod, params) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - return tvm.relay.build_module._build_module_no_factory( - mod, "cuda", "llvm", params - ) + from tvm.meta_schedule.testing import relay_build_with_tensorrt builder = LocalBuilder(f_build=relay_build_with_tensorrt) else: @@ -122,7 +103,6 @@ def relay_build_without_tensorrt( target: Target, params: dict, ) -> List[BuilderResult]: - # @Sung: Weird. Cannot pass keyword arg return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) builder = LocalBuilder(f_build=relay_build_without_tensorrt) @@ -235,7 +215,7 @@ def test_conv2d_relu(): "model_name", ["resnet-50", "mobilenet"], ) -@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("use_meta_sched", [True]) @pytest.mark.parametrize("use_trt", [True, False]) def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): @@ -246,6 +226,5 @@ def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use ) -# @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True) if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 68fd6950de2bc31ccfb302f03393e7066ce3ee3a Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 28 Jan 2022 17:00:34 -0800 Subject: [PATCH 47/49] Add ASF header for the new file --- python/tvm/meta_schedule/testing/byoc_trt.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py index bcd021aa2528..7bceeec312a2 100644 --- a/python/tvm/meta_schedule/testing/byoc_trt.py +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import tvm from tvm.runtime import Module from tvm.target import Target From 2390740093b30612e351aaedae82212f1872757a Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 28 Jan 2022 17:30:40 -0800 Subject: [PATCH 48/49] add pylint, docstring --- python/tvm/meta_schedule/testing/byoc_trt.py | 24 +++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py index 7bceeec312a2..6ed9f6414b97 100644 --- a/python/tvm/meta_schedule/testing/byoc_trt.py +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -14,12 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""TensorRT-MetaSchedule integration""" +# pylint: disable=import-outside-toplevel import tvm from tvm.runtime import Module -from tvm.target import Target -from tvm.meta_schedule.builder import BuilderInput, LocalBuilder, BuilderResult +from tvm.meta_schedule.builder import BuilderResult from typing import List +from tvm.target import Target def relay_build_with_tensorrt( @@ -27,8 +29,24 @@ def relay_build_with_tensorrt( target: Target, params: dict, ) -> List[BuilderResult]: + """Build a Relay IRModule with TensorRT BYOC + Parameters + ---------- + mod : IRModule + The Relay IRModule to build. + target : Target + The target to build the module for. + params : Dict[str, NDArray] + The parameter dict to build the module with. + Returns + ------- + mod : runtime.Module + The built module. + """ from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt mod, config = partition_for_tensorrt(mod, params) with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): - return tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) + result = tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) + assert isinstance(result, Module) + return result From 04cb4ca179a14310ff582b855071aa7a4d9de5d7 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 28 Jan 2022 23:00:52 -0800 Subject: [PATCH 49/49] fix lint and wish for the best --- python/tvm/meta_schedule/testing/byoc_trt.py | 3 ++- tests/python/unittest/test_meta_schedule_byoc_tensorrt.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py index 6ed9f6414b97..d459518cdb23 100644 --- a/python/tvm/meta_schedule/testing/byoc_trt.py +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -17,10 +17,10 @@ """TensorRT-MetaSchedule integration""" # pylint: disable=import-outside-toplevel +from typing import List import tvm from tvm.runtime import Module from tvm.meta_schedule.builder import BuilderResult -from typing import List from tvm.target import Target @@ -45,6 +45,7 @@ def relay_build_with_tensorrt( """ from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt + assert isinstance(target, Target) mod, config = partition_for_tensorrt(mod, params) with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): result = tvm.relay.build_module._build_module_no_factory(mod, "cuda", "llvm", params) diff --git a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py index 8fff8f0f95bf..3b4164c40644 100644 --- a/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py +++ b/tests/python/unittest/test_meta_schedule_byoc_tensorrt.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. """ Test Meta Schedule Builder """ - - import sys import pytest import itertools