Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
* \param predicate The predicate constraints on the input iterators
* \param check_level The iter mapping checking level.
* \param analyzer Analyzer used to get context information.
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return The result list has length len(bindings) + 1
[0, len(bindings)): The iter map matching result. The inner list is of length 2.
Expand All @@ -407,7 +409,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
IterMapLevel check_level, arith::Analyzer* analyzer);
IterMapLevel check_level, arith::Analyzer* analyzer,
bool simplify_trivial_iterators = true);

/*!
* \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr.
Expand Down
11 changes: 8 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -563,21 +563,26 @@ class ScheduleNode : public runtime::Object {
/*!
* \brief Convert the subtree rooted at a specific loop into a block.
* \param loop_rv the root of the subtree
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return the new block
*/
virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = true) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrin.
* \param loop_rv The loop to be tensorized
* \param intrin Name of the tensor intrinsic
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
*/
virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0;
virtual void Tensorize(const LoopRV& loop_rv, const String& intrin,
bool preserve_unit_iters = true) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrin.
* \param block_rv The block to be tensorized
* \param intrin Name of the tensor intrinsic
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
*/
virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0;
virtual void Tensorize(const BlockRV& block_rv, const String& intrin,
bool preserve_unit_iters = true) = 0;

/******** Schedule: Annotation ********/
/*!
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,12 @@ def normalize_iter_map_to_expr(expr):


def subspace_divide(
bindings, input_iters, sub_iters, predicate=True, check_level=IterMapLevel.Surjective
bindings,
input_iters,
sub_iters,
predicate=True,
check_level=IterMapLevel.Surjective,
simplify_trivial_iterators=True,
):
"""Detect if bindings can be written as
[a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
Expand Down Expand Up @@ -206,6 +211,10 @@ def subspace_divide(
check_level : Union[str, IterMapLevel]
Checking level of iteration mapping

simplify_trivial_iterators: bool
If true, iterators with extent of 1 will be replaced with a
constant value.

Returns
-------
results : List[List[PrimExpr]]
Expand All @@ -218,7 +227,9 @@ def subspace_divide(
"""
if isinstance(check_level, str):
check_level = IterMapLevel.from_str(check_level)
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, check_level)
return _ffi_api.SubspaceDivide(
bindings, input_iters, sub_iters, predicate, check_level, simplify_trivial_iterators
)


def inverse_affine_iter_map(iter_map, outputs):
Expand Down
17 changes: 13 additions & 4 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2186,13 +2186,15 @@ def after_set_scope(
########## Schedule: Blockize & Tensorize ##########

@type_checked
def blockize(self, loop: LoopRV) -> BlockRV:
def blockize(self, loop: LoopRV, preserve_unit_iters: bool = True) -> BlockRV:
"""Convert the subtree rooted at a specific loop into a block.

Parameters
----------
loop : LoopRV
The root of the subtree.
preserve_unit_iters : bool
Whether or not to preserve unit iterators in block bindings

Returns
-------
Expand Down Expand Up @@ -2257,10 +2259,15 @@ def after_blockize(
block are divisible by the subspace represented by the loops starting at the given loop.
"""

return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint: disable=no-member
return _ffi_api.ScheduleBlockize(self, loop, preserve_unit_iters) # type: ignore # pylint: disable=no-member

@type_checked
def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -> None:
def tensorize(
self,
block_or_loop: Union[BlockRV, LoopRV],
tensor_intrin: str,
preserve_unit_iters: bool = True,
) -> None:
"""Tensorize the computation enclosed by loop with the tensor intrinsic.

Parameters
Expand All @@ -2269,6 +2276,8 @@ def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin: str) -
The loop to be tensorized.
tensor_intrin : str
The tensor intrin or the name of the tensor intrin.
preserve_unit_iters : bool
Whether or not to preserve unit iterators in block bindings

Examples
--------
Expand Down Expand Up @@ -2402,7 +2411,7 @@ def after_tensorize(
)
"""
_ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member
self, block_or_loop, tensor_intrin
self, block_or_loop, tensor_intrin, preserve_unit_iters
)

########## Schedule: Annotation ##########
Expand Down
31 changes: 21 additions & 10 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1812,18 +1812,26 @@ class SubspaceDivider {
// extent of inner
PrimExpr inner_extent;

// The kind of the division result.
enum class Kind {
kInner, // Indicates the division result is totally in inner subspace.
kOuter, // Indicates the division result is totally in outer subspace.
kMixed, // Indicates the division result is mixed in both subspace.
} kind;

DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner,
PrimExpr inner_extent)
PrimExpr inner_extent, Kind kind = Kind::kMixed)
: outer(std::move(outer)),
inner(std::move(inner)),
outer_extent(std::move(outer_extent)),
inner_extent(std::move(inner_extent)) {}
inner_extent(std::move(inner_extent)),
kind(kind) {}

// whether the division result is totally in outer subspace
bool IsOuter() const { return is_one(inner_extent); }
bool IsOuter() const { return kind == Kind::kOuter; }

// whether the division result is totally in inner subspace
bool IsInner() const { return is_one(outer_extent); }
bool IsInner() const { return kind == Kind::kInner; }

IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); }

Expand All @@ -1832,13 +1840,13 @@ class SubspaceDivider {
static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
auto dtype = iter.dtype();
return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), iter,
extent);
extent, Kind::kInner);
}

static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
auto dtype = iter.dtype();
return DivisionResult(iter, extent, IterSumExpr({}, make_const(dtype, 0)),
make_const(dtype, 1));
make_const(dtype, 1), Kind::kOuter);
}

// Special value to indicate the division is not possible
Expand Down Expand Up @@ -2066,9 +2074,11 @@ class SubspaceDivider {
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
IterMapLevel check_level, arith::Analyzer* analyzer) {
IterMapLevel check_level, arith::Analyzer* analyzer,
bool simplify_trivial_iterators) {
if (!IterRangeSanityCheck(input_iters)) return Array<Array<IterMark>>();
auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer);
auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer,
simplify_trivial_iterators);
const Array<IterSumExpr>& maps = res->indices;
if (maps.empty()) return {};

Expand Down Expand Up @@ -2096,10 +2106,11 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,

TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
.set_body_typed([](const Array<PrimExpr>& bindings, const Map<Var, Range>& root_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate, int check_level) {
const Array<Var>& sub_iters, const PrimExpr& predicate, int check_level,
bool simplify_trivial_iterators) {
arith::Analyzer ana;
return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level),
&ana);
&ana, simplify_trivial_iterators);
});

class InverseAffineIterMapTransformer {
Expand Down
16 changes: 10 additions & 6 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -690,25 +690,29 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
}

/******** Schedule: Blockize & Tensorize ********/
BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) {
BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Blockize(state_, this->GetSRef(loop_rv));
result = tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
return CreateRV<BlockRV>(result);
}

void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin) {
void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
bool preserve_unit_iters) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value());
tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(),
preserve_unit_iters);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
}

void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin) {
void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin,
bool preserve_unit_iters) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value());
tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(),
preserve_unit_iters);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
}
Expand Down
6 changes: 3 additions & 3 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ class ConcreteScheduleNode : public ScheduleNode {
int offset) override;
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv) override;
void Tensorize(const BlockRV& block_rv, const String& intrin) override;
void Tensorize(const LoopRV& loop_rv, const String& intrin) override;
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override;
void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) override;
/******** Schedule: Annotation ********/
void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,18 +452,20 @@ TVM_DLL void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, in
* \brief Convert the subtree rooted at a specific loop into a block.
* \param self The state of the schedule
* \param loop_sref The root of the subtree
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
* \return The new block
*/
TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref);
TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters);

/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
* \param self The state of the schedule
* \param block_or_loop_sref The block or loop to be tensorized.
* \param intrin The tensor intrinsic.
* \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
*/
TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
const TensorIntrin& intrin);
const TensorIntrin& intrin, bool preserve_unit_iters);

/******** Schedule: Annotation ********/
/*!
Expand Down
Loading