diff --git a/python/tvm/auto_scheduler/task_scheduler.py b/python/tvm/auto_scheduler/task_scheduler.py index 547e5a5833ea..58457daad0b6 100644 --- a/python/tvm/auto_scheduler/task_scheduler.py +++ b/python/tvm/auto_scheduler/task_scheduler.py @@ -358,6 +358,11 @@ def tune( self.best_ct = self.ct self.best_score = self.cur_score + # put task without schedule on warm up to dead state + for task_idx, cost in enumerate(self.best_costs): + if cost == 1e10: + self.dead_tasks.add(task_idx) + # use the specific strategy to choose workload to tune task_idx = -1 while self.ct < tune_option.num_measure_trials and len(self.dead_tasks) < len(self.tasks): @@ -367,6 +372,7 @@ def tune( task_idx = (task_idx + 1) % len(self.tasks) elif self.strategy == "gradient": gradients = [] + for i in range(len(self.tasks)): if i in self.dead_tasks: gradients.append(0)