Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/autotvm/tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
from .index_based_tuner import GridSearchTuner, RandomTuner
from .ga_tuner import GATuner
from .xgboost_tuner import XGBTuner
from .droplet_turner import DropletTuner
from .droplet_tuner import DropletTuner
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Tuner with droplet algorithm"""

import logging
import os
import numpy as np
from scipy import stats
from .tuner import Tuner
Expand Down Expand Up @@ -44,13 +45,15 @@ def __init__(self, task, start_position=None, pvalue=0.05):

for _, v in self.space.space_map.items():
self.dims.append(len(v))
if len(self.dims) == 0:
self.dims.append(1)

# start position
start_position = [0] * len(self.dims) if start_position is None else start_position
self.best_choice = (-1, [0] * len(self.dims), [99999])
self.visited = set([self.space.knob2point(start_position)])
self.execution, self.total_execution, self.batch = 1, max(self.dims), 16
self.pvalue, self.step = pvalue, 1
self.execution, self.total_execution, self.pvalue = 1, max(self.dims), pvalue
self.step, self.iter, self.batch = 1, 0, max(16, os.cpu_count())
self.next = [(self.space.knob2point(start_position), start_position)]

def num_to_bin(self, value, factor=1):
Expand Down Expand Up @@ -100,14 +103,15 @@ def speculation(self):
self.next += self.next_pos(self.search_space(self.execution))

def update(self, inputs, results):
found_best_pos = False
found_best_pos, count_valids = False, 0
for i, (_, res) in enumerate(zip(inputs, results)):
try:
if np.mean(self.best_choice[2]) > np.mean(res.costs) and self.p_value(
self.best_choice[2], res.costs
):
self.best_choice = (self.next[i][0], self.next[i][1], res.costs)
found_best_pos = True
count_valids += 1
except TypeError:
LOGGER.debug("Solution is not valid")
continue
Expand All @@ -119,6 +123,13 @@ def update(self, inputs, results):
self.next += self.next_pos(self.search_space())
self.execution = 1
self.speculation()
# stop, because all neighborhoods are invalid.
if count_valids == 0 and self.iter > 3:
self.next = []
LOGGER.warning(
f"Warning: early termination due to an all-invalid neighborhood \
after {self.iter} iterations"
)

def has_next(self):
return len(self.next) > 0
Expand Down