From 85626a3dedda820f41898f866614bf6772e59001 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 10 Oct 2022 17:48:00 -0700 Subject: [PATCH 1/9] [TIR] Unify index data type when creating prim func --- include/tvm/tir/data_type_rewriter.h | 153 ++++++++ include/tvm/tir/stmt.h | 1 + include/tvm/tir/stmt_functor.h | 50 --- python/tvm/te/operation.py | 8 +- src/relay/backend/utils.cc | 2 +- src/te/operation/create_primfunc.cc | 25 +- src/te/operation/create_primfunc.h | 8 +- src/tir/ir/data_type_rewriter.cc | 361 +++++++++++++++++- src/tir/ir/stmt_functor.cc | 1 + src/tir/transforms/lower_match_buffer.cc | 29 +- src/tir/transforms/narrow_datatype.cc | 141 ++----- tests/cpp/data_type_rewriter_test.cc | 2 +- .../unittest/test_te_create_primfunc.py | 25 +- .../test_tir_transform_narrow_datatype.py | 33 +- 14 files changed, 647 insertions(+), 192 deletions(-) create mode 100644 include/tvm/tir/data_type_rewriter.h diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h new file mode 100644 index 000000000000..409732507287 --- /dev/null +++ b/include/tvm/tir/data_type_rewriter.h @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file data_type_rewriter.h + * \brief Rewrite the data type of expressions. + */ +#ifndef TVM_TIR_DATA_TYPE_REWRITER_H_ +#define TVM_TIR_DATA_TYPE_REWRITER_H_ + +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Legalize the data types of expressions to make sure they are consistent with other + * parts of the program. + * + * It enforces the following rules: + * - The data type of the index variable in a loop must be consistent with the data type of the loop + * bounds. + * - The data type of the binary and ternary expressions must be consistent with the data types of + * each of their operands. + * - The data type of the bounds and binding values of block iter vars must be consistent with the + * data type of the block iter vars. + * + * Usually we enforce the consistency of data types when constructing the IR nodes. However, such + * inconsistency may happen as a result of IR mutation in some passes. This class can be used as + * base class of such passes to ensure the consistency of data types. + */ +class DataTypeLegalizer : public StmtExprMutator { + protected: + Stmt VisitStmt_(const ForNode* op) override; + Stmt VisitStmt_(const AttrStmtNode* op) override; + Stmt VisitStmt_(const BlockRealizeNode* op) override; + Stmt VisitStmt_(const BlockNode* op) override; + PrimExpr VisitExpr_(const SelectNode* op) override; + PrimExpr VisitExpr_(const RampNode* op) override; + PrimExpr VisitExpr_(const AddNode* op) override; + PrimExpr VisitExpr_(const SubNode* op) override; + PrimExpr VisitExpr_(const MulNode* op) override; + PrimExpr VisitExpr_(const DivNode* op) override; + PrimExpr VisitExpr_(const ModNode* op) override; + PrimExpr VisitExpr_(const FloorDivNode* op) override; + PrimExpr VisitExpr_(const FloorModNode* op) override; + PrimExpr VisitExpr_(const MinNode* op) override; + PrimExpr VisitExpr_(const MaxNode* op) override; + PrimExpr VisitExpr_(const EQNode* op) override; + PrimExpr VisitExpr_(const NENode* op) override; + PrimExpr VisitExpr_(const LTNode* op) override; + PrimExpr VisitExpr_(const LENode* op) override; + PrimExpr VisitExpr_(const GTNode* op) override; + PrimExpr VisitExpr_(const GENode* op) override; + PrimExpr VisitExpr_(const CallNode* op) override; + + using StmtExprMutator::VisitExpr_; + using StmtExprMutator::VisitStmt_; + + // a map from IterVar before rewrite to that after rewrite, + // ensures one old IterVar maps to exactly one new IterVar + std::unordered_map ivmap_; +}; + +/*! + * \brief Data type rewriter for buffer indices. + * + * Detect the components of buffer indices that should be considered for data type rewriting. + * This class doesn't perform actual rewriting of data types. During recursive visiting, the + * internal flags `is_enabled_` and `is_conditional_` are used to indicate whether the current + * expression is a buffer index or a conditional expression, which can be used in the sub-classes to + * implement different rewriting rules. + */ +class IndexDataTypeRewriter : public DataTypeLegalizer { + using Parent = DataTypeLegalizer; + + protected: + Stmt VisitStmt_(const BlockRealizeNode* op) override; + Stmt VisitStmt_(const BlockNode* op) override; + Stmt VisitStmt_(const BufferStoreNode* op) override; + PrimExpr VisitExpr_(const BufferLoadNode* op) override; + Array VisitIndices(Array indices); + Stmt VisitStmt_(const IfThenElseNode* op) override; + Stmt VisitStmt_(const DeclBufferNode* op) override; + Stmt VisitStmt_(const AllocateNode* op) override; + PrimExpr VisitExpr_(const EQNode* op) override; + PrimExpr VisitExpr_(const NENode* op) override; + PrimExpr VisitExpr_(const LTNode* op) override; + PrimExpr VisitExpr_(const LENode* op) override; + PrimExpr VisitExpr_(const GTNode* op) override; + PrimExpr VisitExpr_(const GENode* op) override; + PrimExpr VisitExpr_(const CallNode* op) override; + Stmt VisitStmt_(const ForNode* op) override; + + using DataTypeLegalizer::VisitExpr_; + using DataTypeLegalizer::VisitStmt_; + + Buffer VisitBuffer(const Buffer& buffer); + Buffer GetRemappedBuffer(const Buffer& buffer); + Map VisitBlockAnnotations(const Map& annotations); + BufferRegion VisitBufferRegion(const BufferRegion& region); + IterVar VisitIterVar(const IterVar& iter_var); + // indicator of index expr to rewrite + bool is_enabled_{false}; + // indicator of condition + bool is_condition_{false}; + + Map var_remap_; + Map buffer_remap_; +}; + +/*! + * \brief Normalize the data types of buffer shapes and indices to the same data type. + * + * This pass rewrites the data types of buffer shapes and indices to the specified data type. It + * assumes the specified data type is large enough to hold the original ranges of buffer shapes and + * indices. + */ +class IndexDataTypeNormalizer : public IndexDataTypeRewriter { + public: + explicit IndexDataTypeNormalizer(DataType target_data_type); + PrimFunc Rewrite(PrimFunc func); + + private: + PrimExpr VisitExpr_(const IntImmNode* op) final; + PrimExpr VisitExpr_(const VarNode* op) final; + PrimExpr VisitExpr_(const SizeVarNode* op) final; + + DataType target_data_type_ = DataType::Int(64); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_DATA_TYPE_REWRITER_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index e0e191b282e5..6865326b8849 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -858,6 +858,7 @@ class IfThenElse : public Stmt { Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode); }; /*! diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 8057108803db..9f4b4b40e4cd 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -485,56 +485,6 @@ bool ContainsNode(const Stmt& stmt) { return visitor.contains_node; } -/*! - * \brief Legalize the data types of expressions to make sure they are consistent with other - * parts of the program. - * - * It enforces the following rules: - * - The data type of the index variable in a loop must be consistent with the data type of the loop - * bounds. - * - The data type of the binary and ternary expressions must be consistent with the data types of - * each of their operands. - * - The data type of the bounds and binding values of block iter vars must be consistent with the - * data type of the block iter vars. - * - * Usually we enforce the consistency of data types when constructing the IR nodes. However, such - * inconsistency may happen as a result of IR mutation in some passes. This class can be used as - * base class of such passes to ensure the consistency of data types. - */ -class DataTypeLegalizer : public StmtExprMutator { - protected: - Stmt VisitStmt_(const ForNode* op) override; - - Stmt VisitStmt_(const AttrStmtNode* op) override; - Stmt VisitStmt_(const BlockRealizeNode* op) override; - Stmt VisitStmt_(const BlockNode* op) override; - PrimExpr VisitExpr_(const SelectNode* op) override; - PrimExpr VisitExpr_(const RampNode* op) override; - PrimExpr VisitExpr_(const AddNode* op) override; - PrimExpr VisitExpr_(const SubNode* op) override; - PrimExpr VisitExpr_(const MulNode* op) override; - PrimExpr VisitExpr_(const DivNode* op) override; - PrimExpr VisitExpr_(const ModNode* op) override; - PrimExpr VisitExpr_(const FloorDivNode* op) override; - PrimExpr VisitExpr_(const FloorModNode* op) override; - PrimExpr VisitExpr_(const MinNode* op) override; - PrimExpr VisitExpr_(const MaxNode* op) override; - PrimExpr VisitExpr_(const EQNode* op) override; - PrimExpr VisitExpr_(const NENode* op) override; - PrimExpr VisitExpr_(const LTNode* op) override; - PrimExpr VisitExpr_(const LENode* op) override; - PrimExpr VisitExpr_(const GTNode* op) override; - PrimExpr VisitExpr_(const GENode* op) override; - PrimExpr VisitExpr_(const CallNode* op) override; - - using StmtExprMutator::VisitExpr_; - using StmtExprMutator::VisitStmt_; - - // a map from IterVar before rewrite to that after rewrite, - // ensures one old IterVar maps to exactly one new IterVar - std::unordered_map ivmap_; -}; - } // namespace tir } // namespace tvm diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 5279c46aebc2..a9843005b243 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -19,7 +19,7 @@ # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List +from typing import List, Optional import tvm._ffi import tvm.arith._ffi_api @@ -566,7 +566,9 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None): return tvm.tir.IterVar(dom, name, 2, thread_tag, span) -def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: +def create_prim_func( + ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None +) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression Parameters @@ -618,4 +620,4 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: """ if not isinstance(ops, (list, tuple, Array)): ops = [ops] - return _ffi_api.CreatePrimFunc(ops) + return _ffi_api.CreatePrimFunc(ops, index_dtype_override) diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 51bcab527d1b..183a3094e473 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -416,7 +416,7 @@ Optional DefaultTIRConverterImpl(const Array& args, return NullOpt; } } - PrimFunc func = te::CreatePrimFuncWithConstants(args, constants); + PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, DataType::Int(64)); bool dynamic_loop_extent = false; tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { if (const auto* loop = obj.as()) { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 80da5a727926..912be7935a3b 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -486,7 +487,8 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, } PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants) { + const Array& constants, + std::optional index_dtype_override) { // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. @@ -508,14 +510,27 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, // Step 4. Create func and complete prim func. auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); func = tir::BindParams(func, constants); - return LayoutFreePlaceholdersNormalizer().Process(std::move(func)); + if (index_dtype_override.has_value()) { + func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); + } + auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func)); + return result; } -PrimFunc CreatePrimFunc(const Array& arg_list) { - return CreatePrimFuncWithConstants(arg_list, {}); +PrimFunc CreatePrimFunc(const Array& arg_list, + std::optional index_dtype_override) { + return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); +TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) { + Array arg_list = args[0]; + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[1].type_code() != kTVMNullptr) { + index_dtype_override = args[1].operator DataType(); + } + *ret = CreatePrimFunc(arg_list, index_dtype_override); +}); } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index b68d30a2fb82..4246347a16f3 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -24,11 +24,14 @@ #include #include +#include + namespace tvm { namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ -PrimFunc CreatePrimFunc(const Array& arg_list); +PrimFunc CreatePrimFunc(const Array& arg_list, + std::optional index_dtype_override = std::nullopt); /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the * constants array is N, the last N tensors in arg_list will be treated as constant tensors. @@ -36,7 +39,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list); * will be embedded in the body as AllocateConstNode. */ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants); + const Array& constants, + std::optional index_dtype_override = std::nullopt); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 102989acf6e0..627a243d14c6 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -23,8 +23,8 @@ */ #include +#include #include -#include #include "./functor_common.h" @@ -191,5 +191,364 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { return e; } +Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { + bool is_enabled = is_enabled_; + is_enabled_ = true; + auto new_extents = op->extents.Map([this](const PrimExpr& e) { return this->VisitExpr(e); }); + auto new_cond = VisitExpr(op->condition); + is_enabled_ = is_enabled; + auto new_body = this->VisitStmt(op->body); + if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) || + !new_body.same_as(op->body)) { + Allocate new_allocate = GetRef(op); + auto* n = new_allocate.CopyOnWrite(); + n->extents = std::move(new_extents); + n->condition = std::move(new_cond); + n->body = std::move(new_body); + return std::move(new_allocate); + } else { + return GetRef(op); + } +} + +Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) { + Buffer new_buffer = VisitBuffer(op->buffer); + DeclBuffer decl_buffer = Downcast(StmtExprMutator::VisitStmt_(op)); + if (!new_buffer.same_as(op->buffer)) { + decl_buffer.CopyOnWrite()->buffer = new_buffer; + } + return std::move(decl_buffer); +} + +Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) { + bool is_condition = is_condition_; + is_condition_ = true; + auto new_predicate = VisitExpr(op->predicate); + is_condition_ = is_condition; + + bool is_enabled = is_enabled_; + is_enabled_ = true; + auto new_iter_values = + op->iter_values.Map([this](const PrimExpr& e) { return this->VisitExpr(e); }); + is_enabled_ = is_enabled; + Block new_body = Downcast(this->VisitStmt(op->block)); + if (!new_predicate.same_as(op->predicate) || !new_iter_values.same_as(op->iter_values) || + !new_body.same_as(op->block)) { + BlockRealize new_block_realize = GetRef(op); + auto* n = new_block_realize.CopyOnWrite(); + n->predicate = std::move(new_predicate); + n->iter_values = std::move(new_iter_values); + n->block = std::move(new_body); + return std::move(new_block_realize); + } else { + return GetRef(op); + } +} + +Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { + Array new_alloc_buffers = + op->alloc_buffers.Map([this](const Buffer& buffer) { return this->VisitBuffer(buffer); }); + Array new_match_buffers = + op->match_buffers.Map([this](const MatchBufferRegion& match_buffer_region) { + Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer); + BufferRegion new_buffer_region = this->VisitBufferRegion(match_buffer_region->source); + if (!new_buffer.same_as(match_buffer_region->buffer) || + !new_buffer_region.same_as(match_buffer_region->source)) { + return MatchBufferRegion(new_buffer, new_buffer_region); + } else { + return match_buffer_region; + } + }); + Array new_reads = op->reads.Map( + [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); + Array new_writes = op->writes.Map( + [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); + Array new_iter_vars = + op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); }); + Optional new_init = NullOpt; + if (op->init.defined()) { + new_init = this->VisitStmt(op->init.value()); + } + Stmt new_body = this->VisitStmt(op->body); + + if (!new_init.same_as(op->init) || !new_body.same_as(op->body) || + !new_alloc_buffers.same_as(op->alloc_buffers) || + !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) || + !new_writes.same_as(op->writes) | new_iter_vars.same_as(op->iter_vars)) { + Block new_block = GetRef(op); + BlockNode* n = new_block.CopyOnWrite(); + n->alloc_buffers = std::move(new_alloc_buffers); + n->match_buffers = std::move(new_match_buffers); + n->reads = std::move(new_reads); + n->writes = std::move(new_writes); + n->iter_vars = std::move(new_iter_vars); + n->init = std::move(new_init); + n->body = std::move(new_body); + return std::move(new_block); + } + return GetRef(op); +} + +Map IndexDataTypeRewriter::VisitBlockAnnotations( + const Map& annotations) { + auto new_annotations = annotations; + + std::function f_mutate_obj = + [this, &f_mutate_obj](const ObjectRef& obj) -> ObjectRef { + if (!obj.defined()) { + return obj; + } + if (obj->IsInstance()) { + Buffer buffer = Downcast(obj); + if (Buffer new_buffer = GetRemappedBuffer(buffer); !new_buffer.same_as(buffer)) { + return new_buffer; + } + } else if (obj->IsInstance()) { + return Downcast>(obj).Map(f_mutate_obj); + } + return obj; + }; + for (const auto& [key, value] : annotations) { + auto new_value = f_mutate_obj(value); + if (!new_value.same_as(value)) { + new_annotations.Set(key, new_value); + } + } + return new_annotations; +} + +Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) { + if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) { + return (*it).second; + } + return buffer; +} + +IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) { + bool is_enabled = is_enabled_; + is_enabled_ = true; + Var new_var = Downcast(VisitExpr(iter_var->var)); + PrimExpr min = VisitExpr(iter_var->dom->min); + PrimExpr extent = VisitExpr(iter_var->dom->extent); + is_enabled_ = is_enabled; + if (!new_var.same_as(iter_var->var) || !min.same_as(iter_var->dom->min) || + !extent.same_as(iter_var->dom->extent)) { + IterVar new_iter_var = iter_var; + IterVarNode* n = new_iter_var.CopyOnWrite(); + n->var = std::move(new_var); + n->dom = Range(min, extent); + return new_iter_var; + } + return iter_var; +} + +Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) { + bool is_enabled = is_enabled_; + + is_enabled_ = true; + Array new_shape = + buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); + Array new_strides = + buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); }); + auto new_elem_offset = VisitExpr(buffer->elem_offset); + is_enabled_ = is_enabled; + + if (!buffer->shape.same_as(new_shape) || !buffer->strides.same_as(new_strides) || + !buffer->elem_offset.same_as(new_elem_offset)) { + Buffer new_buffer = buffer; + BufferNode* new_buffer_node = new_buffer.CopyOnWrite(); + new_buffer_node->shape = std::move(new_shape); + new_buffer_node->strides = std::move(new_strides); + new_buffer_node->elem_offset = std::move(new_elem_offset); + buffer_remap_.Set(buffer, new_buffer); + return new_buffer; + } else { + return buffer; + } +} + +BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer_region) { + Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer); + + bool is_enabled = is_enabled_; + is_enabled_ = true; + auto new_region = buffer_region->region.Map([&](const Range& range) { + return Range::FromMinExtent(this->VisitExpr(range->min), this->VisitExpr(range->extent)); + }); + is_enabled_ = is_enabled; + + if (!remapped_buffer.same_as(buffer_region->buffer) || + !new_region.same_as(buffer_region->region)) { + return BufferRegion(remapped_buffer, new_region); + } else { + return buffer_region; + } +} + +Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { + BufferStore store = GetRef(op); + + Buffer new_buffer = GetRemappedBuffer(op->buffer); + auto value = this->VisitExpr(op->value); + auto indices = VisitIndices(op->indices); + + if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) || + !indices.same_as(op->indices)) { + auto writer = store.CopyOnWrite(); + writer->buffer = new_buffer; + writer->value = value; + writer->indices = indices; + } + + return std::move(store); +} + +PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) { + BufferLoad load = GetRef(op); + + Buffer new_buffer = GetRemappedBuffer(op->buffer); + auto indices = VisitIndices(op->indices); + + if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) { + auto writer = load.CopyOnWrite(); + writer->indices = indices; + writer->buffer = new_buffer; + } + + return std::move(load); +} + +Array IndexDataTypeRewriter::VisitIndices(Array indices) { + bool is_enabled = is_enabled_; + is_enabled_ = true; + + auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; + indices.MutateByApply(fmutate); + + is_enabled_ = is_enabled; + + return indices; +} + +Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) { + bool is_condition = is_condition_; + is_condition_ = true; + PrimExpr cond = VisitExpr(op->condition); + is_condition_ = is_condition; + + Stmt then_case = VisitStmt(op->then_case); + Optional else_case = + op->else_case.defined() ? Optional{VisitStmt(op->else_case.value())} : NullOpt; + if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) || + !else_case.same_as(op->else_case)) { + IfThenElse new_stmt = GetRef(op); + auto* n = new_stmt.CopyOnWrite(); + n->condition = std::move(cond); + n->then_case = std::move(then_case); + n->else_case = std::move(else_case); + return std::move(new_stmt); + } + return GetRef(op); +} + +Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { + bool is_enabled = is_enabled_; + is_enabled_ = true; + Var new_loop_var = Downcast(VisitExpr(op->loop_var)); + PrimExpr min = VisitExpr(op->min); + PrimExpr extent = VisitExpr(op->extent); + is_enabled_ = is_enabled; + + Stmt new_body = VisitStmt(op->body); + + if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) || + !new_body.same_as(op->body)) { + For new_for = GetRef(op); + auto* n = new_for.CopyOnWrite(); + n->loop_var = new_loop_var; + n->min = min; + n->extent = extent; + n->body = new_body; + return std::move(new_for); + } else { + return GetRef(op); + } +} + +#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \ + bool is_enabled = is_enabled_; \ + is_enabled_ = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \ + auto result = Parent::VisitExpr_(op); \ + is_enabled_ = is_enabled; \ + return std::move(result); \ + } + +DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); +DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); +DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); +DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) +DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) +DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); + +PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) { + // handle if_then_else condition + if (op->op.same_as(builtin::if_then_else())) { + bool is_condition = is_condition_; + is_condition_ = true; + PrimExpr cond = VisitExpr(op->args[0]); + is_condition_ = is_condition; + return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2])); + } + return Parent::VisitExpr_(op); +} + +#undef DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH + +IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type) + : target_data_type_(std::move(target_data_type)) {} +PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) { + Map new_buffer_map = func->buffer_map; + for (const auto& [var, buffer] : func->buffer_map) { + new_buffer_map.Set(var, VisitBuffer(buffer)); + } + PrimFuncNode* new_func = func.CopyOnWrite(); + new_func->buffer_map = std::move(new_buffer_map); + new_func->body = VisitStmt(std::move(new_func->body)); + return func; +} + +PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) { + if (is_enabled_) { + ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); + return cast(target_data_type_, GetRef(op)); + } + return GetRef(op); +} + +PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { + if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { + return (*it).second; + } + if (is_enabled_) { + Var new_var = GetRef(op).copy_with_dtype(target_data_type_); + var_remap_.Set(GetRef(op), new_var); + return std::move(new_var); + } + return GetRef(op); +} + +PrimExpr IndexDataTypeNormalizer::VisitExpr_(const SizeVarNode* op) { + if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { + return (*it).second; + } + if (is_enabled_) { + ICHECK_LE(op->dtype.bits(), target_data_type_.bits()); + Var new_var = GetRef(op).copy_with_dtype(target_data_type_); + var_remap_.Set(GetRef(op), new_var); + return std::move(new_var); + } + return GetRef(op); +} } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index e445432e5b6f..daa8fe703a08 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -21,6 +21,7 @@ */ #include #include +#include #include #include diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 9b915da6290b..b0c5c7e3e002 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -23,17 +23,19 @@ */ #include +#include #include #include #include #include +#include "../../printer/text_printer.h" #include "../ir/functor_common.h" #include "ir_utils.h" namespace tvm { namespace tir { -class MatchBufferLower : public StmtExprMutator { +class MatchBufferLower : public DataTypeLegalizer { public: explicit MatchBufferLower(const PrimFunc& func) { for (const Var& param : func->params) { @@ -188,14 +190,14 @@ class MatchBufferLower : public StmtExprMutator { Array buffer_start_indices = source_buffer->ElemOffset(indices); if (buffer_start_indices.size() == 1) { Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); - CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + CHECK(analyzer_.CanProve(truncmod(buffer_start_indices[0], buffer->offset_factor) == 0)) << "The source elem_offset " << buffer_start_indices[0] << " does not satisfy the offset_factor " << buffer->offset_factor << "."; } else { // Non-zero elem_offset is ill-defined for non-flat memory. // If needed in the future, will require `Array // elem_offsets`, with one offset for each flattened index. - Bind(buffer->elem_offset, 0); + Bind(buffer->elem_offset, make_zero(buffer->elem_offset.dtype())); } } @@ -229,7 +231,7 @@ class MatchBufferLower : public StmtExprMutator { } void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") { - CHECK_EQ(arg.dtype(), value.dtype()) + CHECK_EQ(arg.dtype().code(), value.dtype().code()) << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; // Handle recursive case value = Substitute(std::move(value), var_map_); @@ -238,7 +240,7 @@ class MatchBufferLower : public StmtExprMutator { auto it = var_map_.find(v); if (it == var_map_.end()) { var_map_.Set(v, value); - analyzer_.Bind(v, value); + // analyzer_.Bind(v, value); } else { AssertBinding((*it).second, value, arg_name); } @@ -247,10 +249,21 @@ class MatchBufferLower : public StmtExprMutator { } } + PrimExpr LookUpArgBind(const PrimExpr& arg) { + if (arg->IsInstance()) { + Var v = Downcast(arg); + if (auto it = var_map_.find(v); it != var_map_.end()) { + return (*it).second; + } + } + return arg; + } + void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs, const std::string& arg_name = "argument") { - CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name - << " unmet: " << lhs << "==" << rhs << "."; + CHECK(analyzer_.CanProve(LookUpArgBind(lhs) == rhs)) + << "The buffer match constraint for " << arg_name << " unmet: " << lhs << "==" << rhs + << "."; } private: @@ -264,7 +277,9 @@ class MatchBufferLower : public StmtExprMutator { PrimFunc LowerMatchBuffer(PrimFunc func) { auto fptr = func.CopyOnWrite(); + // LOG(INFO) << "BeforeLMB:\n" << tir::AsTVMScript(func); fptr->body = MatchBufferLower(func)(std::move(fptr->body)); + // LOG(INFO) << "AfterLMB:\n" << tir::AsTVMScript(func); return func; } diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 2d287deec44c..9c04a98cb644 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -24,11 +24,13 @@ #include #include +#include #include #include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" +#include "../../printer/text_printer.h" namespace tvm { namespace tir { @@ -102,6 +104,14 @@ class DataTypeVisitor final : public StmtExprVisitor { return StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const BlockNode* op) { + for (const IterVar& iter : op->iter_vars) { + analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); + vextent_[iter->var.as()] = iter->dom->extent.dtype(); + } + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); @@ -187,11 +197,10 @@ class DataTypeVisitor final : public StmtExprVisitor { arith::ConstIntBoundAnalyzer::BoundMapType bound_; }; -class DataTypeRewriter : public DataTypeLegalizer { - using Parent = DataTypeLegalizer; - +class NarrowDataTypeRewriter : public IndexDataTypeRewriter { public: - explicit DataTypeRewriter(int target_bits) : visitor_(target_bits) {} + using Parent = IndexDataTypeRewriter; + explicit NarrowDataTypeRewriter(int target_bits) : visitor_(target_bits) {} Stmt operator()(Stmt s) { visitor_(s); @@ -225,78 +234,30 @@ class DataTypeRewriter : public DataTypeLegalizer { return PrimExpr(); } - Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore store = GetRef(op); - - auto value = this->VisitExpr(op->value); - auto indices = VisitIndices(op->indices); - - if (!value.same_as(op->value) || !indices.same_as(op->indices)) { - auto writer = store.CopyOnWrite(); - writer->value = value; - writer->indices = indices; - } - - return std::move(store); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - BufferLoad load = GetRef(op); - - auto indices = VisitIndices(op->indices); - - if (!indices.same_as(op->indices)) { - auto writer = load.CopyOnWrite(); - writer->indices = indices; - } - - return std::move(load); - } - - Array VisitIndices(Array indices) { - is_index_ = true; - - auto fmutate = [this](const PrimExpr& index) { return this->VisitExpr(index); }; - indices.MutateByApply(fmutate); - - is_index_ = false; - - return indices; - } - - Stmt VisitStmt_(const IfThenElseNode* op) final { - IfThenElse updated = Downcast(Parent::VisitStmt_(op)); - is_condition_ = true; - PrimExpr cond = VisitExpr(op->condition); - is_condition_ = false; - if (!cond.same_as(op->condition)) { - return std::move(IfThenElse(cond, updated->then_case, updated->else_case)); - } - return std::move(updated); - } - PrimExpr VisitExpr_(const VarNode* op) final { - if (visitor_.vmap.find(op) != visitor_.vmap.end()) { - if (vmap_.find(op) == vmap_.end()) { - vmap_[op] = Var(op->name_hint, visitor_.vmap[op]); - } - return vmap_[op]; + if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { + return (*it).second; + } else if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + Var v = Var(op->name_hint, visitor_.vmap[op]); + var_remap_.Set(GetRef(op), v); + return v; } return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const SizeVarNode* op) final { - if (visitor_.vmap.find(op) != visitor_.vmap.end()) { - if (vmap_.find(op) == vmap_.end()) { - vmap_[op] = SizeVar(op->name_hint, visitor_.vmap[op]); - } - return vmap_[op]; + if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { + return (*it).second; + } else if (visitor_.vmap.find(op) != visitor_.vmap.end()) { + SizeVar v = SizeVar(op->name_hint, visitor_.vmap[op]); + var_remap_.Set(GetRef(op), v); + return v; } return Parent::VisitExpr_(op); } PrimExpr VisitExpr_(const IntImmNode* op) final { - if (is_index_) { + if (is_enabled_) { if (visitor_.vmap.find(op) != visitor_.vmap.end()) { return IntImm(visitor_.vmap[op], op->value); } @@ -305,7 +266,7 @@ class DataTypeRewriter : public DataTypeLegalizer { } PrimExpr VisitExpr_(const CastNode* op) final { - if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { + if (is_enabled_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { PrimExpr e = Parent::VisitExpr_(op); const CastNode* new_op = e.as(); ICHECK(new_op != nullptr) << "Expected type to be CastNode" @@ -315,65 +276,25 @@ class DataTypeRewriter : public DataTypeLegalizer { return Parent::VisitExpr_(op); } - PrimExpr VisitExpr_(const EQNode* op) final; - PrimExpr VisitExpr_(const NENode* op) final; - PrimExpr VisitExpr_(const LTNode* op) final; - PrimExpr VisitExpr_(const LENode* op) final; - PrimExpr VisitExpr_(const GTNode* op) final; - PrimExpr VisitExpr_(const GENode* op) final; - PrimExpr VisitExpr_(const CallNode* op) final; - private: // the internal visitor to deduce the narrowed dtype DataTypeVisitor visitor_; // a map from Var before rewrite to that after rewrite, // ensures one old Var maps to exactly one new Var std::unordered_map vmap_; - // indicator of index expr to rewrite - bool is_index_{false}; - // indicator of condition - bool is_condition_{false}; }; -#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ - bool is_index = is_index_; \ - bool rewrite = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \ - if (rewrite) { \ - is_index_ = true; \ - } \ - auto result = Parent::VisitExpr_(op); \ - is_index_ = is_index; \ - return std::move(result); \ - } - -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); - -PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { - // handle if_then_else condition - if (op->op.same_as(builtin::if_then_else())) { - bool is_condition = is_condition_; - is_condition_ = true; - PrimExpr cond = VisitExpr(op->args[0]); - is_condition_ = is_condition; - return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2])); - } - return Parent::VisitExpr_(op); +Stmt NarrowDataType(Stmt stmt, int target_bits) { + return NarrowDataTypeRewriter(target_bits)(stmt); } -Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } - namespace transform { Pass NarrowDataType(int target_bits) { auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = DataTypeRewriter(target_bits)(std::move(n->body)); + n->body = NarrowDataTypeRewriter(target_bits)(std::move(n->body)); + // LOG(INFO) << "AfterNarrow: " << tir::AsTVMScript(f); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); diff --git a/tests/cpp/data_type_rewriter_test.cc b/tests/cpp/data_type_rewriter_test.cc index d1ac9d782ce5..c5e6d4f75843 100644 --- a/tests/cpp/data_type_rewriter_test.cc +++ b/tests/cpp/data_type_rewriter_test.cc @@ -19,8 +19,8 @@ #include #include +#include #include -#include using namespace tvm; using namespace tvm::tir; diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index b59880758e5d..b7691464f4f1 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -44,8 +44,10 @@ def test_unique_name_reduction_block(): assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef) -def _check_workload(te_workload, tir_workload): - func = te.create_prim_func(te_workload()) +def _check_workload(te_workload, tir_workload, index_dtype_override=None): + func = te.create_prim_func(te_workload(), index_dtype_override) + print(func.script()) + print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload)) tvm.ir.assert_structural_equal(func, tir_workload) # make sure that we can create schedule from the func s = tir.Schedule(func, debug_mask="all") @@ -75,10 +77,29 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[i, j] += A[i, k] * B[j, k] +@T.prim_func +def tir_matmul_int64( + A: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + B: T.Buffer[(T.int64(128), T.int64(128)), "float32"], + C: T.Buffer[(T.int64(128), T.int64(128)), "float32"], +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, j0, k0 in T.grid(T.int64(128), T.int64(128), T.int64(128)): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + def test_matmul(): _check_workload(te_matmul, tir_matmul) +def test_matmul_int64(): + _check_workload(te_matmul, tir_matmul_int64, index_dtype_override="int64") + + def te_element_wise(): A = te.placeholder((128, 128), name="A") B = te.compute((128, 128), lambda x, y: A[x, y] * 2, name="B") diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 20818a5b326a..c9c513378595 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -19,6 +19,7 @@ from tvm.driver.build_module import schedule_to_module from tvm.script import tir as T from tvm.tir import const +import tvm.testing def lower_stmt(params, stmt, target_bits): @@ -324,14 +325,26 @@ def expected_after(A: T.Buffer[128, "float32"], B: T.Buffer[130, "float32"]): tvm.ir.assert_structural_equal(after, expected_after) +def test_block(): + @T.prim_func + def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + for i in T.serial(0, T.int64(16)): + for j in T.serial(0, T.int64(8)): + with T.block(): + vi = T.axis.spatial(T.int64(128), i * T.int64(8) + j) + B[vi] = A[vi] + T.float32(1) + + @T.prim_func + def expected_after(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]): + for i in T.serial(0, T.int32(16)): + for j in T.serial(0, T.int32(8)): + with T.block(): + vi = T.axis.spatial(T.int32(128), i * T.int32(8) + j) + B[vi] = A[vi] + T.float32(1) + + after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before))["main"] + tvm.ir.assert_structural_equal(after, expected_after) + + if __name__ == "__main__": - test_basic() - test_thread_axis() - test_thread_axis_2() - test_multilanes() - test_reduce() - test_slice() - test_relay_basic() - test_relay_take() - test_ramp_dtype_consistency() - test_condition() + tvm.testing.main() From 16fe93ca46ff6ce82be0b4b0610c14d125275a84 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 9 Nov 2022 17:06:48 -0800 Subject: [PATCH 2/9] fix --- .../schedule/primitive/blockize_tensorize.cc | 15 +++++++++ src/tir/transforms/lower_match_buffer.cc | 31 +++++-------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 98e30117e172..80a653c544b0 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include #include "../ir_comparator.h" @@ -523,6 +525,19 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } PrimFunc intrin_desc = intrin->desc; PrimFunc intrin_impl = DeepCopy(intrin->impl); + + int index_dtype_bits = -1; + auto f_update_max_dtype_bits_from_region = [&](const Array& buffer_regions) { + for (const BufferRegion& buffer_region : buffer_regions) { + for (const auto& range : buffer_region->region) { + index_dtype_bits = std::max(index_dtype_bits, range->min.dtype().bits()); + } + } + }; + f_update_max_dtype_bits_from_region(block_realize->block->reads); + f_update_max_dtype_bits_from_region(block_realize->block->writes); + ICHECK(index_dtype_bits > 0); + intrin_impl = IndexDataTypeNormalizer(DataType::Int(index_dtype_bits)).Rewrite(intrin_impl); // Step 2: Structural pattern matching TensorizeComparator comparator(self->mod, /*assert_mode=*/true); comparator.VisitStmt(block_realize, intrin_desc->body); diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index b0c5c7e3e002..445447d0dea6 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -23,19 +23,17 @@ */ #include -#include #include #include #include #include -#include "../../printer/text_printer.h" #include "../ir/functor_common.h" #include "ir_utils.h" namespace tvm { namespace tir { -class MatchBufferLower : public DataTypeLegalizer { +class MatchBufferLower : public StmtExprMutator { public: explicit MatchBufferLower(const PrimFunc& func) { for (const Var& param : func->params) { @@ -190,14 +188,14 @@ class MatchBufferLower : public DataTypeLegalizer { Array buffer_start_indices = source_buffer->ElemOffset(indices); if (buffer_start_indices.size() == 1) { Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); - CHECK(analyzer_.CanProve(truncmod(buffer_start_indices[0], buffer->offset_factor) == 0)) + CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) << "The source elem_offset " << buffer_start_indices[0] << " does not satisfy the offset_factor " << buffer->offset_factor << "."; } else { // Non-zero elem_offset is ill-defined for non-flat memory. // If needed in the future, will require `Array // elem_offsets`, with one offset for each flattened index. - Bind(buffer->elem_offset, make_zero(buffer->elem_offset.dtype())); + Bind(buffer->elem_offset, 0); } } @@ -231,7 +229,7 @@ class MatchBufferLower : public DataTypeLegalizer { } void Bind(const PrimExpr& arg, PrimExpr value, const std::string& arg_name = "argument") { - CHECK_EQ(arg.dtype().code(), value.dtype().code()) + CHECK_EQ(arg.dtype(), value.dtype()) << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; // Handle recursive case value = Substitute(std::move(value), var_map_); @@ -240,7 +238,7 @@ class MatchBufferLower : public DataTypeLegalizer { auto it = var_map_.find(v); if (it == var_map_.end()) { var_map_.Set(v, value); - // analyzer_.Bind(v, value); + analyzer_.Bind(v, value); } else { AssertBinding((*it).second, value, arg_name); } @@ -249,21 +247,10 @@ class MatchBufferLower : public DataTypeLegalizer { } } - PrimExpr LookUpArgBind(const PrimExpr& arg) { - if (arg->IsInstance()) { - Var v = Downcast(arg); - if (auto it = var_map_.find(v); it != var_map_.end()) { - return (*it).second; - } - } - return arg; - } - void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs, const std::string& arg_name = "argument") { - CHECK(analyzer_.CanProve(LookUpArgBind(lhs) == rhs)) - << "The buffer match constraint for " << arg_name << " unmet: " << lhs << "==" << rhs - << "."; + CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name + << " unmet: " << lhs << "==" << rhs << "."; } private: @@ -277,9 +264,7 @@ class MatchBufferLower : public DataTypeLegalizer { PrimFunc LowerMatchBuffer(PrimFunc func) { auto fptr = func.CopyOnWrite(); - // LOG(INFO) << "BeforeLMB:\n" << tir::AsTVMScript(func); fptr->body = MatchBufferLower(func)(std::move(fptr->body)); - // LOG(INFO) << "AfterLMB:\n" << tir::AsTVMScript(func); return func; } @@ -297,4 +282,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB } // namespace transform } // namespace tir -} // namespace tvm +} // namespace tvm \ No newline at end of file From 85b7a10e61693394fdf527c830ae9cc0b1a3a603 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 10 Nov 2022 14:41:07 -0800 Subject: [PATCH 3/9] update tensorize --- src/tir/transforms/lower_match_buffer.cc | 6 +++--- .../unittest/test_tir_schedule_tensorize.py | 20 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 445447d0dea6..2aa6d18b4d11 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -195,7 +195,7 @@ class MatchBufferLower : public StmtExprMutator { // Non-zero elem_offset is ill-defined for non-flat memory. // If needed in the future, will require `Array // elem_offsets`, with one offset for each flattened index. - Bind(buffer->elem_offset, 0); + Bind(buffer->elem_offset, make_const(buffer->elem_offset.dtype(), 0)); } } @@ -206,7 +206,7 @@ class MatchBufferLower : public StmtExprMutator { if (!buffer->strides.empty()) { ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); if (source_buffer->strides.empty()) { - PrimExpr stride = make_const(DataType::Int(32), 1); + PrimExpr stride = make_const(buffer->strides.back().dtype(), 1); for (size_t i = buffer->shape.size(); i > 0; --i) { const PrimExpr& shape = source_buffer->shape[i - 1 + offset]; Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1)); @@ -282,4 +282,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB } // namespace transform } // namespace tir -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index f30e91b892c5..f890a43261f5 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -697,34 +697,34 @@ def tensorized_matmul_int64_shape( ] ) T.writes(C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)]) - A_elem_offset = T.var("int32") - B_elem_offset = T.var("int32") - C_elem_offset = T.var("int32") + A_elem_offset = T.var("int64") + B_elem_offset = T.var("int64") + C_elem_offset = T.var("int64") A_sub = T.match_buffer( A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)], - [16, 16], + [T.int64(16), T.int64(16)], elem_offset=A_elem_offset, ) B_sub = T.match_buffer( B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)], - [16, 16], + [T.int64(16), T.int64(16)], elem_offset=B_elem_offset, ) C_sub = T.match_buffer( C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)], - [16, 16], + [T.int64(16), T.int64(16)], elem_offset=C_elem_offset, ) T.evaluate( T.tvm_mma_sync( C_sub.data, - T.floordiv(C_sub.elem_offset, 256), + T.floordiv(C_sub.elem_offset, T.int64(256)), A_sub.data, - T.floordiv(A_sub.elem_offset, 256), + T.floordiv(A_sub.elem_offset, T.int64(256)), B_sub.data, - T.floordiv(B_sub.elem_offset, 256), + T.floordiv(B_sub.elem_offset, T.int64(256)), C_sub.data, - T.floordiv(C_sub.elem_offset, 256), + T.floordiv(C_sub.elem_offset, T.int64(256)), dtype="handle", ) ) From 9c35fe5efd1ddc22295ddeabf7a4358724e326eb Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 14 Nov 2022 12:31:22 -0800 Subject: [PATCH 4/9] fix --- src/tir/ir/data_type_rewriter.cc | 4 +-- .../test_meta_schedule_relay_integration.py | 36 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 627a243d14c6..b022ce252396 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -466,8 +466,8 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { For new_for = GetRef(op); auto* n = new_for.CopyOnWrite(); n->loop_var = new_loop_var; - n->min = min; - n->extent = extent; + n->min = cast(new_loop_var.dtype(), min); + n->extent = cast(new_loop_var.dtype(), extent); n->body = new_body; return std::move(new_for); } else { diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index bf302cd0e5bf..021db0f86ad2 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -391,21 +391,21 @@ def test_meta_schedule_te2primfunc_argument_order_and_lowering(): class _fused_layout_transform: @T.prim_func def main( # type: ignore - placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore - T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore + placeholder: T.Buffer[(T.int64(1), T.int64(3), T.int64(16), T.int64(16)), "float32"], # type: ignore + T_layout_trans: T.Buffer[(T.int64(1), T.int64(1), T.int64(16), T.int64(16), T.int64(3)), "float32"], # type: ignore ) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3): + for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(1), T.int64(16), T.int64(16), T.int64(3)): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[ax0, ax1 * 3 + ax4, ax2, ax3]) + T.reads(placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else( - ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16, # type: ignore - placeholder[ax0, ax1 * 3 + ax4, ax2, ax3], + ax0 < T.int64(1) and ax1 * T.int64(3) + ax4 < T.int64(3) and ax2 < T.int64(16) and ax3 < T.int64(16), # type: ignore + placeholder[ax0, ax1 * T.int64(3) + ax4, ax2, ax3], T.float32(0), dtype="float32", ) @@ -413,41 +413,41 @@ def main( # type: ignore @tvm.script.ir_module class _fused_layout_transform_1: @T.prim_func - def main(placeholder: T.Buffer[(1, 2, 16, 16, 4), "float32"], T_layout_trans: T.Buffer[(1, 8, 16, 16), "float32"]) -> None: # type: ignore + def main(placeholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(16), T.int64(16), T.int64(4)), "float32"], T_layout_trans: T.Buffer[(T.int64(1), T.int64(8), T.int64(16), T.int64(16)), "float32"]) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - for i0, i1, i2, i3 in T.grid(1, 8, 16, 16): + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(8), T.int64(16), T.int64(16)): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4]) # type: ignore + T.reads(placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)]) # type: ignore T.writes(T_layout_trans[ax0, ax1, ax2, ax3]) - T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16, placeholder[ax0, ax1 // 4, ax2, ax3, ax1 % 4], T.float32(0), dtype="float32") # type: ignore + T_layout_trans[ax0, ax1, ax2, ax3] = T.if_then_else(ax0 < T.int64(1) and ax1 < T.int64(8) and ax2 < T.int64(16) and ax3 < T.int64(16), placeholder[ax0, ax1 // T.int64(4), ax2, ax3, ax1 % T.int64(4)], T.float32(0), dtype="float32") # type: ignore @tvm.script.ir_module class _fused_nn_contrib_conv2d_NCHWc: @T.prim_func - def main(placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"], placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"], conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]) -> None: # type: ignore + def main(placeholder: T.Buffer[(T.int64(1), T.int64(1), T.int64(16), T.int64(16), T.int64(3)), "float32"], placeholder_1: T.Buffer[(T.int64(2), T.int64(1), T.int64(5), T.int64(5), T.int64(3), T.int64(4)), "float32"], conv2d_NCHWc: T.Buffer[(T.int64(1), T.int64(2), T.int64(16), T.int64(16), T.int64(4)), "float32"]) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32") - for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3): + data_pad = T.alloc_buffer([T.int64(1), T.int64(1), T.int64(20), T.int64(20), T.int64(3)], dtype="float32") + for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(1), T.int64(20), T.int64(20), T.int64(3)): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1]) + T.reads(placeholder[i0_1, i1_1, i2_1 - T.int64(2), i3_1 - T.int64(2), i4_1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) - data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18, placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 - for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5): + data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(T.int64(2) <= i2_1 and i2_1 < T.int64(18) and T.int64(2) <= i3_1 and i3_1 < T.int64(18), placeholder[i0_1, i1_1, i2_1 - T.int64(2), i3_1 - T.int64(2), i4_1], T.float32(0), dtype="float32") # type: ignore # pylint: disable=R1716 + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(2), T.int64(16), T.int64(16), T.int64(4), T.int64(3), T.int64(5), T.int64(5)): with T.block("conv2d_NCHWc"): n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap("SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) - T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3], placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block]) # type: ignore + T.reads(data_pad[n, ic // T.int64(3), oh + kh, ow + kw, ic % T.int64(3)], placeholder_1[oc_chunk, ic // T.int64(3), kh, kw, ic % T.int64(3), oc_block]) # type: ignore T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block]) with T.init(): conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0) - conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3, oc_block] # type: ignore + conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] + data_pad[n, ic // T.int64(3), oh + kh, ow + kw, ic % T.int64(3)] * placeholder_1[oc_chunk, ic // T.int64(3), kh, kw, ic % T.int64(3), oc_block] # type: ignore # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument From 5ac8b3835d5bfb67696609e92c58aeb6aaf79f04 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 14 Nov 2022 18:53:47 -0800 Subject: [PATCH 5/9] fix --- include/tvm/tir/data_type_rewriter.h | 12 +++++++----- src/tir/ir/data_type_rewriter.cc | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 409732507287..9cacebf7b87a 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -90,9 +90,11 @@ class DataTypeLegalizer : public StmtExprMutator { * implement different rewriting rules. */ class IndexDataTypeRewriter : public DataTypeLegalizer { + protected: using Parent = DataTypeLegalizer; + using Parent::VisitExpr_; + using Parent::VisitStmt_; - protected: Stmt VisitStmt_(const BlockRealizeNode* op) override; Stmt VisitStmt_(const BlockNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; @@ -110,9 +112,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { PrimExpr VisitExpr_(const CallNode* op) override; Stmt VisitStmt_(const ForNode* op) override; - using DataTypeLegalizer::VisitExpr_; - using DataTypeLegalizer::VisitStmt_; - Buffer VisitBuffer(const Buffer& buffer); Buffer GetRemappedBuffer(const Buffer& buffer); Map VisitBlockAnnotations(const Map& annotations); @@ -139,7 +138,10 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter { explicit IndexDataTypeNormalizer(DataType target_data_type); PrimFunc Rewrite(PrimFunc func); - private: + protected: + using Parent = IndexDataTypeRewriter; + using Parent::VisitExpr_; + using Parent::VisitStmt_; PrimExpr VisitExpr_(const IntImmNode* op) final; PrimExpr VisitExpr_(const VarNode* op) final; PrimExpr VisitExpr_(const SizeVarNode* op) final; diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index b022ce252396..048a0ec5f833 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -274,7 +274,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { if (!new_init.same_as(op->init) || !new_body.same_as(op->body) || !new_alloc_buffers.same_as(op->alloc_buffers) || !new_match_buffers.same_as(op->match_buffers) || !new_reads.same_as(op->reads) || - !new_writes.same_as(op->writes) | new_iter_vars.same_as(op->iter_vars)) { + !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars)) { Block new_block = GetRef(op); BlockNode* n = new_block.CopyOnWrite(); n->alloc_buffers = std::move(new_alloc_buffers); From 6bc766358e63c3736c6e7be1c56382681cec54cf Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 16 Nov 2022 10:26:04 -0800 Subject: [PATCH 6/9] fix windows --- src/tir/ir/data_type_rewriter.cc | 12 ------------ src/tir/transforms/narrow_datatype.cc | 11 ----------- 2 files changed, 23 deletions(-) diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 048a0ec5f833..816ad2245e53 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -538,17 +538,5 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { return GetRef(op); } -PrimExpr IndexDataTypeNormalizer::VisitExpr_(const SizeVarNode* op) { - if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { - return (*it).second; - } - if (is_enabled_) { - ICHECK_LE(op->dtype.bits(), target_data_type_.bits()); - Var new_var = GetRef(op).copy_with_dtype(target_data_type_); - var_remap_.Set(GetRef(op), new_var); - return std::move(new_var); - } - return GetRef(op); -} } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 9c04a98cb644..9d11b1267089 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -245,17 +245,6 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { return Parent::VisitExpr_(op); } - PrimExpr VisitExpr_(const SizeVarNode* op) final { - if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { - return (*it).second; - } else if (visitor_.vmap.find(op) != visitor_.vmap.end()) { - SizeVar v = SizeVar(op->name_hint, visitor_.vmap[op]); - var_remap_.Set(GetRef(op), v); - return v; - } - return Parent::VisitExpr_(op); - } - PrimExpr VisitExpr_(const IntImmNode* op) final { if (is_enabled_) { if (visitor_.vmap.find(op) != visitor_.vmap.end()) { From 1425f88bc80e74955dae73dc32eaed18f937f78a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 16 Nov 2022 10:39:49 -0800 Subject: [PATCH 7/9] fix --- include/tvm/tir/data_type_rewriter.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 9cacebf7b87a..535ffe5d3673 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -144,7 +144,6 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter { using Parent::VisitStmt_; PrimExpr VisitExpr_(const IntImmNode* op) final; PrimExpr VisitExpr_(const VarNode* op) final; - PrimExpr VisitExpr_(const SizeVarNode* op) final; DataType target_data_type_ = DataType::Int(64); }; From 7dd291734811e3b3ed331e877e538ba1a4cbfee2 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 16 Nov 2022 10:53:07 -0800 Subject: [PATCH 8/9] address comments --- src/tir/ir/data_type_rewriter.cc | 52 +++++++++++++-------------- src/tir/transforms/narrow_datatype.cc | 1 - 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 816ad2245e53..0b9be474242e 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -138,7 +138,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { } } -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ +#define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr DataTypeLegalizer::VisitExpr_(const OP* op) { \ PrimExpr a = this->VisitExpr(op->a); \ PrimExpr b = this->VisitExpr(op->b); \ @@ -149,23 +149,23 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { } \ } -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); - -#undef DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) +TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); + +#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { PrimExpr e = StmtExprMutator::VisitExpr_(op); @@ -475,7 +475,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { } } -#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ +#define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \ bool is_enabled = is_enabled_; \ is_enabled_ = is_condition_ && op->a->dtype.is_int() && op->b->dtype.is_int(); \ @@ -484,12 +484,12 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { return std::move(result); \ } -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) -DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); +TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); +TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); +TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); +TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) +TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) +TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) { // handle if_then_else condition @@ -503,7 +503,7 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) { return Parent::VisitExpr_(op); } -#undef DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH +#undef TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type) : target_data_type_(std::move(target_data_type)) {} diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 9d11b1267089..fba813870bb1 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -283,7 +283,6 @@ Pass NarrowDataType(int target_bits) { auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = NarrowDataTypeRewriter(target_bits)(std::move(n->body)); - // LOG(INFO) << "AfterNarrow: " << tir::AsTVMScript(f); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); From c6db9f5d0c244c9bed976a8a6f900952f6982e47 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 16 Nov 2022 15:45:06 -0800 Subject: [PATCH 9/9] windows workaround --- include/tvm/tir/data_type_rewriter.h | 1 + src/tir/ir/data_type_rewriter.cc | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 535ffe5d3673..378addaba528 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -71,6 +71,7 @@ class DataTypeLegalizer : public StmtExprMutator { PrimExpr VisitExpr_(const GTNode* op) override; PrimExpr VisitExpr_(const GENode* op) override; PrimExpr VisitExpr_(const CallNode* op) override; + PrimExpr VisitExpr_(const CastNode* op) override; using StmtExprMutator::VisitExpr_; using StmtExprMutator::VisitStmt_; diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 0b9be474242e..fecb8e5fb70c 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -138,6 +138,10 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { } } +PrimExpr DataTypeLegalizer::VisitExpr_(const CastNode* op) { + return StmtExprMutator::VisitExpr_(op); +} + #define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ PrimExpr DataTypeLegalizer::VisitExpr_(const OP* op) { \ PrimExpr a = this->VisitExpr(op->a); \