Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ class TaskSchedulerNode : public runtime::Object {
* \param task_id The task id to be checked.
*/
void TouchTask(int task_id);
/*! \brief Returns a human-readable string of the tuning statistics. */
std::string TuningStatistics() const;
/*! \brief Print out a human-readable format of the tuning statistics. */
void PrintTuningStatistics();

static constexpr const char* _type_key = "meta_schedule.TaskScheduler";
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
Expand Down
24 changes: 14 additions & 10 deletions python/tvm/meta_schedule/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_logger(name: str) -> Logger:
return logging.getLogger(name)


def get_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]:
def get_logging_func(logger: Logger) -> Optional[Callable[[int, str, int, str], None]]:
"""Get the logging function.

Parameters
Expand All @@ -62,15 +62,15 @@ def get_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]:
# logging.FATAL not included
}

def logging_func(level: int, msg: str):
if level < 0:
def logging_func(level: int, filename: str, lineo: int, msg: str):
if level < 0: # clear the output in notebook / console
from IPython.display import ( # type: ignore # pylint: disable=import-outside-toplevel
clear_output,
)

clear_output(wait=True)
else:
level2log[level](msg)
level2log[level](f"[{os.path.basename(filename)}:{lineo}] " + msg)

return logging_func

Expand All @@ -94,12 +94,15 @@ def create_loggers(
global_logger_name = "tvm.meta_schedule"
global_logger = logging.getLogger(global_logger_name)
if global_logger.level is logging.NOTSET:
global_logger.setLevel(logging.INFO)
global_logger.setLevel(logging.DEBUG)
console_logging_level = logging._levelToName[ # pylint: disable=protected-access
global_logger.level
]

config["loggers"].setdefault(
global_logger_name,
{
"level": logging._levelToName[global_logger.level], # pylint: disable=protected-access
"level": logging.DEBUG,
"handlers": [handler.get_name() for handler in global_logger.handlers]
+ [global_logger_name + ".console", global_logger_name + ".file"],
"propagate": False,
Expand All @@ -108,7 +111,7 @@ def create_loggers(
config["loggers"].setdefault(
"{logger_name}",
{
"level": "INFO",
"level": "DEBUG",
"handlers": [
"{logger_name}.file",
],
Expand All @@ -121,6 +124,7 @@ def create_loggers(
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"formatter": "tvm.meta_schedule.standard_formatter",
"level": console_logging_level,
},
)
config["handlers"].setdefault(
Expand All @@ -129,7 +133,7 @@ def create_loggers(
"class": "logging.FileHandler",
"filename": "{log_dir}/" + __name__ + ".task_scheduler.log",
"mode": "a",
"level": "INFO",
"level": "DEBUG",
"formatter": "tvm.meta_schedule.standard_formatter",
},
)
Expand All @@ -139,14 +143,14 @@ def create_loggers(
"class": "logging.FileHandler",
"filename": "{log_dir}/{logger_name}.log",
"mode": "a",
"level": "INFO",
"level": "DEBUG",
"formatter": "tvm.meta_schedule.standard_formatter",
},
)
config["formatters"].setdefault(
"tvm.meta_schedule.standard_formatter",
{
"format": "%(asctime)s.%(msecs)03d %(levelname)s %(message)s",
"format": "%(asctime)s [%(levelname)s] %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
)
Expand Down
12 changes: 3 additions & 9 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,9 @@ def touch_task(self, task_id: int) -> None:
"""
_ffi_api.TaskSchedulerTouchTask(self, task_id) # type: ignore # pylint: disable=no-member

def tuning_statistics(self) -> str:
"""Returns a human-readable string of the tuning statistics.

Returns
-------
tuning_statistics : str
The tuning statistics.
"""
return _ffi_api.TaskSchedulerTuningStatistics(self) # type: ignore # pylint: disable=no-member
def print_tuning_statistics(self) -> None:
"""Print out a human-readable format of the tuning statistics."""
return _ffi_api.TaskSchedulerPrintTuningStatistics(self) # type: ignore # pylint: disable=no-member

@staticmethod
def create( # pylint: disable=keyword-arg-before-vararg
Expand Down
34 changes: 33 additions & 1 deletion python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,45 @@ def cpu_count(logical: bool = True) -> int:


@register_func("meta_schedule.using_ipython")
def _using_ipython():
def _using_ipython() -> bool:
"""Return whether the current process is running in an IPython shell.

Returns
-------
result : bool
Whether the current process is running in an IPython shell.
"""
try:
return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore
except NameError:
return False


@register_func("meta_schedule.print_interactive_table")
def print_interactive_table(data: str) -> None:
"""Print the dataframe interactive table in notebook.

Parameters
----------
data : str
The serialized performance table from MetaSchedule table printer.
"""
import pandas as pd # type: ignore # pylint: disable=import-outside-toplevel
from IPython.display import display # type: ignore # pylint: disable=import-outside-toplevel

pd.set_option("display.max_rows", None)
pd.set_option("display.max_colwidth", None)
parsed = [
x.split("|")[1:] for x in list(filter(lambda x: set(x) != {"-"}, data.strip().split("\n")))
]
display(
pd.DataFrame(
parsed[1:],
columns=parsed[0],
)
)


def get_global_func_with_default_on_worker(
name: Union[None, str, Callable],
default: Callable,
Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/task_scheduler/gradient_based.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class GradientBasedNode final : public TaskSchedulerNode {
int n_tasks = this->tasks_.size();
// Step 1. Check if it's in round robin mode.
if (round_robin_rounds_ == 0) {
TVM_PY_LOG(INFO, this->logger) << "\n" << this->TuningStatistics();
TVM_PY_LOG_CLEAR_SCREEN(this->logger);
this->PrintTuningStatistics();
}
if (round_robin_rounds_ < n_tasks) {
return round_robin_rounds_++;
Expand Down
29 changes: 17 additions & 12 deletions src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,8 @@ Array<RunnerResult> TaskSchedulerNode::JoinRunningTask(int task_id) {
}
TaskCleanUp(task, task_id, results);
TVM_PY_LOG_CLEAR_SCREEN(this->logger);
TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << task->ctx->task_name
<< "\n"
<< this->TuningStatistics();
TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << task->ctx->task_name;
this->PrintTuningStatistics();
return results;
}

Expand All @@ -257,12 +256,11 @@ void TaskSchedulerNode::TerminateTask(int task_id) {
--this->remaining_tasks_;
TVM_PY_LOG_CLEAR_SCREEN(this->logger);
TVM_PY_LOG(INFO, this->logger) << "Task #" << task_id
<< " has finished. Remaining task(s): " << this->remaining_tasks_
<< "\n"
<< this->TuningStatistics();
<< " has finished. Remaining task(s): " << this->remaining_tasks_;
this->PrintTuningStatistics();
}

std::string TaskSchedulerNode::TuningStatistics() const {
void TaskSchedulerNode::PrintTuningStatistics() {
std::ostringstream os;
int n_tasks = this->tasks_.size();
int total_trials = 0;
Expand Down Expand Up @@ -307,11 +305,18 @@ std::string TaskSchedulerNode::TuningStatistics() const {
}
}
p.Separator();
os << p.AsStr() //
<< "\nTotal trials: " << total_trials //

os << "\nTotal trials: " << total_trials //
<< "\nTotal latency (us): " << total_latency //
<< "\n";
return os.str();

if (using_ipython()) {
print_interactive_table(p.AsStr());
std::cout << os.str() << std::endl << std::flush;
TVM_PY_LOG(DEBUG, this->logger) << "\n" << p.AsStr() << os.str();
} else {
TVM_PY_LOG(INFO, this->logger) << "\n" << p.AsStr() << os.str();
}
}

TaskScheduler TaskScheduler::PyTaskScheduler(
Expand Down Expand Up @@ -369,8 +374,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask")
.set_body_method<TaskScheduler>(&TaskSchedulerNode::TerminateTask);
TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask")
.set_body_method<TaskScheduler>(&TaskSchedulerNode::TouchTask);
TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTuningStatistics")
.set_body_method<TaskScheduler>(&TaskSchedulerNode::TuningStatistics);
TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics")
.set_body_method<TaskScheduler>(&TaskSchedulerNode::PrintTuningStatistics);

} // namespace meta_schedule
} // namespace tvm
36 changes: 22 additions & 14 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,36 +82,32 @@ class PyLogMessage {
// FATAL not included
};

explicit PyLogMessage(const char* file, int lineno, PackedFunc logger, Level logging_level)
: file_(file), lineno_(lineno), logger_(logger), logging_level_(logging_level) {
if (this->logger_ != nullptr) {
stream_ << "" << file_ << ":" << lineno_ << " ";
}
}
explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger, Level logging_level)
: filename_(filename), lineno_(lineno), logger_(logger), logging_level_(logging_level) {}

TVM_NO_INLINE ~PyLogMessage() {
ICHECK(logging_level_ != Level::CLEAR)
<< "Cannot use CLEAR as logging level in TVM_PY_LOG, please use TVM_PY_LOG_CLEAR_SCREEN.";
if (this->logger_ != nullptr) {
logger_(static_cast<int>(logging_level_), stream_.str());
logger_(static_cast<int>(logging_level_), std::string(filename_), lineno_, stream_.str());
} else {
if (logging_level_ == Level::INFO) {
runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str();
runtime::detail::LogMessage(filename_, lineno_).stream() << stream_.str();
} else if (logging_level_ == Level::WARNING) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " << stream_.str();
runtime::detail::LogMessage(filename_, lineno_).stream() << "Warning: " << stream_.str();
} else if (logging_level_ == Level::ERROR) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " << stream_.str();
runtime::detail::LogMessage(filename_, lineno_).stream() << "Error: " << stream_.str();
} else if (logging_level_ == Level::DEBUG) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " << stream_.str();
runtime::detail::LogMessage(filename_, lineno_).stream() << "Debug: " << stream_.str();
} else {
runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str();
runtime::detail::LogFatal(filename_, lineno_).stream() << stream_.str();
}
}
}
std::ostringstream& stream() { return stream_; }

private:
const char* file_;
const char* filename_;
int lineno_;
std::ostringstream stream_;
PackedFunc logger_;
Expand All @@ -131,6 +127,18 @@ inline bool using_ipython() {
return flag;
}

/*!
* \brief Print out the performance table interactively in jupyter notebook.
* \param str The serialized performance table.
*/
inline void print_interactive_table(const String& data) {
const auto* f_print_interactive_table =
runtime::Registry::Get("meta_schedule.print_interactive_table");
ICHECK(f_print_interactive_table->defined())
<< "Cannot find print_interactive_table function in registry.";
(*f_print_interactive_table)(data);
}

/*!
* \brief A helper function to clear logging output for ipython kernel and console.
* \param file The file name.
Expand All @@ -139,7 +147,7 @@ inline bool using_ipython() {
*/
inline void clear_logging(const char* file, int lineno, PackedFunc logging_func) {
if (logging_func.defined() && using_ipython()) {
logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), "");
logging_func(static_cast<int>(PyLogMessage::Level::CLEAR), file, lineno, "");
} else {
// this would clear all logging output in the console
runtime::detail::LogMessage(file, lineno).stream() << "\033c\033[3J\033[2J\033[0m\033[H";
Expand Down