Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/meta_schedule/task_scheduler/gradient_based.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class GradientBasedNode final : public TaskSchedulerNode {
}
if (round_robin_rounds_ == n_tasks) {
for (int i = 0; i < n_tasks; ++i) {
this->JoinRunningTask(i);
if (this->tasks_[i]->runner_futures.defined()) {
this->JoinRunningTask(i);
}
}
++round_robin_rounds_;
}
Expand All @@ -92,11 +94,10 @@ class GradientBasedNode final : public TaskSchedulerNode {
for (int task_id : tasks_alive) {
const std::vector<double>& best_latency = this->best_latency_history_.at(task_id);
int n = best_latency.size();
ICHECK_GE(n, 1);
double task_weight = this->tasks_[task_id]->task_weight;
int w = this->window_size;
double best = best_latency[n - 1];
if (best < 1e9) {
if (n > 0 && best_latency[n - 1] < 1e9) {
double best = best_latency[n - 1];
double g1 = (n >= 1 + w) ? (best_latency[n - 1 - w] - best) / w : 0.0;
double g2 = best / n;
double g = alpha * g1 + (1 - alpha) * g2;
Expand Down Expand Up @@ -124,9 +125,11 @@ class GradientBasedNode final : public TaskSchedulerNode {
Array<RunnerResult> JoinRunningTask(int task_id) final {
Array<RunnerResult> results = TaskSchedulerNode::JoinRunningTask(task_id);
TaskRecordNode* task = this->tasks_[task_id].get();
this->best_latency_history_.at(task_id).push_back(
*std::min_element(task->latency_ms.begin(), //
task->latency_ms.end()));
if (task->latency_ms.size() > 0) {
this->best_latency_history_.at(task_id).push_back(
*std::min_element(task->latency_ms.begin(), //
task->latency_ms.end()));
}
return results;
}
};
Expand Down
85 changes: 85 additions & 0 deletions tests/python/unittest/test_meta_schedule_task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Set

import pytest

import tvm
import tvm.testing
from tvm import meta_schedule as ms
Expand Down Expand Up @@ -352,10 +353,94 @@ def test_meta_schedule_task_scheduler_multiple_gradient_based():
)


def test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy():
"""
When search strategy of one task returns empty list of candidates or None,
the scheduler should continue working as normal for other tasks
"""

@ms.derived_object
class NullSearchStrategy(ms.search_strategy.PySearchStrategy):
def __init__(self, rounds_with_empty_candidates):
self.rounds_with_empty_candidates = rounds_with_empty_candidates

def _initialize_with_tune_context(self, context: "TuneContext") -> None:
pass

def pre_tuning(self, *args, **kwargs):
pass

def post_tuning(self):
pass

def generate_measure_candidates(self):
"""
Returns empty list to indicate there is no result from search, while
the search isn't ended.
"""
if self.rounds_with_empty_candidates:
self.rounds_with_empty_candidates -= 1
return []
return None

def notify_runner_results(self, *args, **kwargs):
pass

def clone(self):
return NullSearchStrategy(n=self.n)

tasks = [
ms.TuneContext(
MatmulModule,
target=tvm.target.Target("llvm"),
space_generator=_schedule_matmul,
search_strategy=NullSearchStrategy(rounds_with_empty_candidates=5),
task_name="Matmul",
rand_state=42,
),
ms.TuneContext(
BatchMatmulModule,
target=tvm.target.Target("llvm"),
space_generator=_schedule_batch_matmul,
search_strategy=NullSearchStrategy(rounds_with_empty_candidates=0),
task_name="BatchMatmul",
rand_state=0x114514,
),
ms.TuneContext(
MatmulReluModule,
target=tvm.target.Target("llvm"),
space_generator=_schedule_matmul,
search_strategy=ms.search_strategy.ReplayTrace(),
task_name="MatmulRelu",
rand_state=0xDEADBEEF,
),
]
database = ms.database.MemoryDatabase()
gradient_based = ms.task_scheduler.GradientBased()
gradient_based.tune(
tasks,
task_weights=[1.0, 1.0, 1.0],
builder=DummyBuilder(),
runner=DummyRunner(),
database=database,
measure_callbacks=[ms.measure_callback.AddToDatabase()],
max_trials_global=30,
max_trials_per_task=10,
num_trials_per_iter=6,
cost_model=None,
)

assert len(database) == 10
assert len(database.get_top_k(database.commit_workload(MatmulModule), 100)) == 0
assert len(database.get_top_k(database.commit_workload(BatchMatmulModule), 100)) == 0
assert len(database.get_top_k(database.commit_workload(MatmulReluModule), 100)) == 10


if __name__ == "__main__":
test_meta_schedule_task_scheduler_single()
test_meta_schedule_task_scheduler_multiple()
test_meta_schedule_task_scheduler_NIE()
test_meta_schedule_task_scheduler_avoid_cyclic()
test_meta_schedule_task_scheduler_override_next_task_id_only()
test_meta_schedule_task_scheduler_multiple_gradient_based()
test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy()