Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,20 @@ class ScheduleRule : public runtime::ObjectRef {
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule RandomComputeLocation();
/*!
* \brief Mark parallelize, vectorize and unroll to each block correspondingly
* \brief Mark parallelize, vectorize and unroll to the root block. The mark will be applied to
* each block in a follow-up post processor
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
* uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
* upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
* parallelism.
* \param max_vectorize_extent The maximum extent to be vectorized.
* It sets the uplimit of the CPU vectorization. Use -1 to disable vectorization.
* \param unroll_max_steps The maximum number of unroll steps to be done.
* It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization.
* \param unroll_max_steps The options of the maximum number of unroll steps to be done.
* Use an empty array to disable unroll.
* \param unroll_explicit Whether to explicitly unroll the loop, or just add a unroll pragma.
* \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,18 @@ constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_str
constexpr const char* meta_schedule_random_compute_producer =
"meta_schedule.random_compute_producer";

/*! \brief Mark auto-parallel setting on the block. */
constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";

/*! \brief Mark auto-vectorize setting on the block. */
constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";

/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";

/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
from .add_rfactor import AddRFactor
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .schedule_rule import PyScheduleRule, ScheduleRule
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Rule that mark parallelize, vectorize and unroll to the root block. The mark will be applied to
each block in a follow-up post processor"""
from typing import List, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.ParallelizeVectorizeUnroll")
class ParallelizeVectorizeUnroll(ScheduleRule):
"""Rule that mark parallelize, vectorize and unroll to the root block. The mark will be applied
to each block in a follow-up post processor

Parameters
----------
max_jobs_per_core: int
The maximum number of jobs to be launched per CPU core. It sets the upper limit of CPU
parallelism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallelism.
max_vectorize_extent: int
The maximum extent to be vectorized. It sets the upper limit of the hardware target
vectorization.
Use -1 to disable vectorization.
unroll_max_steps: Optional[List[int]]
The options of the maximum number of unroll steps to be done.
Use None to disable unroll
unroll_explicit: bool
Whether to explicitly unroll the loop, or just add an "unroll" pragma
"""

def __init__(
self,
max_jobs_per_core: int = 16,
max_vectorize_extent: int = 16,
unroll_max_steps: Optional[List[int]] = None,
unroll_explicit: bool = True,
) -> None:
if unroll_max_steps is None:
unroll_max_steps = []
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleParallelizeVectorizeUnroll, # type: ignore # pylint: disable=no-member
max_jobs_per_core,
max_vectorize_extent,
unroll_max_steps,
unroll_explicit,
)
28 changes: 28 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
AddRFactor,
AutoInline,
CrossThreadReduction,
ParallelizeVectorizeUnroll,
RandomComputeLocation,
ScheduleRule,
)
from tvm.target import Target
Expand Down Expand Up @@ -61,3 +63,29 @@ def cross_thread_reduction(target: Target) -> ScheduleRule:
if target.kind.name == "cuda":
return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
raise NotImplementedError(f"{target.kind.name} is not supported")


def random_compute_location(target: Target) -> ScheduleRule:
"""Default schedule rules for with random-compute-location"""
if target.kind.name == "llvm":
return RandomComputeLocation()
raise NotImplementedError(f"{target.kind.name} is not supported")


def parallel_vectorize_unroll(target: Target) -> ScheduleRule:
"""Default schedule rules for with parallel-vectorize-unroll"""
if target.kind.name == "llvm":
return ParallelizeVectorizeUnroll(
max_jobs_per_core=16,
max_vectorize_extent=32,
unroll_max_steps=[0, 16, 64, 512],
unroll_explicit=True,
)
if target.kind.name == "cuda":
return ParallelizeVectorizeUnroll(
max_jobs_per_core=-1,
max_vectorize_extent=-1,
unroll_max_steps=[0, 16, 64, 512, 1024],
unroll_explicit=True,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
129 changes: 129 additions & 0 deletions src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

bool IsRootBlock(const Schedule& sch, const BlockRV& block_rv) {
StmtSRef block_sref = sch->GetSRef(block_rv);
return block_sref->parent == nullptr;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode {
public:
// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
if (this->max_jobs_per_core != -1) {
Target target = context->target.value();
this->max_parallel_extent_ = GetTargetNumCores(target) * max_jobs_per_core;
}
}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) {
// Currently only mark the root block with annotations.
if (!tir::IsRootBlock(sch, root_rv)) {
return {sch};
}

// Parallelization
if (max_jobs_per_core != -1) {
sch->Annotate(root_rv, tir::attr::meta_schedule_parallel,
Integer(this->max_parallel_extent_));
}
// Vectorization
if (max_vectorize_extent != -1) {
sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent));
}
// Unroll
if (!unroll_max_steps.empty()) {
int n = unroll_max_steps.size();
double prob = 1.0 / n;
Array<FloatImm> probs(n, FloatImm(DataType::Float(64), prob));
PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs);
if (unroll_explicit) {
sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step);
} else {
sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_implicit, max_step);
}
}
return {sch};
}

public:
/*!
* \brief The maximum number of jobs to be launched per CPU core. It sets the
* upper limit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
* parallelism.
*/
int64_t max_jobs_per_core;
/*!
* \brief The maximum extent to be vectorized.
* It sets the upper limit of the hardware target vectorization. Use -1 to disable vectorization.
*/
int max_vectorize_extent;
/*!
* \brief The options of the maximum number of unroll steps to be done.
* Use an empty array to disable unroll.
*/
Array<Integer> unroll_max_steps;
/*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */
bool unroll_explicit;
/*! \brief The number of maximum available jobs in CPU. */
int64_t max_parallel_extent_;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_jobs_per_core", &max_jobs_per_core);
v->Visit("max_vectorize_extent", &max_vectorize_extent);
v->Visit("unroll_max_steps", &unroll_max_steps);
v->Visit("unroll_explicit", &unroll_explicit);
// `max_parallel_extent_` is not visited
}

static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll";
TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core,
int max_vectorize_extent,
Array<Integer> unroll_max_steps,
bool unroll_explicit) {
ObjectPtr<ParallelizeVectorizeUnrollNode> n = make_object<ParallelizeVectorizeUnrollNode>();
n->max_jobs_per_core = max_jobs_per_core;
n->max_vectorize_extent = max_vectorize_extent;
n->unroll_max_steps = unroll_max_steps;
n->unroll_explicit = unroll_explicit;
n->max_parallel_extent_ = -1;
return ScheduleRule(n);
}

TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll")
.set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll);

} // namespace meta_schedule
} // namespace tvm
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll
from tvm.meta_schedule.testing.space_generation import check_trace
from tvm.meta_schedule.tune_context import TuneContext
from tvm.script import tir as T
from tvm.target import Target

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks

@tvm.script.ir_module
class Matmul:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class ParallelizeVectorizeUnroll:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block("root"):
T.reads([])
T.writes([])
T.block_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2})
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on


def _create_context(mod, target, rule):
ctx = TuneContext(
mod=mod,
target=target,
space_generator=PostOrderApply(),
sch_rules=[rule],
task_name="test",
)
ctx.space_generator.initialize_with_tune_context(ctx)
for sch_rule in ctx.sch_rules:
sch_rule.initialize_with_tune_context(ctx)
return ctx


def test_parallel_vectorize_unroll():
expected = [
[
'b0 = sch.get_block(name="root", func_name="main")',
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=512)',
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)',
"v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)',
]
]
mod = Matmul
target = Target("llvm --num-cores=32")
ctx = _create_context(
mod=mod,
target=target,
rule=parallel_vectorize_unroll(target=target),
)
spaces = ctx.space_generator.generate_design_space(mod=mod)
assert len(spaces) == 1
check_trace(spaces, expected)


if __name__ == "__main__":
test_parallel_vectorize_unroll()