diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index e0470337b536..5b261eec32a4 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -68,7 +68,9 @@ class GradientBasedNode final : public TaskSchedulerNode { } if (round_robin_rounds_ == n_tasks) { for (int i = 0; i < n_tasks; ++i) { - this->JoinRunningTask(i); + if (this->tasks_[i]->runner_futures.defined()) { + this->JoinRunningTask(i); + } } ++round_robin_rounds_; } @@ -92,11 +94,10 @@ class GradientBasedNode final : public TaskSchedulerNode { for (int task_id : tasks_alive) { const std::vector& best_latency = this->best_latency_history_.at(task_id); int n = best_latency.size(); - ICHECK_GE(n, 1); double task_weight = this->tasks_[task_id]->task_weight; int w = this->window_size; - double best = best_latency[n - 1]; - if (best < 1e9) { + if (n > 0 && best_latency[n - 1] < 1e9) { + double best = best_latency[n - 1]; double g1 = (n >= 1 + w) ? (best_latency[n - 1 - w] - best) / w : 0.0; double g2 = best / n; double g = alpha * g1 + (1 - alpha) * g2; @@ -124,9 +125,11 @@ class GradientBasedNode final : public TaskSchedulerNode { Array JoinRunningTask(int task_id) final { Array results = TaskSchedulerNode::JoinRunningTask(task_id); TaskRecordNode* task = this->tasks_[task_id].get(); - this->best_latency_history_.at(task_id).push_back( - *std::min_element(task->latency_ms.begin(), // - task->latency_ms.end())); + if (task->latency_ms.size() > 0) { + this->best_latency_history_.at(task_id).push_back( + *std::min_element(task->latency_ms.begin(), // + task->latency_ms.end())); + } return results; } }; diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 33a019e3c555..ab0e3f0123dd 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -20,6 +20,7 @@ from typing import Set import pytest + import tvm import tvm.testing from tvm import meta_schedule as ms @@ -352,6 +353,89 @@ def test_meta_schedule_task_scheduler_multiple_gradient_based(): ) +def test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy(): + """ + When search strategy of one task returns empty list of candidates or None, + the scheduler should continue working as normal for other tasks + """ + + @ms.derived_object + class NullSearchStrategy(ms.search_strategy.PySearchStrategy): + def __init__(self, rounds_with_empty_candidates): + self.rounds_with_empty_candidates = rounds_with_empty_candidates + + def _initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def pre_tuning(self, *args, **kwargs): + pass + + def post_tuning(self): + pass + + def generate_measure_candidates(self): + """ + Returns empty list to indicate there is no result from search, while + the search isn't ended. + """ + if self.rounds_with_empty_candidates: + self.rounds_with_empty_candidates -= 1 + return [] + return None + + def notify_runner_results(self, *args, **kwargs): + pass + + def clone(self): + return NullSearchStrategy(n=self.n) + + tasks = [ + ms.TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=_schedule_matmul, + search_strategy=NullSearchStrategy(rounds_with_empty_candidates=5), + task_name="Matmul", + rand_state=42, + ), + ms.TuneContext( + BatchMatmulModule, + target=tvm.target.Target("llvm"), + space_generator=_schedule_batch_matmul, + search_strategy=NullSearchStrategy(rounds_with_empty_candidates=0), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ms.TuneContext( + MatmulReluModule, + target=tvm.target.Target("llvm"), + space_generator=_schedule_matmul, + search_strategy=ms.search_strategy.ReplayTrace(), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + ] + database = ms.database.MemoryDatabase() + gradient_based = ms.task_scheduler.GradientBased() + gradient_based.tune( + tasks, + task_weights=[1.0, 1.0, 1.0], + builder=DummyBuilder(), + runner=DummyRunner(), + database=database, + measure_callbacks=[ms.measure_callback.AddToDatabase()], + max_trials_global=30, + max_trials_per_task=10, + num_trials_per_iter=6, + cost_model=None, + ) + + assert len(database) == 10 + assert len(database.get_top_k(database.commit_workload(MatmulModule), 100)) == 0 + assert len(database.get_top_k(database.commit_workload(BatchMatmulModule), 100)) == 0 + assert len(database.get_top_k(database.commit_workload(MatmulReluModule), 100)) == 10 + + if __name__ == "__main__": test_meta_schedule_task_scheduler_single() test_meta_schedule_task_scheduler_multiple() @@ -359,3 +443,4 @@ def test_meta_schedule_task_scheduler_multiple_gradient_based(): test_meta_schedule_task_scheduler_avoid_cyclic() test_meta_schedule_task_scheduler_override_next_task_id_only() test_meta_schedule_task_scheduler_multiple_gradient_based() + test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy()