From a3cf22d8d7c1d1521a7765902bec8744e7df0d2b Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Fri, 24 Jun 2022 13:28:28 +0400 Subject: [PATCH 01/15] [Adreno] Add static texture markup relay pass Co-authored-by: Chris Sullivan --- include/tvm/relay/transform.h | 5 + python/tvm/topi/adreno/conv2d_nchw.py | 19 +- python/tvm/topi/adreno/conv2d_nhwc.py | 16 +- src/relay/backend/build_module.cc | 13 + .../transforms/annotate_texture_storage.cc | 518 ++++++++++++++ .../python/relay/test_conv2d_nchw_texture.py | 640 +++++++++++++++++- .../python/relay/test_conv2d_nhwc_texture.py | 7 +- .../test_depthwise_conv2d_nchw_texture.py | 4 +- .../test_depthwise_conv2d_nhwc_texture.py | 2 +- tests/python/relay/utils/adreno_utils.py | 15 + 10 files changed, 1220 insertions(+), 19 deletions(-) create mode 100644 src/relay/transforms/annotate_texture_storage.cc diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 042ad1ef02da..4288f2ec48ee 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -580,6 +580,11 @@ TVM_DLL Pass AnnotateUsedMemory(); */ TVM_DLL Pass CapturePostDfsIndexInSpans(); + /*! + * \brief Calls device dependent memory scope analysis pass, collects mapping of desirable + * expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope + */ +TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config); } // namespace transform /*! diff --git a/python/tvm/topi/adreno/conv2d_nchw.py b/python/tvm/topi/adreno/conv2d_nchw.py index 2a8f6028b755..16ecaa84d040 100644 --- a/python/tvm/topi/adreno/conv2d_nchw.py +++ b/python/tvm/topi/adreno/conv2d_nchw.py @@ -29,6 +29,7 @@ add_pad, bind_data_copy, get_default_conv2d_config, + get_texture_storage, ) @@ -214,8 +215,11 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output): 5d tensors 4. pad should be scheduled separately to create independent opencl kernel. If pad is inlined into convolution, this gives 1.5x performance drop - 5. We are using cache_read to produce texture and guarantee the best performance - on the next stage. + 5. We are using cache_read for intermediate tensors to produce texture and guarantee + the best performance on the next stage. + The weights are managed through static texture planning mechanism and guarantied come + in texture memory scope. + Thus way we are calling cache_read only for data tensor 6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize for textures For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion @@ -288,10 +292,15 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output): s[output].compute_inline() # create cache stage - AT = s.cache_read(pad_data, "global.texture", [conv]) + AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv]) bind_data_copy(s[AT]) - WT = s.cache_read(kernel, "global.texture-weight", [conv]) - bind_data_copy(s[WT]) + if ( + autotvm.GLOBAL_SCOPE.in_tuning + or isinstance(kernel.op, tvm.te.ComputeOp) + and "filter_pack" in kernel.op.tag + ): + WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv]) + bind_data_copy(s[WT]) # tile and bind spatial axes n, fc, y, x, fb = s[latest_blocked].op.axis diff --git a/python/tvm/topi/adreno/conv2d_nhwc.py b/python/tvm/topi/adreno/conv2d_nhwc.py index 388f606ecb54..ce7bf0ccc956 100644 --- a/python/tvm/topi/adreno/conv2d_nhwc.py +++ b/python/tvm/topi/adreno/conv2d_nhwc.py @@ -210,8 +210,11 @@ def schedule_conv2d_NHWC(cfg, s, output): 5d tensors 4. pad should be scheduled separately to create independent opencl kernel. If pad is inlined into convolution, this gives 1.5x performance drop - 5. We are using cache_read to produce texture and guarantee the best performance - on the next stage. + 5. We are using cache_read for intermediate tensors to produce texture and guarantee + the best performance on the next stage. + The weights are managed through static texture planning mechanism and guarantied come + in texture memory scope. + Thus way we are calling cache_read only for data tensor 6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize for textures For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion @@ -287,8 +290,13 @@ def schedule_conv2d_NHWC(cfg, s, output): # create cache stage AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv]) bind_data_copy(s[AT]) - WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv]) - bind_data_copy(s[WT]) + if ( + autotvm.GLOBAL_SCOPE.in_tuning + or isinstance(kernel.op, tvm.te.ComputeOp) + and "filter_pack" in kernel.op.tag + ): + WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv]) + bind_data_copy(s[WT]) # tile and bind spatial axes n, y, x, fc, fb = s[latest_blocked].op.axis diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 39f2e7761a42..edd401b1c22e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -397,6 +397,19 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module = transform::InferType()(relay_module); relay_module = transform::LabelOps()(relay_module); + relay_module = transform::AnnotateMemoryScope(config_)(relay_module); + pass_seqs = GetPassPrefix( + /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); + pass_seqs.push_back(transform::PlanDevices(config_)); + // Create a sequential pass and perform optimizations. + seq = transform::Sequential(pass_seqs); + if (config_->optional_homogeneous_target.defined()) { + With tctx(config_->optional_homogeneous_target); + relay_module = seq(relay_module); + } else { + relay_module = seq(relay_module); + } + relay_module = transform::InferType()(relay_module); ICHECK(relay_module.defined()); return relay_module; diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc new file mode 100644 index 000000000000..d9d3599234b8 --- /dev/null +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file annotate_texture_storage.cc + * \brief Collection of target specific relay passes which + * storage scope related information. + * + * - CollectStorageInfo returns a mapping from relay expr + * to a list of output storage scopes for each output. + * These scopes are used during memory planning as well + * as downstream when doing codegen and in the graph runtime when doing runtime dataspace + * allocations. + * + * - AnnotateMemoryScope calls *target.CollectStorageInfo for all target been represented + * in the graph and rewrites graph modifying or inserting of VirtualDevice with required + * memory_scop collected from the CollectStorageInfo + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "../transforms/device_aware_visitors.h" + +namespace tvm { +namespace relay { +namespace { + +/** + * @brief Analyzes the graph and returns mapping of expressions vs desired memory scope + */ +class StorageInfo : private transform::DeviceAwareExprVisitor { + public: + StorageInfo() : transform::DeviceAwareExprVisitor(Optional()) {} + + static Map> GetStorageMap(const Expr& expr) { + StorageInfo storage_info; + storage_info.VisitExpr(expr); + storage_info.LegalizeProducerStorage(); + Map> storage_map; + for (auto& kv : storage_info.storage_scope_) { + std::vector storage_scopes; + std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(storage_scopes)); + storage_map.Set(GetRef(kv.first), Array{storage_scopes}); + } + + // Filling the input arguments by "global" scope to handle PlanDevice algo which propagates + // virtual devices from outputs to inputs. At the same time outputs must be unconstrained + // to avoid useless device_copy + for (const auto& cs: storage_info.consumer_storage_scopes_) { + // we have record in consumers that mean that potentially consumer + // dealt with textures anyhow, it's safe to mark this expr as global scope + // even without verification of the consumer's outputs scope + if (storage_info.CanConsumeTextures(cs.second) && + storage_map.find(GetRef(cs.first)) == storage_map.end()) { + storage_map.Set(GetRef(cs.first), Array{"global"}); + } + } + + // initial algo assumes mapping of outputs of the expr that is not enough, need to update + // VirtualDevice for function variables to get proper codegen. Adding vars to storage_map + for (const auto& a : storage_info.args_to_vars_) { + if (storage_map.count(a.first)) { + for (const auto& v : a.second) { + storage_map.Set(v, storage_map[a.first]); + } + } + } + return storage_map; + } + + private: + void Visit(const Expr& expr) { + // Pre-order traversal to enable upward propagation + // of consumer storage scopes to producers when desirable. + if (const auto* fn = expr.as()) { + this->VisitExpr(fn->body); + for (const auto& param : fn->params) { + this->VisitExpr(param); + } + } else { + this->VisitExpr(expr); + } + } + + void VisitExpr_(const VarNode* vn) final { ApplyConsumerScopeToInputs(vn); } + + void VisitExpr_(const ConstantNode* cn) final { + ApplyConsumerScopeToInputs(cn); + } + + void DeviceAwareVisitExpr_(const CallNode* call) final { + // Check the contents of this primitive function + if (DeviceSupportsTextureStorage(GetRef(call))) { + if (const auto* fn = call->op.as()) { + if (fn->HasNonzeroAttr(attr::kPrimitive)) { + primitive_supports_texture_ = false; + Visit(call->op); + if (primitive_supports_texture_) { + if (call->checked_type().as()) { + std::string scope = "global.texture"; + if (const auto* ttype = call->checked_type().as()) { + if (ttype->shape.size() == 5) { + scope = Scope(ttype->shape); + } + } + storage_scope_[call].push_back(scope); + } else { + const auto* tuple_type = call->type_as(); + ICHECK(tuple_type); + // TODO(csullivan): Add support for mixed output storage scope. + // In current adreno storage planner all outputs of a + // primitive function are assumed to be of the same storage + // type. This should be easy to extend in the future. + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + storage_scope_[call].push_back("global.texture"); + } + } + for (size_t i = 0; i < fn->params.size(); i++) { + args_to_vars_[call->args[i]].push_back(fn->params[i]); + } + } + // Add consumer storage scope information for call arguments + for (auto& arg : call->args) { + if (storage_scope_.count(call)) { + ICHECK(!HasMixedStorageOutputs(call)) + << "Mixed output storage scopes are not currently supported"; + consumer_storage_scopes_[arg.operator->()].push_back("global.texture"); + } else { + consumer_storage_scopes_[arg.operator->()].push_back("global"); + } + } + } + } + } + + primitive_supports_texture_ = SupportsTextureStorage(call); + + for (auto& arg : call->args) { + Visit(arg); + } + // We have all callees filled into storage_scope_ if they support textures + // We need to verify if this call expects texture and if it does not, remove from + // storage_scope_ since initially storage_scope_ is filled only based on knowledge + // that function able to work with textures, but not necessary that this texture is + // expected by function callee + for (auto& arg : call->args) { + if (consumer_storage_scopes_.count(arg.operator->()) && + GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) != "global.texture") { + storage_scope_.erase(arg.operator->()); + if (const auto* cn = arg.as() ) { + if (const auto* fn = cn->op.as()) { + storage_scope_.erase(fn->body.operator->()); + } + } + } + } + } + + std::string Scope(Array shape) { + std::map diffs; + int limit = 16384; + int a0 = shape[0].as()->value; + int a1 = shape[1].as()->value; + int a2 = shape[2].as()->value; + int a3 = shape[3].as()->value; + + int d3l = a0 * a1 * a2; + int d3r = a3; + int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l; + if (d3l < limit && d3r < limit) diffs[diff3] = ""; + + int d2l = a0 * a1; + int d2r = a2 * a3; + int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l; + if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc"; + + int d1l = a0; + int d1r = a1 * a2 * a3; + int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l; + if (d1l < limit && d1r < limit) diffs[diff1] = "weight"; + if (!diffs.empty()) { + std::string scope = "global.texture"; + if (!diffs.begin()->second.empty()) { + scope += ("-" + diffs.begin()->second); + } + return scope; + } else { + return "global.texture"; + } + } + + void ApplyConsumerScopeToInputs(const ExprNode* expr) { + std::string scope; + auto consumer_scopes_it = consumer_storage_scopes_.find(expr); + if (consumer_scopes_it != consumer_storage_scopes_.end()) { + std::string consumer_scope = GetConsumerScope(consumer_scopes_it->second); + ICHECK(!storage_scope_.count(expr)) + << "Already propagated consumer scopes to input: " << GetRef(expr); + + bool expr_is_rgba_vectorizable = false; + if (const auto* ttype = expr->checked_type().as()) { + if (ttype->shape.size() == 5) { + scope = Scope(ttype->shape); + if (scope != "global") { + auto inner_dim = ttype->shape.back().as(); + if (inner_dim && inner_dim->value == 4) { + expr_is_rgba_vectorizable = true; + } + } + } + } + + // Only propagate texture scope from consumers to input expr if + // the input shape of the input expr is rgba vectorizable. + if (consumer_scope.find("global.texture") != std::string::npos) { + if (expr_is_rgba_vectorizable) { + storage_scope_[expr].push_back(scope); + } + } else { + storage_scope_[expr].push_back(consumer_scope); + } + } + } + + void LegalizeProducerStorage() { + for (auto& kv : consumer_storage_scopes_) { + const ExprNode* producer = kv.first; + std::string legal_scope = GetConsumerScope(kv.second); + if (storage_scope_.count(producer)) { + ICHECK(!HasMixedStorageOutputs(producer)) + << "Mixed output storage scopes are not currently supported"; + if (storage_scope_[producer][0].find(legal_scope) == std::string::npos) { + for (size_t i = 0; i < storage_scope_[producer].size(); i++) { + // Only support uniform storage scope across all outputs for now + storage_scope_[producer][i] = legal_scope; + } + } + } + } + } + + bool DeviceSupportsTextureStorage(const Expr& expr) { + auto vd = GetVirtualDevice(expr); + if (vd != VirtualDevice::FullyUnconstrained()) { + if (Optional t_device = vd->target->GetAttr("device")) { + if (vd->target->kind->device_type == kDLOpenCL && t_device.defined()) { + if (t_device.value() == "adreno") { + return true; + } + } + } + } + return false; + } + + std::string GetConsumerScope(const std::vector& consumer_scopes) const { + if (!consumer_scopes.size()) { + return "global"; + } + std::string texture_tag = "global.texture"; + for (auto& consumer_scope : consumer_scopes) { + if (consumer_scope.find(texture_tag) == std::string::npos) { + return "global"; + } + } + return texture_tag; + } + + bool CanConsumeTextures(const std::vector& consumer_scopes) const { + if (!consumer_scopes.size()) { + return false; + } + std::string texture_tag = "global.texture"; + for (auto& consumer_scope : consumer_scopes) { + if (consumer_scope.find(texture_tag) == 0) { + return true; + } + } + return false; + } + + bool HasMixedStorageOutputs(const ExprNode* expr) { + if (storage_scope_.count(expr)) { + std::string ref_scope = storage_scope_[expr][0]; + for (std::string& scope : storage_scope_[expr]) { + if (scope != ref_scope) { + return true; + } + } + } + return false; + } + + bool SupportsTextureStorage(const CallNode* call) const { + bool supports_texture_storage = false; + if (auto attrs = call->attrs.as()) { + if (attrs->data_layout == "NCHW4c" && attrs->kernel_layout == "OIHW4o") { + supports_texture_storage = true; + } else if (attrs->data_layout == "NHWC4c" && + (attrs->kernel_layout == "HWOI4o" || attrs->kernel_layout == "HWIO4o" || + attrs->kernel_layout == "OIHW4o")) { + supports_texture_storage = true; + } + } else if (auto attrs = call->attrs.as()) { + if (attrs->layout == "NCHW4c") { + supports_texture_storage = true; + } + } else if (auto attrs = call->attrs.as()) { + if (attrs->layout == "NCHW4c") { + supports_texture_storage = true; + } + } else if (auto attrs = call->attrs.as()) { + if (attrs->layout == "NCHW4c") { + supports_texture_storage = true; + } + } + + return supports_texture_storage; + } + + /*! \brief Temporary state for marking whether a visited function + * primitive supports texture storage scope */ + bool primitive_supports_texture_ = false; + /*! \brief expr storage scope mapping for each output */ + std::unordered_map> storage_scope_; + /*! \brief output storage scopes used by consumers of expr key */ + std::unordered_map> consumer_storage_scopes_; + /*! \brief mapping of arguments to call to function variables*/ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> args_to_vars_; +}; + +} // namespace + +/** + * @brief rewrite of virtual devices, memory_scope part for expressions defined + * by the StorageInfo analysis pass + * + * Currently this workflow supports analysis and rewriting of VirtualDevice for + * Constants and function Variables + */ +class VDRewriter : public transform::DeviceAwareExprMutator { + using VarMap = std::unordered_map; + + public: + explicit VDRewriter(const Map>& storage_scope) + : transform::DeviceAwareExprMutator(Optional()), storage_scope_(storage_scope) {} + + Function Rewrite(const Expr& expr) { return Downcast(Mutate(expr)); } + + Expr VisitExpr_(const VarNode* vn) final { + if (storage_scope_.find(GetRef(vn)) != storage_scope_.end() && + storage_scope_[GetRef(vn)][0] != "global") { + Var c = Var(vn->vid, vn->type_annotation, vn->span); + auto virtual_device = GetVirtualDevice(GetRef(vn)); + c->virtual_device_ = + VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, + virtual_device->target, storage_scope_[GetRef(vn)][0]); + return c; + } + return GetRef(vn); + } + + Expr VisitExpr_(const ConstantNode* vn) final { + if (storage_scope_.find(GetRef(vn)) != storage_scope_.end()) { + Expr c = Constant(vn->data, vn->span); + auto virtual_device = GetVirtualDevice(GetRef(vn)); + c = OnDevice(c, + VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, + virtual_device->target, storage_scope_[GetRef(vn)][0]), + true); + return c; + } + return GetRef(vn); + } + + Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { + auto new_call = transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node); + auto virtual_device = GetVirtualDevice(GetRef(call_node)); + std::string memory_scope = ""; + if (storage_scope_.find(GetRef(call_node)) != storage_scope_.end()) { + memory_scope = storage_scope_[GetRef(call_node)][0]; + } else if (virtual_device->memory_scope != "") { + memory_scope = virtual_device->memory_scope; + } else if (!call_node->op.as()) { + memory_scope = ""; + } + if (!memory_scope.empty()) { + new_call = OnDevice(new_call, + VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, + virtual_device->target, memory_scope), + true); + } + return new_call; + } + + private: + Map> storage_scope_; + VarMap new_vars_; + Array current_function_scope_; +}; + +Map> CollectTextureStorage(const Expr& expr) { + return StorageInfo::GetStorageMap(expr); +} + +/** + * @brief Collects all target devices participated in graph + */ +class CollectVirtualDevices : public transform::DeviceAwareExprVisitor { + public: + CollectVirtualDevices() : transform::DeviceAwareExprVisitor(Optional()) {} + /** + * @brief Get all unique device elements from target of each VirtualDevice + * + * @param expr - IR + * @return set of devices + */ + std::set GetDevices(const Expr& expr) { + this->Run(expr); + return std::move(devices_); + } + + void Visit(const Expr& expr) { + // Pre-order traversal to enable upward propagation + // of consumer storage scopes to producers when desirable. + if (const auto* fn = expr.as()) { + this->VisitExpr(fn->body); + for (const auto& param : fn->params) { + this->VisitExpr(param); + } + } else { + this->VisitExpr(expr); + } + } + + void DeviceAwareVisitExpr_(const CallNode* call) final { + auto vd = GetVirtualDevice(GetRef(call)); + if (vd != VirtualDevice::FullyUnconstrained()) { + if (Optional t_device = vd->target->GetAttr("device")) { + devices_.insert(vd->target->kind->name + "." + t_device.value()); + } + } + for (auto& arg : call->args) { + Visit(arg); + } + } + + void Run(const Expr& expr) { VisitExpr(expr); } + using transform::DeviceAwareExprVisitor::VisitExpr_; + std::set devices_; +}; + +/*! + * \brief Collect the target specific tensor storage info for each expression's output. + * \param expr The expression. + * \return The device based storage mapping. + */ +Map> CollectStorageInfo(const Expr& expr) { + std::set device_types = CollectVirtualDevices().GetDevices(expr); + // TODO(amalyshe): current approach collects all targets withing graph and call the only + // function corresponding to all these targets in alphabetic order + // this will work reliable only for case of only one device and should be redesigned + // to handle common case + std::string ftarget_prefix = "relay.backend"; + for (auto& dev_id : device_types) { + ftarget_prefix += (std::string(".") + dev_id); + } + + Map> storage_info = {}; + if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectStorageInfo")) { + storage_info = (*f)(expr); + } + return storage_info; +} + +Expr AnnotateMemoryScopeExpr(const Expr& expr, const IRModule& mod, CompilationConfig config) { + auto storage_scope = CollectStorageInfo(expr); + return VDRewriter(storage_scope).Rewrite(expr); +} + +namespace transform { +tvm::transform::Pass AnnotateMemoryScope(CompilationConfig config) { + runtime::TypedPackedFunc pass_func = + [config = std::move(config)](Function f, IRModule m, PassContext pc) { + return Downcast(AnnotateMemoryScopeExpr(f, m, config)); + }; + return CreateFunctionPass(pass_func, 2, "AnnotateMemoryScope", {}); +} +} // namespace transform + +TVM_REGISTER_GLOBAL("relay.backend.opencl.adreno._CollectStorageInfo") + .set_body_typed(CollectTextureStorage); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index 2dd88f6118b4..7feb42890c17 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -63,7 +63,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl @@ -105,7 +105,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl @@ -147,7 +147,7 @@ def test_conv2d_inceptionv3_35_35_strides(): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl @@ -493,3 +493,637 @@ def test_conv2d_winograd_conv(): ) matches = re.findall("winograd", graph) assert len(matches) > 0 + + +@tvm.testing.requires_opencl +def test_2conv2d(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 32, 40, 40) + filter_shape1 = (96, 32, 2, 2) + filter_shape2 = (32, 96, 2, 2) + bias_shape1 = (1, 96, 1, 1) + bias_shape2 = (1, 32, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + B2 = relay.var("bias2", shape=bias_shape2, dtype=dtype) + + # C = relay.nn.relu(A) + conv1 = relay.nn.conv2d( + A, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=96, + kernel_size=(2, 2), + ) + D = relay.op.add(conv1, B1) + D = relay.op.nn.relu(D) + + conv2 = relay.nn.conv2d( + D, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + D = relay.op.add(conv2, B2) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, W1, B1, W2, B2], D) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data1 = np.zeros(filter_shape1).astype(dtype) + bias_data1 = np.zeros(bias_shape1).astype(dtype) + initializer("weight", filter_data1) + initializer("bias", bias_data1) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + bias_data2 = np.zeros(bias_shape2).astype(dtype) + initializer("weight", filter_data2) + initializer("bias", bias_data2) + params1 = { + "weight1": tvm.nd.array(filter_data1), + "bias1": tvm.nd.array(bias_data1), + "weight2": tvm.nd.array(filter_data2), + "bias2": tvm.nd.array(bias_data2), + } + + static_memory_scope = [ + "global", + "global", + "global.texture-weight", + "global.texture-weight", + "global.texture-nhwc", + "global.texture-weight", + "global.texture-weight", + "", + "", + ] + + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + +@tvm.testing.requires_opencl +def test_residual_block(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 32, 40, 40) + filter_shape1 = (32, 32, 2, 2) + filter_shape2 = (32, 32, 1, 1) + filter_shape3 = (32, 32, 2, 2) + bias_shape1 = (1, 32, 1, 1) + # bias_shape2 = (1, 32, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + # B2 = relay.var("bias2", shape=bias_shape2, dtype=dtype) + W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype) + + conv1 = relay.nn.conv2d( + A, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + D = relay.op.add(conv1, B1) + D = relay.op.nn.relu(D) + + conv2 = relay.nn.conv2d( + D, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + D = relay.op.add(conv2, D) + D = D * relay.const(0.15, "float16") + D = relay.op.nn.relu(D) + + conv3 = relay.nn.conv2d( + D, + W3, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + D = relay.op.nn.relu(conv3) + + mod = relay.Function([A, W1, B1, W2, W3], D) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data1 = np.zeros(filter_shape1).astype(dtype) + bias_data1 = np.zeros(bias_shape1).astype(dtype) + initializer("weight", filter_data1) + initializer("bias", bias_data1) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + # bias_data2 = np.zeros(bias_shape2).astype(dtype) + initializer("weight", filter_data2) + # initializer("bias", bias_data2) + filter_data3 = np.zeros(filter_shape3).astype(dtype) + initializer("weight", filter_data3) + params1 = { + "weight1": tvm.nd.array(filter_data1), + "bias1": tvm.nd.array(bias_data1), + "weight2": tvm.nd.array(filter_data2), + # "bias2": tvm.nd.array(bias_data2), + "weight3": tvm.nd.array(filter_data3), + } + + static_memory_scope = [ + "global", + "global", + "global.texture-weight", + "global.texture-weight", + "global.texture", + "global.texture-weight", + 'global', + "global.texture", + "global.texture-weight", + "", + "" + ] + + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + + +@tvm.testing.requires_opencl +def test_plan_device_issue1(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 32, 40, 40) + filter_shape1 = (32, 32, 2, 2) + filter_shape2 = (32, 32, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + + conv1 = relay.nn.conv2d( + A, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + conv2 = relay.nn.conv2d( + conv1, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + mod = relay.Function([A, W1, W2], conv2) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data1 = np.zeros(filter_shape1).astype(dtype) + initializer("weight", filter_data1) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + initializer("weight", filter_data2) + params1 = { + "weight1": tvm.nd.array(filter_data1), + "weight2": tvm.nd.array(filter_data2), + } + + # static_memory_scope = [ + # "global", + # "global", + # "global.texture-weight", + # "global.texture-weight", + # "global.texture-nhwc", + # "global.texture-weight", + # "global.texture-weight", + # "global", + # "global", + # ] + + static_memory_scope = [] + + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + +@tvm.testing.requires_opencl +def test_branch_textures(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 32, 40, 40) + filter_shape1 = (96, 32, 2, 2) + filter_shape2 = (32, 96, 2, 2) + filter_shape3 = (5, 96, 2, 2) + bias_shape1 = (1, 96, 1, 1) + bias_shape2 = (1, 32, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype) + B2 = relay.var("bias2", shape=bias_shape2, dtype=dtype) + + # C = relay.nn.relu(A) + conv1 = relay.nn.conv2d( + A, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=96, + kernel_size=(2, 2), + ) + D = relay.op.add(conv1, B1) + D = relay.op.nn.relu(D) + + conv2 = relay.nn.conv2d( + D, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + conv2 = relay.op.add(conv2, B2) + conv2 = relay.op.nn.relu(conv2) + + conv3 = relay.nn.conv2d( + D, + W3, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[2, 2], + out_dtype=dtype, + channels=5, + kernel_size=(2, 2), + ) + + t = relay.Tuple([conv2, conv3]) + c = relay.op.concatenate(t, axis=1) + + + mod = relay.Function([A, W1, B1, W2, B2, W3], c) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data1 = np.zeros(filter_shape1).astype(dtype) + bias_data1 = np.zeros(bias_shape1).astype(dtype) + initializer("weight", filter_data1) + initializer("bias", bias_data1) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + bias_data2 = np.zeros(bias_shape2).astype(dtype) + initializer("weight", filter_data2) + initializer("bias", bias_data2) + filter_data3 = np.zeros(filter_shape3).astype(dtype) + initializer("weight", filter_data3) + params1 = { + "weight1": tvm.nd.array(filter_data1), + "bias1": tvm.nd.array(bias_data1), + "weight2": tvm.nd.array(filter_data2), + "bias2": tvm.nd.array(bias_data2), + "weight3": tvm.nd.array(filter_data3), + } + + # static_memory_scope = [ + # "global", + # "global", + # "global.texture-weight", + # "global.texture-weight", + # "global.texture-nhwc", + # "global.texture-weight", + # "global.texture-weight", + # "global", + # "global", + # ] + + static_memory_scope = [] + + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + + +@tvm.testing.requires_opencl +def test_branch1_texture_params(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 32, 40, 40) + filter_shape0 = (32, 32, 1, 1) + filter_shape1 = (32, 32, 2, 2) + filter_shape2 = (32, 32, 1, 1) + filter_shape3 = (32, 32, 2, 2) + bias_shape1 = (1, 32, 1, 1) + # bias_shape2 = (1, 32, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + W0 = relay.var("weight0", shape=filter_shape0, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype) + + conv0 = relay.nn.conv2d( + A, + W0, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + pool = relay.nn.avg_pool2d(conv0, pool_size=(2, 2), strides=(2, 2)) + conv1 = relay.nn.conv2d( + pool, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + conv1 = relay.op.add(conv1, B1) + conv1 = relay.op.nn.relu(conv1) + + conv2 = relay.nn.conv2d( + pool, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + conv3 = relay.nn.conv2d( + pool, + W3, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 1, 1, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + conv3 = relay.op.nn.relu(conv3) + res = relay.op.add(conv1, conv2) + res = relay.op.add(res, conv3) + + mod = relay.Function([A, W0, W1, B1, W2, W3], res) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data0 = np.zeros(filter_shape0).astype(dtype) + filter_data1 = np.zeros(filter_shape1).astype(dtype) + bias_data1 = np.zeros(bias_shape1).astype(dtype) + initializer("weight", filter_data1) + initializer("bias", bias_data1) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + initializer("weight", filter_data2) + filter_data3 = np.zeros(filter_shape3).astype(dtype) + initializer("weight", filter_data3) + params1 = { + "weight0": tvm.nd.array(filter_data0), + "weight1": tvm.nd.array(filter_data1), + "bias1": tvm.nd.array(bias_data1), + "weight2": tvm.nd.array(filter_data2), + "weight3": tvm.nd.array(filter_data3), + } + + static_memory_scope = [ + # "global", + # "global", + # "global.texture-weight", + # "global.texture-weight", + # "global.texture", + # "global.texture-weight", + # 'global', + # "global.texture", + # "global.texture-weight", + # "", + # "" + ] + + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + + +# conv2d <- to get textures +# / \ \ <- here should be textures and textures in params +# conv2d conv2d conv2d +# \ / +# add <- tail required to have the only one output +# \ / +# add +@tvm.testing.requires_opencl +def test_branch2_texture_params(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 32, 40, 40) + filter_shape0 = (32, 32, 1, 1) + filter_shape1 = (32, 32, 2, 2) + filter_shape2 = (32, 32, 1, 1) + filter_shape3 = (32, 32, 2, 2) + bias_shape1 = (1, 32, 1, 1) + # bias_shape2 = (1, 32, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + W0 = relay.var("weight0", shape=filter_shape0, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype) + + conv0 = relay.nn.conv2d( + A, + W0, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + conv1 = relay.nn.conv2d( + conv0, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + conv1 = relay.op.add(conv1, B1) + conv1 = relay.op.nn.relu(conv1) + + conv2 = relay.nn.conv2d( + conv0, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + conv3 = relay.nn.conv2d( + conv0, + W3, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 1, 1, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(2, 2), + ) + conv3 = relay.op.nn.relu(conv3) + res = relay.op.add(conv1, conv2) + res = relay.op.add(res, conv3) + + mod = relay.Function([A, W0, W1, B1, W2, W3], res) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data0 = np.zeros(filter_shape0).astype(dtype) + filter_data1 = np.zeros(filter_shape1).astype(dtype) + bias_data1 = np.zeros(bias_shape1).astype(dtype) + initializer("weight", filter_data1) + initializer("bias", bias_data1) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + initializer("weight", filter_data2) + filter_data3 = np.zeros(filter_shape3).astype(dtype) + initializer("weight", filter_data3) + params1 = { + "weight0": tvm.nd.array(filter_data0), + "weight1": tvm.nd.array(filter_data1), + "bias1": tvm.nd.array(bias_data1), + "weight2": tvm.nd.array(filter_data2), + "weight3": tvm.nd.array(filter_data3), + } + + static_memory_scope = [ + # "global", + # "global", + # "global.texture-weight", + # "global.texture-weight", + # "global.texture", + # "global.texture-weight", + # 'global', + # "global.texture", + # "global.texture-weight", + # "", + # "" + ] + + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + +# function repeat, params scope are different in reused functions +@tvm.testing.requires_opencl +def test_conv2d_different_param_scope(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 32, 40, 40) + filter_shape1 = (32, 32, 1, 1) + A = relay.var("data", shape=input_shape, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + + conv1 = relay.nn.conv2d( + A, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + conv2 = relay.nn.conv2d( + conv1, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + conv3 = relay.nn.conv2d( + conv2, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[0, 0, 0, 0], + strides=[1, 1], + out_dtype=dtype, + channels=32, + kernel_size=(1, 1), + ) + + mod = relay.Function([A, W1], conv3) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data1 = np.zeros(filter_shape1).astype(dtype) + params1 = { + "weight1": tvm.nd.array(filter_data1), + } + + static_memory_scope = [ + # "global", + # "global", + # "global.texture-weight", + # "global.texture-weight", + # "global.texture-nhwc", + # "global.texture-weight", + # "global.texture-weight", + # "global", + # "global", + ] + + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) \ No newline at end of file diff --git a/tests/python/relay/test_conv2d_nhwc_texture.py b/tests/python/relay/test_conv2d_nhwc_texture.py index 96227ca551cf..c63d0864f814 100644 --- a/tests/python/relay/test_conv2d_nhwc_texture.py +++ b/tests/python/relay/test_conv2d_nhwc_texture.py @@ -224,7 +224,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad(): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl @@ -266,7 +266,7 @@ def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass(): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl @@ -308,7 +308,7 @@ def test_conv2d_inceptionv3_35_35_strides(): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl @@ -492,7 +492,6 @@ def test_conv2d_4x4x4_16c16pad(): B = relay.var("weight", shape=filter_shape, dtype=dtype) bias = relay.var("bias", shape=bias_shape, dtype=dtype) - # C = relay.nn.relu(A) conv = relay.nn.conv2d( A, B, diff --git a/tests/python/relay/test_depthwise_conv2d_nchw_texture.py b/tests/python/relay/test_depthwise_conv2d_nchw_texture.py index 71cf62c5d85c..c94d085b5115 100644 --- a/tests/python/relay/test_depthwise_conv2d_nchw_texture.py +++ b/tests/python/relay/test_depthwise_conv2d_nchw_texture.py @@ -64,7 +64,7 @@ def test_depthwise_conv2d_bias_nchwc(): "bias": tvm.nd.array(bias_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl @@ -103,7 +103,7 @@ def test_depthwise_conv2d_nchwc(): "weight": tvm.nd.array(filter_data), } - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess) + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, [], gpu_preprocess) @tvm.testing.requires_opencl diff --git a/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py b/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py index 16d26c77ca8e..16f9b8749909 100644 --- a/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py +++ b/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py @@ -20,7 +20,7 @@ import numpy as np from tvm import relay from tvm.relay import testing -from utils.adreno_utils import gpu_preprocess, build_run_compare +from utils.adreno_utils import build_run_compare @tvm.testing.requires_opencl diff --git a/tests/python/relay/utils/adreno_utils.py b/tests/python/relay/utils/adreno_utils.py index 6e353b22cdb4..27768c3d0cec 100644 --- a/tests/python/relay/utils/adreno_utils.py +++ b/tests/python/relay/utils/adreno_utils.py @@ -24,6 +24,7 @@ from tvm.relay import testing from tvm.relay.transform import recast from tvm.contrib import graph_runtime +import json def get_cpu_reference(mod, params1, input_shape, inputs): @@ -51,6 +52,7 @@ def build_run_compare( input_shape, dtype="float32", target="llvm", + static_mem_scopes=[], gpu_preprocess=None, stat_file=None, ): @@ -82,6 +84,19 @@ def build_run_compare( tvm_mod_nchwc, target_host=target_host, target=target, params=params1 ) + # verification that storage_scope has expected textures scopes + graph_json = json.loads(graph) + if "storage_scope" in graph_json["attrs"]: + assert ( + len(static_mem_scopes) == len(graph_json["attrs"]["storage_scope"][1]) + or len(static_mem_scopes) == 0 + ) + else: + assert len(static_mem_scopes) == 0 + + for i in range(0, len(static_mem_scopes)): + assert static_mem_scopes[i] == graph_json["attrs"]["storage_scope"][1][i] + if run_on_host: ctx = tvm.opencl() m = graph_runtime.create(graph, lib, ctx) From c58549d1b910b30a0ea148f3b8173c929442bd79 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Fri, 24 Jun 2022 14:32:30 +0400 Subject: [PATCH 02/15] lint check --- src/relay/backend/build_module.cc | 2 +- .../transforms/annotate_texture_storage.cc | 19 +++++++++---------- .../python/relay/test_conv2d_nchw_texture.py | 10 ++++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index edd401b1c22e..a372229c7e55 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -399,7 +399,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module = transform::AnnotateMemoryScope(config_)(relay_module); pass_seqs = GetPassPrefix( - /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); + /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); pass_seqs.push_back(transform::PlanDevices(config_)); // Create a sequential pass and perform optimizations. seq = transform::Sequential(pass_seqs); diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index d9d3599234b8..66ce379563a2 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -69,12 +69,12 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { // Filling the input arguments by "global" scope to handle PlanDevice algo which propagates // virtual devices from outputs to inputs. At the same time outputs must be unconstrained // to avoid useless device_copy - for (const auto& cs: storage_info.consumer_storage_scopes_) { + for (const auto& cs : storage_info.consumer_storage_scopes_) { // we have record in consumers that mean that potentially consumer // dealt with textures anyhow, it's safe to mark this expr as global scope // even without verification of the consumer's outputs scope if (storage_info.CanConsumeTextures(cs.second) && - storage_map.find(GetRef(cs.first)) == storage_map.end()) { + storage_map.find(GetRef(cs.first)) == storage_map.end()) { storage_map.Set(GetRef(cs.first), Array{"global"}); } } @@ -107,9 +107,7 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { void VisitExpr_(const VarNode* vn) final { ApplyConsumerScopeToInputs(vn); } - void VisitExpr_(const ConstantNode* cn) final { - ApplyConsumerScopeToInputs(cn); - } + void VisitExpr_(const ConstantNode* cn) final { ApplyConsumerScopeToInputs(cn); } void DeviceAwareVisitExpr_(const CallNode* call) final { // Check the contents of this primitive function @@ -170,7 +168,7 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { if (consumer_storage_scopes_.count(arg.operator->()) && GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) != "global.texture") { storage_scope_.erase(arg.operator->()); - if (const auto* cn = arg.as() ) { + if (const auto* cn = arg.as()) { if (const auto* fn = cn->op.as()) { storage_scope_.erase(fn->body.operator->()); } @@ -408,10 +406,11 @@ class VDRewriter : public transform::DeviceAwareExprMutator { memory_scope = ""; } if (!memory_scope.empty()) { - new_call = OnDevice(new_call, - VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, - virtual_device->target, memory_scope), - true); + new_call = + OnDevice(new_call, + VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, + virtual_device->target, memory_scope), + true); } return new_call; } diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index 7feb42890c17..40d171831538 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -572,6 +572,7 @@ def test_2conv2d(): build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + @tvm.testing.requires_opencl def test_residual_block(): target = "opencl --device=adreno" @@ -660,11 +661,11 @@ def test_residual_block(): "global.texture-weight", "global.texture", "global.texture-weight", - 'global', + "global", "global.texture", "global.texture-weight", "", - "" + "", ] build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) @@ -733,6 +734,7 @@ def test_plan_device_issue1(): build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + @tvm.testing.requires_opencl def test_branch_textures(): target = "opencl --device=adreno" @@ -795,7 +797,6 @@ def test_branch_textures(): t = relay.Tuple([conv2, conv3]) c = relay.op.concatenate(t, axis=1) - mod = relay.Function([A, W1, B1, W2, B2, W3], c) np.random.seed(0) initializer = relay.testing.init.Xavier() @@ -1059,6 +1060,7 @@ def test_branch2_texture_params(): build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + # function repeat, params scope are different in reused functions @tvm.testing.requires_opencl def test_conv2d_different_param_scope(): @@ -1126,4 +1128,4 @@ def test_conv2d_different_param_scope(): # "global", ] - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) \ No newline at end of file + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) From eb413120b9cdea7c72e33613b03c349c82b09a50 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Fri, 24 Jun 2022 15:18:54 +0400 Subject: [PATCH 03/15] Remove hardcoded texture limit, check through target options --- .../transforms/annotate_texture_storage.cc | 64 ++++++++++--------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index 66ce379563a2..a7898441388b 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -121,7 +121,7 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { std::string scope = "global.texture"; if (const auto* ttype = call->checked_type().as()) { if (ttype->shape.size() == 5) { - scope = Scope(ttype->shape); + scope = Scope(ttype->shape, GetVirtualDevice(GetRef(call))); } } storage_scope_[call].push_back(scope); @@ -177,37 +177,39 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { } } - std::string Scope(Array shape) { - std::map diffs; - int limit = 16384; - int a0 = shape[0].as()->value; - int a1 = shape[1].as()->value; - int a2 = shape[2].as()->value; - int a3 = shape[3].as()->value; - - int d3l = a0 * a1 * a2; - int d3r = a3; - int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l; - if (d3l < limit && d3r < limit) diffs[diff3] = ""; - - int d2l = a0 * a1; - int d2r = a2 * a3; - int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l; - if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc"; - - int d1l = a0; - int d1r = a1 * a2 * a3; - int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l; - if (d1l < limit && d1r < limit) diffs[diff1] = "weight"; - if (!diffs.empty()) { - std::string scope = "global.texture"; - if (!diffs.begin()->second.empty()) { - scope += ("-" + diffs.begin()->second); + std::string Scope(Array shape, const VirtualDevice& vd) { + if (vd != VirtualDevice::FullyUnconstrained()) { + std::map diffs; + int limit = + vd->target->GetAttr("texture_spatial_limit").value_or(Integer(16384))->value; + int a0 = shape[0].as()->value; + int a1 = shape[1].as()->value; + int a2 = shape[2].as()->value; + int a3 = shape[3].as()->value; + + int d3l = a0 * a1 * a2; + int d3r = a3; + int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l; + if (d3l < limit && d3r < limit) diffs[diff3] = ""; + + int d2l = a0 * a1; + int d2r = a2 * a3; + int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l; + if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc"; + + int d1l = a0; + int d1r = a1 * a2 * a3; + int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l; + if (d1l < limit && d1r < limit) diffs[diff1] = "weight"; + if (!diffs.empty()) { + std::string scope = "global.texture"; + if (!diffs.begin()->second.empty()) { + scope += ("-" + diffs.begin()->second); + } + return scope; } - return scope; - } else { - return "global.texture"; } + return "global"; } void ApplyConsumerScopeToInputs(const ExprNode* expr) { @@ -221,7 +223,7 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { bool expr_is_rgba_vectorizable = false; if (const auto* ttype = expr->checked_type().as()) { if (ttype->shape.size() == 5) { - scope = Scope(ttype->shape); + scope = Scope(ttype->shape, GetVirtualDevice(GetRef(expr))); if (scope != "global") { auto inner_dim = ttype->shape.back().as(); if (inner_dim && inner_dim->value == 4) { From 6b28ec637ead769e713a749293b1340164dd8f73 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Thu, 14 Jul 2022 14:29:25 +0300 Subject: [PATCH 04/15] fix cpplint --- include/tvm/relay/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 4288f2ec48ee..f60912fb012e 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -580,7 +580,7 @@ TVM_DLL Pass AnnotateUsedMemory(); */ TVM_DLL Pass CapturePostDfsIndexInSpans(); - /*! +/*! * \brief Calls device dependent memory scope analysis pass, collects mapping of desirable * expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope */ From 7b07585da787dec1b2ded7940d030cd17ae7bc1d Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Thu, 14 Jul 2022 14:33:04 +0300 Subject: [PATCH 05/15] Add winograd into annotation pass --- src/relay/transforms/annotate_texture_storage.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index a7898441388b..a58a603d1639 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -324,6 +324,11 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { attrs->kernel_layout == "OIHW4o")) { supports_texture_storage = true; } + } else if (auto attrs = call->attrs.as()) { + if ((attrs->data_layout == "NCHW4c" || attrs->data_layout == "NHWC4c") && + (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == "HWIO4o")) { + supports_texture_storage = true; + } } else if (auto attrs = call->attrs.as()) { if (attrs->layout == "NCHW4c") { supports_texture_storage = true; From 94ad477c0f3189aca09c6a19e67cb3705074a7f0 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Mon, 18 Jul 2022 13:12:17 +0300 Subject: [PATCH 06/15] fix clang --- src/relay/transforms/annotate_texture_storage.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index a58a603d1639..a7eca1ebe70c 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -326,7 +326,7 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { } } else if (auto attrs = call->attrs.as()) { if ((attrs->data_layout == "NCHW4c" || attrs->data_layout == "NHWC4c") && - (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == "HWIO4o")) { + (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == "HWIO4o")) { supports_texture_storage = true; } } else if (auto attrs = call->attrs.as()) { From 43ef0dbf4b229c50cc5deca93501b0ca0dff8447 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 19 Jul 2022 19:52:47 +0300 Subject: [PATCH 07/15] Remove extra call of PlanDevice in OptimizeImpl --- src/relay/backend/build_module.cc | 15 ++------------- src/relay/transforms/annotate_texture_storage.cc | 6 +++++- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index a372229c7e55..917ca8a4afa0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -345,6 +345,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. pass_seqs.push_back(transform::FuseOps()); + pass_seqs.push_back(transform::PlanDevices(config_)); // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); @@ -396,20 +397,8 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module = transform::Inline()(relay_module); relay_module = transform::InferType()(relay_module); relay_module = transform::LabelOps()(relay_module); - relay_module = transform::AnnotateMemoryScope(config_)(relay_module); - pass_seqs = GetPassPrefix( - /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); - pass_seqs.push_back(transform::PlanDevices(config_)); - // Create a sequential pass and perform optimizations. - seq = transform::Sequential(pass_seqs); - if (config_->optional_homogeneous_target.defined()) { - With tctx(config_->optional_homogeneous_target); - relay_module = seq(relay_module); - } else { - relay_module = seq(relay_module); - } - relay_module = transform::InferType()(relay_module); + ICHECK(relay_module.defined()); return relay_module; diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index a7eca1ebe70c..66b4523f84e2 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -504,7 +504,11 @@ Map> CollectStorageInfo(const Expr& expr) { Expr AnnotateMemoryScopeExpr(const Expr& expr, const IRModule& mod, CompilationConfig config) { auto storage_scope = CollectStorageInfo(expr); - return VDRewriter(storage_scope).Rewrite(expr); + if (storage_scope.size()) { + return VDRewriter(storage_scope).Rewrite(expr); + } else { + return expr; + } } namespace transform { From d51c1e43cd8784ada4182436e7abba932258c2a1 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 19 Jul 2022 19:54:23 +0300 Subject: [PATCH 08/15] Remove one more extra call of PlanDevice in OptimizeImpl --- src/relay/backend/build_module.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 917ca8a4afa0..7b39cb444360 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -345,7 +345,6 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. pass_seqs.push_back(transform::FuseOps()); - pass_seqs.push_back(transform::PlanDevices(config_)); // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); From dcc39f63480195d73c08b8e9007965fe2f9db0ef Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Fri, 22 Jul 2022 14:42:04 +0300 Subject: [PATCH 09/15] Fix/add scopes for static texture planning tests --- .../python/relay/test_conv2d_nchw_texture.py | 114 +++++++++--------- 1 file changed, 58 insertions(+), 56 deletions(-) diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index 40d171831538..0a71b44e9e37 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -559,7 +559,7 @@ def test_2conv2d(): } static_memory_scope = [ - "global", + "", "global", "global.texture-weight", "global.texture-weight", @@ -655,7 +655,7 @@ def test_residual_block(): } static_memory_scope = [ - "global", + "", "global", "global.texture-weight", "global.texture-weight", @@ -718,17 +718,15 @@ def test_plan_device_issue1(): "weight2": tvm.nd.array(filter_data2), } - # static_memory_scope = [ - # "global", - # "global", - # "global.texture-weight", - # "global.texture-weight", - # "global.texture-nhwc", - # "global.texture-weight", - # "global.texture-weight", - # "global", - # "global", - # ] + static_memory_scope = [ + "", + "global", + "global.texture-weight", + "global.texture", + "global.texture-weight", + "", + "" + ] static_memory_scope = [] @@ -818,17 +816,20 @@ def test_branch_textures(): "weight3": tvm.nd.array(filter_data3), } - # static_memory_scope = [ - # "global", - # "global", - # "global.texture-weight", - # "global.texture-weight", - # "global.texture-nhwc", - # "global.texture-weight", - # "global.texture-weight", - # "global", - # "global", - # ] + static_memory_scope = [ + "", + "global", + "global.texture-weight", + "global.texture-weight", + "global", + "global.texture-weight", + "global.texture-weight", + "", + "", + "", + "", + "" + ] static_memory_scope = [] @@ -929,17 +930,19 @@ def test_branch1_texture_params(): } static_memory_scope = [ - # "global", - # "global", - # "global.texture-weight", - # "global.texture-weight", - # "global.texture", - # "global.texture-weight", - # 'global', - # "global.texture", - # "global.texture-weight", - # "", - # "" + "", + "global", + "global.texture-weight", + "global.texture", + "global.texture", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture", + "global.texture-weight", + "global.texture", + "", + "" ] build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) @@ -1045,18 +1048,19 @@ def test_branch2_texture_params(): } static_memory_scope = [ - # "global", - # "global", - # "global.texture-weight", - # "global.texture-weight", - # "global.texture", - # "global.texture-weight", - # 'global', - # "global.texture", - # "global.texture-weight", - # "", - # "" - ] + "", + "global", + "global.texture-weight", + "global.texture", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture", + "global.texture-weight", + "global.texture", + "", + "" + ] build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) @@ -1117,15 +1121,13 @@ def test_conv2d_different_param_scope(): } static_memory_scope = [ - # "global", - # "global", - # "global.texture-weight", - # "global.texture-weight", - # "global.texture-nhwc", - # "global.texture-weight", - # "global.texture-weight", - # "global", - # "global", + "", + "global", + "global.texture-weight", + "global.texture", + "global.texture", + "", + "" ] build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) From 183a270084d163b562f977efdf8417d6451cc6aa Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 26 Jul 2022 15:57:08 +0300 Subject: [PATCH 10/15] Remove test_2conv2d as duplication of test_plan_device_issue --- .../python/relay/test_conv2d_nchw_texture.py | 78 ------------------- 1 file changed, 78 deletions(-) diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index 0a71b44e9e37..9f50e4c6d9ad 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -495,84 +495,6 @@ def test_conv2d_winograd_conv(): assert len(matches) > 0 -@tvm.testing.requires_opencl -def test_2conv2d(): - target = "opencl --device=adreno" - dtype = "float16" - - input_shape = (1, 32, 40, 40) - filter_shape1 = (96, 32, 2, 2) - filter_shape2 = (32, 96, 2, 2) - bias_shape1 = (1, 96, 1, 1) - bias_shape2 = (1, 32, 1, 1) - A = relay.var("data", shape=input_shape, dtype=dtype) - W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) - B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype) - W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) - B2 = relay.var("bias2", shape=bias_shape2, dtype=dtype) - - # C = relay.nn.relu(A) - conv1 = relay.nn.conv2d( - A, - W1, - data_layout="NCHW", - kernel_layout="OIHW", - padding=[0, 0, 0, 0], - strides=[2, 2], - out_dtype=dtype, - channels=96, - kernel_size=(2, 2), - ) - D = relay.op.add(conv1, B1) - D = relay.op.nn.relu(D) - - conv2 = relay.nn.conv2d( - D, - W2, - data_layout="NCHW", - kernel_layout="OIHW", - padding=[0, 0, 0, 0], - strides=[2, 2], - out_dtype=dtype, - channels=32, - kernel_size=(2, 2), - ) - D = relay.op.add(conv2, B2) - D = relay.op.nn.relu(D) - - mod = relay.Function([A, W1, B1, W2, B2], D) - np.random.seed(0) - initializer = relay.testing.init.Xavier() - filter_data1 = np.zeros(filter_shape1).astype(dtype) - bias_data1 = np.zeros(bias_shape1).astype(dtype) - initializer("weight", filter_data1) - initializer("bias", bias_data1) - filter_data2 = np.zeros(filter_shape2).astype(dtype) - bias_data2 = np.zeros(bias_shape2).astype(dtype) - initializer("weight", filter_data2) - initializer("bias", bias_data2) - params1 = { - "weight1": tvm.nd.array(filter_data1), - "bias1": tvm.nd.array(bias_data1), - "weight2": tvm.nd.array(filter_data2), - "bias2": tvm.nd.array(bias_data2), - } - - static_memory_scope = [ - "", - "global", - "global.texture-weight", - "global.texture-weight", - "global.texture-nhwc", - "global.texture-weight", - "global.texture-weight", - "", - "", - ] - - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) - - @tvm.testing.requires_opencl def test_residual_block(): target = "opencl --device=adreno" From f105c650097bf80190f52c4b4f4c089c5866a7b2 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 26 Jul 2022 15:58:51 +0300 Subject: [PATCH 11/15] remove comments in test_residual_block --- tests/python/relay/test_conv2d_nchw_texture.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index 9f50e4c6d9ad..bc056e74906f 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -505,12 +505,10 @@ def test_residual_block(): filter_shape2 = (32, 32, 1, 1) filter_shape3 = (32, 32, 2, 2) bias_shape1 = (1, 32, 1, 1) - # bias_shape2 = (1, 32, 1, 1) A = relay.var("data", shape=input_shape, dtype=dtype) W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) B1 = relay.var("bias1", shape=bias_shape1, dtype=dtype) W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) - # B2 = relay.var("bias2", shape=bias_shape2, dtype=dtype) W3 = relay.var("weight3", shape=filter_shape3, dtype=dtype) conv1 = relay.nn.conv2d( @@ -563,16 +561,13 @@ def test_residual_block(): initializer("weight", filter_data1) initializer("bias", bias_data1) filter_data2 = np.zeros(filter_shape2).astype(dtype) - # bias_data2 = np.zeros(bias_shape2).astype(dtype) initializer("weight", filter_data2) - # initializer("bias", bias_data2) filter_data3 = np.zeros(filter_shape3).astype(dtype) initializer("weight", filter_data3) params1 = { "weight1": tvm.nd.array(filter_data1), "bias1": tvm.nd.array(bias_data1), "weight2": tvm.nd.array(filter_data2), - # "bias2": tvm.nd.array(bias_data2), "weight3": tvm.nd.array(filter_data3), } From d4cc2b45d6445e308be4a51b6a31ccc26f140549 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 26 Jul 2022 16:14:06 +0300 Subject: [PATCH 12/15] address review comments --- .../transforms/annotate_texture_storage.cc | 83 +++++++------------ 1 file changed, 32 insertions(+), 51 deletions(-) diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index 66b4523f84e2..3de2abe61287 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -111,44 +111,42 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { void DeviceAwareVisitExpr_(const CallNode* call) final { // Check the contents of this primitive function - if (DeviceSupportsTextureStorage(GetRef(call))) { - if (const auto* fn = call->op.as()) { - if (fn->HasNonzeroAttr(attr::kPrimitive)) { - primitive_supports_texture_ = false; - Visit(call->op); - if (primitive_supports_texture_) { - if (call->checked_type().as()) { - std::string scope = "global.texture"; - if (const auto* ttype = call->checked_type().as()) { - if (ttype->shape.size() == 5) { - scope = Scope(ttype->shape, GetVirtualDevice(GetRef(call))); - } - } - storage_scope_[call].push_back(scope); - } else { - const auto* tuple_type = call->type_as(); - ICHECK(tuple_type); - // TODO(csullivan): Add support for mixed output storage scope. - // In current adreno storage planner all outputs of a - // primitive function are assumed to be of the same storage - // type. This should be easy to extend in the future. - for (size_t i = 0; i < tuple_type->fields.size(); i++) { - storage_scope_[call].push_back("global.texture"); + if (const auto* fn = call->op.as()) { + if (fn->HasNonzeroAttr(attr::kPrimitive)) { + primitive_supports_texture_ = false; + Visit(call->op); + if (primitive_supports_texture_) { + if (call->checked_type().as()) { + std::string scope = "global.texture"; + if (const auto* ttype = call->checked_type().as()) { + if (ttype->shape.size() == 5) { + scope = Scope(ttype->shape, GetVirtualDevice(GetRef(call))); } } - for (size_t i = 0; i < fn->params.size(); i++) { - args_to_vars_[call->args[i]].push_back(fn->params[i]); + storage_scope_[call].push_back(scope); + } else { + const auto* tuple_type = call->type_as(); + ICHECK(tuple_type); + // TODO(csullivan): Add support for mixed output storage scope. + // In current adreno storage planner all outputs of a + // primitive function are assumed to be of the same storage + // type. This should be easy to extend in the future. + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + storage_scope_[call].push_back("global.texture"); } } - // Add consumer storage scope information for call arguments - for (auto& arg : call->args) { - if (storage_scope_.count(call)) { - ICHECK(!HasMixedStorageOutputs(call)) - << "Mixed output storage scopes are not currently supported"; - consumer_storage_scopes_[arg.operator->()].push_back("global.texture"); - } else { - consumer_storage_scopes_[arg.operator->()].push_back("global"); - } + for (size_t i = 0; i < fn->params.size(); i++) { + args_to_vars_[call->args[i]].push_back(fn->params[i]); + } + } + // Add consumer storage scope information for call arguments + for (auto& arg : call->args) { + if (storage_scope_.count(call)) { + ICHECK(!HasMixedStorageOutputs(call)) + << "Mixed output storage scopes are not currently supported"; + consumer_storage_scopes_[arg.operator->()].push_back("global.texture"); + } else { + consumer_storage_scopes_[arg.operator->()].push_back("global"); } } } @@ -262,20 +260,6 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { } } - bool DeviceSupportsTextureStorage(const Expr& expr) { - auto vd = GetVirtualDevice(expr); - if (vd != VirtualDevice::FullyUnconstrained()) { - if (Optional t_device = vd->target->GetAttr("device")) { - if (vd->target->kind->device_type == kDLOpenCL && t_device.defined()) { - if (t_device.value() == "adreno") { - return true; - } - } - } - } - return false; - } - std::string GetConsumerScope(const std::vector& consumer_scopes) const { if (!consumer_scopes.size()) { return "global"; @@ -290,9 +274,6 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { } bool CanConsumeTextures(const std::vector& consumer_scopes) const { - if (!consumer_scopes.size()) { - return false; - } std::string texture_tag = "global.texture"; for (auto& consumer_scope : consumer_scopes) { if (consumer_scope.find(texture_tag) == 0) { From 9d37fee7e73a9707ddf24116aed80f82732fdddb Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 26 Jul 2022 16:15:30 +0300 Subject: [PATCH 13/15] fix black hits --- .../python/relay/test_conv2d_nchw_texture.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index bc056e74906f..f1805d478707 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -636,13 +636,13 @@ def test_plan_device_issue1(): } static_memory_scope = [ - "", - "global", - "global.texture-weight", - "global.texture", - "global.texture-weight", - "", - "" + "", + "global", + "global.texture-weight", + "global.texture", + "global.texture-weight", + "", + "", ] static_memory_scope = [] @@ -734,18 +734,18 @@ def test_branch_textures(): } static_memory_scope = [ - "", - "global", - "global.texture-weight", - "global.texture-weight", - "global", - "global.texture-weight", - "global.texture-weight", - "", - "", - "", - "", - "" + "", + "global", + "global.texture-weight", + "global.texture-weight", + "global", + "global.texture-weight", + "global.texture-weight", + "", + "", + "", + "", + "", ] static_memory_scope = [] @@ -859,7 +859,7 @@ def test_branch1_texture_params(): "global.texture-weight", "global.texture", "", - "" + "", ] build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) @@ -976,8 +976,8 @@ def test_branch2_texture_params(): "global.texture-weight", "global.texture", "", - "" - ] + "", + ] build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) @@ -1044,7 +1044,7 @@ def test_conv2d_different_param_scope(): "global.texture", "global.texture", "", - "" + "", ] build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) From 8a0fce27178c4a34dc9c5ec3929f221566e7aab9 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Thu, 28 Jul 2022 11:56:37 +0300 Subject: [PATCH 14/15] Add textures test descriptions --- .../python/relay/test_conv2d_nchw_texture.py | 153 +++++++++--------- 1 file changed, 81 insertions(+), 72 deletions(-) diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index f1805d478707..a8d99f330383 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -497,6 +497,24 @@ def test_conv2d_winograd_conv(): @tvm.testing.requires_opencl def test_residual_block(): + """ + - some kind of residual block followed by convolution to have texture after residual block + - scalar data type verification which should be mapped to global memory scope + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (1) <- to get textures as output + / \ + conv2d (2) | + \ / + add <- add should be fused into conv2d (2) + multiply to scalar <- buffer to the input of multiply scalar value + relu + | <- texture in intermediate tensor + conv2d (3) + relu + | <- buffer + layout_transform (NCHW4c->NCHW) + """ target = "opencl --device=adreno" dtype = "float16" @@ -588,70 +606,20 @@ def test_residual_block(): build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) -@tvm.testing.requires_opencl -def test_plan_device_issue1(): - target = "opencl --device=adreno" - dtype = "float16" - - input_shape = (1, 32, 40, 40) - filter_shape1 = (32, 32, 2, 2) - filter_shape2 = (32, 32, 1, 1) - A = relay.var("data", shape=input_shape, dtype=dtype) - W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) - W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) - - conv1 = relay.nn.conv2d( - A, - W1, - data_layout="NCHW", - kernel_layout="OIHW", - padding=[0, 0, 0, 0], - strides=[2, 2], - out_dtype=dtype, - channels=32, - kernel_size=(2, 2), - ) - conv2 = relay.nn.conv2d( - conv1, - W2, - data_layout="NCHW", - kernel_layout="OIHW", - padding=[0, 0, 0, 0], - strides=[1, 1], - out_dtype=dtype, - channels=32, - kernel_size=(1, 1), - ) - - mod = relay.Function([A, W1, W2], conv2) - np.random.seed(0) - initializer = relay.testing.init.Xavier() - filter_data1 = np.zeros(filter_shape1).astype(dtype) - initializer("weight", filter_data1) - filter_data2 = np.zeros(filter_shape2).astype(dtype) - initializer("weight", filter_data2) - params1 = { - "weight1": tvm.nd.array(filter_data1), - "weight2": tvm.nd.array(filter_data2), - } - - static_memory_scope = [ - "", - "global", - "global.texture-weight", - "global.texture", - "global.texture-weight", - "", - "", - ] - - static_memory_scope = [] - - build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) - @tvm.testing.requires_opencl -def test_branch_textures(): +def test_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (1) <- to get textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does not support textures, there we should have buffers + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ target = "opencl --device=adreno" dtype = "float16" @@ -754,7 +722,23 @@ def test_branch_textures(): @tvm.testing.requires_opencl -def test_branch1_texture_params(): +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ target = "opencl --device=adreno" dtype = "float16" @@ -865,15 +849,24 @@ def test_branch1_texture_params(): build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) -# conv2d <- to get textures -# / \ \ <- here should be textures and textures in params -# conv2d conv2d conv2d -# \ / -# add <- tail required to have the only one output -# \ / -# add @tvm.testing.requires_opencl -def test_branch2_texture_params(): +def test_branching_texture_params(): + """ + Verification of passing texture to several consumers markup of relay variables in + primary functions + on_device + + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (0) <- to get textures + / \ \ <- here should be textures and textures in params + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output + \ / + add <- to have the only one output + | <- buffer + layout_transform (NCHW4c->NCHW) + """ target = "opencl --device=adreno" dtype = "float16" @@ -984,7 +977,23 @@ def test_branch2_texture_params(): # function repeat, params scope are different in reused functions @tvm.testing.requires_opencl -def test_conv2d_different_param_scope(): +def test_conv2d_different_lowering_same_op(): + """ + Use case for verification of caching compiled functions + Three convolutions following by each other in this case should be + compiled in three different entities and lowered differently because + they are differ in input param memory scopes and in output memory scope + + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (1) <- buffer as input tensor and texture as output + | <- texture + conv2d (2) <- texture as input and texture as output + | <- texture + conv2d (3) <- texture as input and buffer as output + | <- buffer + layout_transform (NCHW4c->NCHW) + """ target = "opencl --device=adreno" dtype = "float16" From 1c8abb5d3d20fa4dd830b2d7cd0b67ffb35d2946 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Thu, 28 Jul 2022 12:30:57 +0300 Subject: [PATCH 15/15] Address PR comments --- .../transforms/annotate_texture_storage.cc | 44 ++++++++++++------- .../python/relay/test_conv2d_nchw_texture.py | 5 +-- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index 3de2abe61287..3dd918d962f5 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -30,7 +30,7 @@ * * - AnnotateMemoryScope calls *target.CollectStorageInfo for all target been represented * in the graph and rewrites graph modifying or inserting of VirtualDevice with required - * memory_scop collected from the CollectStorageInfo + * memory_scope collected from the CollectStorageInfo */ #include @@ -119,9 +119,7 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { if (call->checked_type().as()) { std::string scope = "global.texture"; if (const auto* ttype = call->checked_type().as()) { - if (ttype->shape.size() == 5) { - scope = Scope(ttype->shape, GetVirtualDevice(GetRef(call))); - } + scope = Scope(ttype->shape, GetVirtualDevice(GetRef(call))); } storage_scope_[call].push_back(scope); } else { @@ -175,8 +173,26 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { } } + /** + * Defines the name of the memory scope which can fit the tensor of required shape + * + * The scope stands for "global" if tensor does not satisfy current flattening rules for textures + * (texture currently has to be 5d tensors with value eq 4 in the last dimension) + * + * The packing layout inside the texture scope (the part after the dash) is defined + * during the shape itself. Hardware can have limitations on the texture spatial dimensions + * we must not exceed these sizes. In addition to the fitting of h/w limitation we want to + * get balanced packing where final spatial sizes of textures will not be too different + * @param shape shape to be analyzed + * @param vd VirtualDevice for the tensors determined of memory scope + * @return string representing memory scope either "global" or "global.texture-layout" + */ std::string Scope(Array shape, const VirtualDevice& vd) { - if (vd != VirtualDevice::FullyUnconstrained()) { + // currently we support only textures been made from 5d tensors + // 5d requirement is not limitation of textures in general, it is limitation how + // we are representing memory scopes/layout and flattening of textures in tir + if (vd != VirtualDevice::FullyUnconstrained() && shape.size() == 5 && + shape[4].as()->value == 4) { std::map diffs; int limit = vd->target->GetAttr("texture_spatial_limit").value_or(Integer(16384))->value; @@ -220,13 +236,11 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { bool expr_is_rgba_vectorizable = false; if (const auto* ttype = expr->checked_type().as()) { - if (ttype->shape.size() == 5) { - scope = Scope(ttype->shape, GetVirtualDevice(GetRef(expr))); - if (scope != "global") { - auto inner_dim = ttype->shape.back().as(); - if (inner_dim && inner_dim->value == 4) { - expr_is_rgba_vectorizable = true; - } + scope = Scope(ttype->shape, GetVirtualDevice(GetRef(expr))); + if (scope != "global") { + auto inner_dim = ttype->shape.back().as(); + if (inner_dim && inner_dim->value == 4) { + expr_is_rgba_vectorizable = true; } } } @@ -347,11 +361,11 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { * Currently this workflow supports analysis and rewriting of VirtualDevice for * Constants and function Variables */ -class VDRewriter : public transform::DeviceAwareExprMutator { +class RewriteVDStorageScopes : public transform::DeviceAwareExprMutator { using VarMap = std::unordered_map; public: - explicit VDRewriter(const Map>& storage_scope) + explicit RewriteVDStorageScopes(const Map>& storage_scope) : transform::DeviceAwareExprMutator(Optional()), storage_scope_(storage_scope) {} Function Rewrite(const Expr& expr) { return Downcast(Mutate(expr)); } @@ -486,7 +500,7 @@ Map> CollectStorageInfo(const Expr& expr) { Expr AnnotateMemoryScopeExpr(const Expr& expr, const IRModule& mod, CompilationConfig config) { auto storage_scope = CollectStorageInfo(expr); if (storage_scope.size()) { - return VDRewriter(storage_scope).Rewrite(expr); + return RewriteVDStorageScopes(storage_scope).Rewrite(expr); } else { return expr; } diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index a8d99f330383..58590998fdd2 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -606,7 +606,6 @@ def test_residual_block(): build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) - @tvm.testing.requires_opencl def test_concat(): """ @@ -737,7 +736,7 @@ def test_pooling_branching_texture_params(): \ / add <- to have the only one output, will be fused | <- buffer - layout_transform (NCHW4c->NCHW) + layout_transform (NCHW4c->NCHW) """ target = "opencl --device=adreno" dtype = "float16" @@ -865,7 +864,7 @@ def test_branching_texture_params(): \ / add <- to have the only one output | <- buffer - layout_transform (NCHW4c->NCHW) + layout_transform (NCHW4c->NCHW) """ target = "opencl --device=adreno" dtype = "float16"