Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class TensorizeInfo(Object):


def get_tensorize_loop_mapping(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
sch: Schedule, block: BlockRV, desc_func: PrimFunc, allow_padding: bool = False
) -> Optional[TensorizeInfo]:
"""Establish a mapping between loops in a target block and an intrinsic description

Expand All @@ -80,13 +80,14 @@ def get_tensorize_loop_mapping(
The target block to match against
desc_func : PrimFunc
The prim func describing the computation to be tensorized

allow_padding : bool
Whether to allow padding the block iters to match the intrinsic description
Returns
-------
tensorize_info : Optional[TensorizeInfo]
TensorizeInfo structure if a valid mapping is found, None otherwise
"""
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func, allow_padding) # type: ignore


@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo")
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/tir/schedule/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from . import _ffi_api


def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) -> Optional[LoopRV]:
def tile_with_tensor_intrin(
sch: Schedule, block: BlockRV, intrin_name: str, allow_padding: bool = False
) -> Optional[LoopRV]:
"""Tile a subset of loops in the block according to the given tensor intrinsic.

Parameters
Expand All @@ -32,11 +34,13 @@ def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) ->
The block whose subset of loops will be tiled
intrin_name : str
The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand
allow_padding : bool
Whether to allow padding when tiling

Returns
-------
tiled_loop_rv : Optional[LoopRV]
LoopRV corresponding to the outermost loop of a block tiled according to the given intrin
NullOpt if no valid loop mapping is found
"""
return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name) # type: ignore
return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name, allow_padding) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
state->sch->TransformBlockLayout(state->tensor_core_reindex_B, index_map);
state->sch->TransformBlockLayout(state->block_rv, index_map);

return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name);
return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name,
/*allow_padding=*/true);
}

inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorization(
Expand Down
8 changes: 7 additions & 1 deletion src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -731,10 +731,15 @@ class TensorizeInfoNode : public Object {
Map<tir::StmtSRef, tir::For> loop_map;
/*! \brief Maps loops in an intrinsic description to its index, outer to inner */
Map<tir::For, Integer> desc_loop_indexer;
/*! \brief Optional padded extents of the block iters when padding is needed to match the
* intrinsic description
*/
Optional<Array<Integer>> block_iter_paddings;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_map", &loop_map);
v->Visit("desc_loop_indexer", &desc_loop_indexer);
v->Visit("block_iter_paddings", &block_iter_paddings);
}

static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
Expand All @@ -751,11 +756,12 @@ class TensorizeInfo : public ObjectRef {
* \param self The schedule state to be tensorized
* \param block_sref The target block to match against
* \param desc_func The prim func describing the computation to be tensorized
* \param allow_padding Whether to allow padding the block iters to match the intrinsic description
* \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
*/
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);
const tir::PrimFunc& desc_func, bool allow_padding);

/*!\brief Necessary information used to perform transformations for tensorization */
class AutoTensorizeMappingInfoNode : public Object {
Expand Down
53 changes: 44 additions & 9 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,8 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer,

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
const tir::PrimFunc& desc_func,
bool allow_padding) {
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
Expand Down Expand Up @@ -1732,6 +1733,8 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const int n_desc_vars = desc_block->iter_values.size();
const int offset = n_block_vars - n_desc_vars;

std::unordered_map<int, int> block_index_to_padding; // padding of each block iter if necessary

if (offset < 0) {
return NullOpt;
}
Expand Down Expand Up @@ -1782,10 +1785,11 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

// Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
PrimExpr block_bind;
for (int i = next_block_ind; i >= 0; --i) {
if (iter_types_block[i] == iter_type_desc) {
next_block_ind = i - 1;
block_bind = block->iter_values[i];
int current_block_ind = next_block_ind;
for (; current_block_ind >= 0; --current_block_ind) {
if (iter_types_block[current_block_ind] == iter_type_desc) {
next_block_ind = current_block_ind - 1;
block_bind = block->iter_values[current_block_ind];
break;
}
}
Expand All @@ -1802,15 +1806,30 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (UsesVar(residual,
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); }))
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) {
continue;
}
// padding is allowed only when the block has trivial bindings
if (allow_padding && !is_zero(residual)) {
allow_padding = false;
}

const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();

// Check divisibility
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
if (!int_block_extent) {
return NullOpt;
}
int64_t remainder = int_block_extent->value % int_desc_extent->value;
if (remainder != 0) {
if (allow_padding) {
// If the block loop is not divisible by the desc loop, we pad the block loop to make it
// divisible if padding is allowed.
block_index_to_padding[current_block_ind] = int_desc_extent->value - remainder;
} else {
return NullOpt;
}
}

ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
break;
Expand All @@ -1820,13 +1839,29 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
}
if (!block_index_to_padding.empty()) {
if (!allow_padding) {
return NullOpt;
}
Array<Integer> paddings;
for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) {
const IterVar& iter_var = block->block->iter_vars[i];
if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) {
paddings.push_back(IntImm(iter_var->var.dtype(), it->second));
} else {
paddings.push_back(IntImm(iter_var->var.dtype(), 0));
}
}
ret->block_iter_paddings = std::move(paddings);
}

return TensorizeInfo(ret);
}

TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc);
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding);
});

/******** Auto Tensorization ********/
Expand Down
10 changes: 7 additions & 3 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,15 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
}

Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name) {
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
const String& intrin_name, bool allow_padding) {
Optional<tir::TensorizeInfo> opt_tensorize_info =
GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv),
tir::TensorIntrin::Get(intrin_name)->desc, allow_padding);
if (!opt_tensorize_info) return NullOpt;
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
if (info->block_iter_paddings.defined()) {
sch->PadEinsum(block_rv, info->block_iter_paddings.value());
}
// Construct a mapping from tir loops back to LoopRVs
Map<tir::StmtSRef, LoopRV> loop2rv;
{
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
* \param block_rv The block whose subset of loops will be tiled
* \param intrin_name The name of a tensor intrinsic, must be registerd via
* TensorIntrin.register(...) beforehand
* \param allow_padding Whether to allow padding when tiling
* \return LoopRV corresponding to the outermost loop of a
* block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
*/
Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name);
const String& intrin_name, bool allow_padding = false);

/******** Block mutation ********/

Expand Down
Loading