From f8a77140bf5eca38602b5e4c68c1a2b8b599028a Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 15 Nov 2022 23:52:58 -0500 Subject: [PATCH] Fix segfault in gradient based scheduler Gradient based scheduler would segfault if no candidates can be returned by the search strategy for some workload. It's expected to tune other workloads that have candidates. --- .../task_scheduler/gradient_based.cc | 17 ++-- .../test_meta_schedule_task_scheduler.py | 85 +++++++++++++++++++ 2 files changed, 95 insertions(+), 7 deletions(-) diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index e0470337b536..5b261eec32a4 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -68,7 +68,9 @@ class GradientBasedNode final : public TaskSchedulerNode { } if (round_robin_rounds_ == n_tasks) { for (int i = 0; i < n_tasks; ++i) { - this->JoinRunningTask(i); + if (this->tasks_[i]->runner_futures.defined()) { + this->JoinRunningTask(i); + } } ++round_robin_rounds_; } @@ -92,11 +94,10 @@ class GradientBasedNode final : public TaskSchedulerNode { for (int task_id : tasks_alive) { const std::vector& best_latency = this->best_latency_history_.at(task_id); int n = best_latency.size(); - ICHECK_GE(n, 1); double task_weight = this->tasks_[task_id]->task_weight; int w = this->window_size; - double best = best_latency[n - 1]; - if (best < 1e9) { + if (n > 0 && best_latency[n - 1] < 1e9) { + double best = best_latency[n - 1]; double g1 = (n >= 1 + w) ? (best_latency[n - 1 - w] - best) / w : 0.0; double g2 = best / n; double g = alpha * g1 + (1 - alpha) * g2; @@ -124,9 +125,11 @@ class GradientBasedNode final : public TaskSchedulerNode { Array JoinRunningTask(int task_id) final { Array results = TaskSchedulerNode::JoinRunningTask(task_id); TaskRecordNode* task = this->tasks_[task_id].get(); - this->best_latency_history_.at(task_id).push_back( - *std::min_element(task->latency_ms.begin(), // - task->latency_ms.end())); + if (task->latency_ms.size() > 0) { + this->best_latency_history_.at(task_id).push_back( + *std::min_element(task->latency_ms.begin(), // + task->latency_ms.end())); + } return results; } }; diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 33a019e3c555..ab0e3f0123dd 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -20,6 +20,7 @@ from typing import Set import pytest + import tvm import tvm.testing from tvm import meta_schedule as ms @@ -352,6 +353,89 @@ def test_meta_schedule_task_scheduler_multiple_gradient_based(): ) +def test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy(): + """ + When search strategy of one task returns empty list of candidates or None, + the scheduler should continue working as normal for other tasks + """ + + @ms.derived_object + class NullSearchStrategy(ms.search_strategy.PySearchStrategy): + def __init__(self, rounds_with_empty_candidates): + self.rounds_with_empty_candidates = rounds_with_empty_candidates + + def _initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def pre_tuning(self, *args, **kwargs): + pass + + def post_tuning(self): + pass + + def generate_measure_candidates(self): + """ + Returns empty list to indicate there is no result from search, while + the search isn't ended. + """ + if self.rounds_with_empty_candidates: + self.rounds_with_empty_candidates -= 1 + return [] + return None + + def notify_runner_results(self, *args, **kwargs): + pass + + def clone(self): + return NullSearchStrategy(n=self.n) + + tasks = [ + ms.TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=_schedule_matmul, + search_strategy=NullSearchStrategy(rounds_with_empty_candidates=5), + task_name="Matmul", + rand_state=42, + ), + ms.TuneContext( + BatchMatmulModule, + target=tvm.target.Target("llvm"), + space_generator=_schedule_batch_matmul, + search_strategy=NullSearchStrategy(rounds_with_empty_candidates=0), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ms.TuneContext( + MatmulReluModule, + target=tvm.target.Target("llvm"), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + ] + database = ms.database.MemoryDatabase() + gradient_based = ms.task_scheduler.GradientBased() + gradient_based.tune( + tasks, + task_weights=[1.0, 1.0, 1.0], + builder=DummyBuilder(), + runner=DummyRunner(), + database=database, + measure_callbacks=[ms.measure_callback.AddToDatabase()], + max_trials_global=30, + max_trials_per_task=10, + num_trials_per_iter=6, + cost_model=None, + ) + + assert len(database) == 10 + assert len(database.get_top_k(database.commit_workload(MatmulModule), 100)) == 0 + assert len(database.get_top_k(database.commit_workload(BatchMatmulModule), 100)) == 0 + assert len(database.get_top_k(database.commit_workload(MatmulReluModule), 100)) == 10 + + if __name__ == "__main__": test_meta_schedule_task_scheduler_single() test_meta_schedule_task_scheduler_multiple() @@ -359,3 +443,4 @@ def test_meta_schedule_task_scheduler_multiple_gradient_based(): test_meta_schedule_task_scheduler_avoid_cyclic() test_meta_schedule_task_scheduler_override_next_task_id_only() test_meta_schedule_task_scheduler_multiple_gradient_based() + test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy()