Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file data_type_rewriter.h
* \brief Rewrite the data type of expressions.
*/
#ifndef TVM_TIR_DATA_TYPE_REWRITER_H_
#define TVM_TIR_DATA_TYPE_REWRITER_H_

#include <tvm/tir/stmt_functor.h>

#include <unordered_map>

namespace tvm {
namespace tir {

/*!
* \brief Legalize the data types of expressions to make sure they are consistent with other
* parts of the program.
*
* It enforces the following rules:
* - The data type of the index variable in a loop must be consistent with the data type of the loop
* bounds.
* - The data type of the binary and ternary expressions must be consistent with the data types of
* each of their operands.
* - The data type of the bounds and binding values of block iter vars must be consistent with the
* data type of the block iter vars.
*
* Usually we enforce the consistency of data types when constructing the IR nodes. However, such
* inconsistency may happen as a result of IR mutation in some passes. This class can be used as
* base class of such passes to ensure the consistency of data types.
*/
class DataTypeLegalizer : public StmtExprMutator {
protected:
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
PrimExpr VisitExpr_(const SelectNode* op) override;
PrimExpr VisitExpr_(const RampNode* op) override;
PrimExpr VisitExpr_(const AddNode* op) override;
PrimExpr VisitExpr_(const SubNode* op) override;
PrimExpr VisitExpr_(const MulNode* op) override;
PrimExpr VisitExpr_(const DivNode* op) override;
PrimExpr VisitExpr_(const ModNode* op) override;
PrimExpr VisitExpr_(const FloorDivNode* op) override;
PrimExpr VisitExpr_(const FloorModNode* op) override;
PrimExpr VisitExpr_(const MinNode* op) override;
PrimExpr VisitExpr_(const MaxNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
PrimExpr VisitExpr_(const LTNode* op) override;
PrimExpr VisitExpr_(const LENode* op) override;
PrimExpr VisitExpr_(const GTNode* op) override;
PrimExpr VisitExpr_(const GENode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
PrimExpr VisitExpr_(const CastNode* op) override;

using StmtExprMutator::VisitExpr_;
using StmtExprMutator::VisitStmt_;

// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
};

/*!
* \brief Data type rewriter for buffer indices.
*
* Detect the components of buffer indices that should be considered for data type rewriting.
* This class doesn't perform actual rewriting of data types. During recursive visiting, the
* internal flags `is_enabled_` and `is_conditional_` are used to indicate whether the current
* expression is a buffer index or a conditional expression, which can be used in the sub-classes to
* implement different rewriting rules.
*/
class IndexDataTypeRewriter : public DataTypeLegalizer {
protected:
using Parent = DataTypeLegalizer;
using Parent::VisitExpr_;
using Parent::VisitStmt_;

Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
Array<PrimExpr> VisitIndices(Array<PrimExpr> indices);
Stmt VisitStmt_(const IfThenElseNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
PrimExpr VisitExpr_(const LTNode* op) override;
PrimExpr VisitExpr_(const LENode* op) override;
PrimExpr VisitExpr_(const GTNode* op) override;
PrimExpr VisitExpr_(const GENode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
Stmt VisitStmt_(const ForNode* op) override;

Buffer VisitBuffer(const Buffer& buffer);
Buffer GetRemappedBuffer(const Buffer& buffer);
Map<String, ObjectRef> VisitBlockAnnotations(const Map<String, ObjectRef>& annotations);
BufferRegion VisitBufferRegion(const BufferRegion& region);
IterVar VisitIterVar(const IterVar& iter_var);
// indicator of index expr to rewrite
bool is_enabled_{false};
// indicator of condition
bool is_condition_{false};

Map<Var, Var> var_remap_;
Map<Buffer, Buffer> buffer_remap_;
};

/*!
* \brief Normalize the data types of buffer shapes and indices to the same data type.
*
* This pass rewrites the data types of buffer shapes and indices to the specified data type. It
* assumes the specified data type is large enough to hold the original ranges of buffer shapes and
* indices.
*/
class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
public:
explicit IndexDataTypeNormalizer(DataType target_data_type);
PrimFunc Rewrite(PrimFunc func);

protected:
using Parent = IndexDataTypeRewriter;
using Parent::VisitExpr_;
using Parent::VisitStmt_;
PrimExpr VisitExpr_(const IntImmNode* op) final;
PrimExpr VisitExpr_(const VarNode* op) final;

DataType target_data_type_ = DataType::Int(64);
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_DATA_TYPE_REWRITER_H_
1 change: 1 addition & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ class IfThenElse : public Stmt {
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode);
};

/*!
Expand Down
50 changes: 0 additions & 50 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,56 +485,6 @@ bool ContainsNode(const Stmt& stmt) {
return visitor.contains_node;
}

/*!
* \brief Legalize the data types of expressions to make sure they are consistent with other
* parts of the program.
*
* It enforces the following rules:
* - The data type of the index variable in a loop must be consistent with the data type of the loop
* bounds.
* - The data type of the binary and ternary expressions must be consistent with the data types of
* each of their operands.
* - The data type of the bounds and binding values of block iter vars must be consistent with the
* data type of the block iter vars.
*
* Usually we enforce the consistency of data types when constructing the IR nodes. However, such
* inconsistency may happen as a result of IR mutation in some passes. This class can be used as
* base class of such passes to ensure the consistency of data types.
*/
class DataTypeLegalizer : public StmtExprMutator {
protected:
Stmt VisitStmt_(const ForNode* op) override;

Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
PrimExpr VisitExpr_(const SelectNode* op) override;
PrimExpr VisitExpr_(const RampNode* op) override;
PrimExpr VisitExpr_(const AddNode* op) override;
PrimExpr VisitExpr_(const SubNode* op) override;
PrimExpr VisitExpr_(const MulNode* op) override;
PrimExpr VisitExpr_(const DivNode* op) override;
PrimExpr VisitExpr_(const ModNode* op) override;
PrimExpr VisitExpr_(const FloorDivNode* op) override;
PrimExpr VisitExpr_(const FloorModNode* op) override;
PrimExpr VisitExpr_(const MinNode* op) override;
PrimExpr VisitExpr_(const MaxNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
PrimExpr VisitExpr_(const LTNode* op) override;
PrimExpr VisitExpr_(const LENode* op) override;
PrimExpr VisitExpr_(const GTNode* op) override;
PrimExpr VisitExpr_(const GENode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;

using StmtExprMutator::VisitExpr_;
using StmtExprMutator::VisitStmt_;

// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
};

} // namespace tir
} // namespace tvm

Expand Down
8 changes: 5 additions & 3 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# pylint: disable=invalid-name
from numbers import Integral as _Integral
from typing import List
from typing import List, Optional

import tvm._ffi
import tvm.arith._ffi_api
Expand Down Expand Up @@ -566,7 +566,9 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
return tvm.tir.IterVar(dom, name, 2, thread_tag, span)


def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
def create_prim_func(
ops: List[_tensor.Tensor], index_dtype_override: Optional[str] = None
) -> tvm.tir.PrimFunc:
"""Create a TensorIR PrimFunc from tensor expression

Parameters
Expand Down Expand Up @@ -618,4 +620,4 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
"""
if not isinstance(ops, (list, tuple, Array)):
ops = [ops]
return _ffi_api.CreatePrimFunc(ops)
return _ffi_api.CreatePrimFunc(ops, index_dtype_override)
2 changes: 1 addition & 1 deletion src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ Optional<tir::PrimFunc> DefaultTIRConverterImpl(const Array<te::Tensor>& args,
return NullOpt;
}
}
PrimFunc func = te::CreatePrimFuncWithConstants(args, constants);
PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, DataType::Int(64));
bool dynamic_loop_extent = false;
tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void {
if (const auto* loop = obj.as<tir::ForNode>()) {
Expand Down
25 changes: 20 additions & 5 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ir/name_supply.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>

Expand Down Expand Up @@ -486,7 +487,8 @@ PrimFunc GenerateAndCompletePrimFunc(const Array<te::Tensor>& arg_list,
}

PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants) {
const Array<runtime::NDArray>& constants,
std::optional<DataType> index_dtype_override) {
// Infomations used in CreatePrimFunc and its sub-functions.
CreateFuncInfo info(arg_list);
// Root body stmts.
Expand All @@ -508,14 +510,27 @@ PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
// Step 4. Create func and complete prim func.
auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info);
func = tir::BindParams(func, constants);
return LayoutFreePlaceholdersNormalizer().Process(std::move(func));
if (index_dtype_override.has_value()) {
func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func));
}
auto result = LayoutFreePlaceholdersNormalizer().Process(std::move(func));
return result;
}

PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
return CreatePrimFuncWithConstants(arg_list, {});
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
std::optional<DataType> index_dtype_override) {
return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override);
}

TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc);
TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
Array<te::Tensor> arg_list = args[0];
std::optional<DataType> index_dtype_override{std::nullopt};
// Add conversion to make std::optional compatible with FFI.
if (args[1].type_code() != kTVMNullptr) {
index_dtype_override = args[1].operator DataType();
}
*ret = CreatePrimFunc(arg_list, index_dtype_override);
});

} // namespace tir
} // namespace tvm
8 changes: 6 additions & 2 deletions src/te/operation/create_primfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,23 @@
#include <tvm/te/tensor.h>
#include <tvm/tir/function.h>

#include <optional>

namespace tvm {
namespace tir {

/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list);
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list,
std::optional<DataType> index_dtype_override = std::nullopt);

/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the
* constants array is N, the last N tensors in arg_list will be treated as constant tensors.
* Constant tensors will not be part of the parameters of the created PrimFunc, instead constants
* will be embedded in the body as AllocateConstNode.
*/
PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants);
const Array<runtime::NDArray>& constants,
std::optional<DataType> index_dtype_override = std::nullopt);

} // namespace tir
} // namespace tvm
Expand Down
Loading