diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 17d82558fb82..f4fc491286dd 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -196,8 +196,8 @@ class TaskSchedulerNode : public runtime::Object { * \param task_id The task id to be checked. */ void TouchTask(int task_id); - /*! \brief Returns a human-readable string of the tuning statistics. */ - std::string TuningStatistics() const; + /*! \brief Print out a human-readable format of the tuning statistics. */ + void PrintTuningStatistics(); static constexpr const char* _type_key = "meta_schedule.TaskScheduler"; TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); diff --git a/python/tvm/meta_schedule/logging.py b/python/tvm/meta_schedule/logging.py index 9d673266a3f2..53353e3aa907 100644 --- a/python/tvm/meta_schedule/logging.py +++ b/python/tvm/meta_schedule/logging.py @@ -39,7 +39,7 @@ def get_logger(name: str) -> Logger: return logging.getLogger(name) -def get_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]: +def get_logging_func(logger: Logger) -> Optional[Callable[[int, str, int, str], None]]: """Get the logging function. Parameters @@ -62,15 +62,15 @@ def get_logging_func(logger: Logger) -> Optional[Callable[[int, str], None]]: # logging.FATAL not included } - def logging_func(level: int, msg: str): - if level < 0: + def logging_func(level: int, filename: str, lineo: int, msg: str): + if level < 0: # clear the output in notebook / console from IPython.display import ( # type: ignore # pylint: disable=import-outside-toplevel clear_output, ) clear_output(wait=True) else: - level2log[level](msg) + level2log[level](f"[{os.path.basename(filename)}:{lineo}] " + msg) return logging_func @@ -94,12 +94,15 @@ def create_loggers( global_logger_name = "tvm.meta_schedule" global_logger = logging.getLogger(global_logger_name) if global_logger.level is logging.NOTSET: - global_logger.setLevel(logging.INFO) + global_logger.setLevel(logging.DEBUG) + console_logging_level = logging._levelToName[ # pylint: disable=protected-access + global_logger.level + ] config["loggers"].setdefault( global_logger_name, { - "level": logging._levelToName[global_logger.level], # pylint: disable=protected-access + "level": logging.DEBUG, "handlers": [handler.get_name() for handler in global_logger.handlers] + [global_logger_name + ".console", global_logger_name + ".file"], "propagate": False, @@ -108,7 +111,7 @@ def create_loggers( config["loggers"].setdefault( "{logger_name}", { - "level": "INFO", + "level": "DEBUG", "handlers": [ "{logger_name}.file", ], @@ -121,6 +124,7 @@ def create_loggers( "class": "logging.StreamHandler", "stream": "ext://sys.stdout", "formatter": "tvm.meta_schedule.standard_formatter", + "level": console_logging_level, }, ) config["handlers"].setdefault( @@ -129,7 +133,7 @@ def create_loggers( "class": "logging.FileHandler", "filename": "{log_dir}/" + __name__ + ".task_scheduler.log", "mode": "a", - "level": "INFO", + "level": "DEBUG", "formatter": "tvm.meta_schedule.standard_formatter", }, ) @@ -139,14 +143,14 @@ def create_loggers( "class": "logging.FileHandler", "filename": "{log_dir}/{logger_name}.log", "mode": "a", - "level": "INFO", + "level": "DEBUG", "formatter": "tvm.meta_schedule.standard_formatter", }, ) config["formatters"].setdefault( "tvm.meta_schedule.standard_formatter", { - "format": "%(asctime)s.%(msecs)03d %(levelname)s %(message)s", + "format": "%(asctime)s [%(levelname)s] %(message)s", "datefmt": "%Y-%m-%d %H:%M:%S", }, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index f06f4d911fa8..d56d944474e9 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -163,15 +163,9 @@ def touch_task(self, task_id: int) -> None: """ _ffi_api.TaskSchedulerTouchTask(self, task_id) # type: ignore # pylint: disable=no-member - def tuning_statistics(self) -> str: - """Returns a human-readable string of the tuning statistics. - - Returns - ------- - tuning_statistics : str - The tuning statistics. - """ - return _ffi_api.TaskSchedulerTuningStatistics(self) # type: ignore # pylint: disable=no-member + def print_tuning_statistics(self) -> None: + """Print out a human-readable format of the tuning statistics.""" + return _ffi_api.TaskSchedulerPrintTuningStatistics(self) # type: ignore # pylint: disable=no-member @staticmethod def create( # pylint: disable=keyword-arg-before-vararg diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index eb3c6437603c..401fdab08a26 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -188,13 +188,45 @@ def cpu_count(logical: bool = True) -> int: @register_func("meta_schedule.using_ipython") -def _using_ipython(): +def _using_ipython() -> bool: + """Return whether the current process is running in an IPython shell. + + Returns + ------- + result : bool + Whether the current process is running in an IPython shell. + """ try: return get_ipython().__class__.__name__ == "ZMQInteractiveShell" # type: ignore except NameError: return False +@register_func("meta_schedule.print_interactive_table") +def print_interactive_table(data: str) -> None: + """Print the dataframe interactive table in notebook. + + Parameters + ---------- + data : str + The serialized performance table from MetaSchedule table printer. + """ + import pandas as pd # type: ignore # pylint: disable=import-outside-toplevel + from IPython.display import display # type: ignore # pylint: disable=import-outside-toplevel + + pd.set_option("display.max_rows", None) + pd.set_option("display.max_colwidth", None) + parsed = [ + x.split("|")[1:] for x in list(filter(lambda x: set(x) != {"-"}, data.strip().split("\n"))) + ] + display( + pd.DataFrame( + parsed[1:], + columns=parsed[0], + ) + ) + + def get_global_func_with_default_on_worker( name: Union[None, str, Callable], default: Callable, diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index bae52573a0f9..e0470337b536 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -60,7 +60,8 @@ class GradientBasedNode final : public TaskSchedulerNode { int n_tasks = this->tasks_.size(); // Step 1. Check if it's in round robin mode. if (round_robin_rounds_ == 0) { - TVM_PY_LOG(INFO, this->logger) << "\n" << this->TuningStatistics(); + TVM_PY_LOG_CLEAR_SCREEN(this->logger); + this->PrintTuningStatistics(); } if (round_robin_rounds_ < n_tasks) { return round_robin_rounds_++; diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 21efde26d993..69a70f63c5c0 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -232,9 +232,8 @@ Array TaskSchedulerNode::JoinRunningTask(int task_id) { } TaskCleanUp(task, task_id, results); TVM_PY_LOG_CLEAR_SCREEN(this->logger); - TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << task->ctx->task_name - << "\n" - << this->TuningStatistics(); + TVM_PY_LOG(INFO, this->logger) << "[Updated] Task #" << task_id << ": " << task->ctx->task_name; + this->PrintTuningStatistics(); return results; } @@ -257,12 +256,11 @@ void TaskSchedulerNode::TerminateTask(int task_id) { --this->remaining_tasks_; TVM_PY_LOG_CLEAR_SCREEN(this->logger); TVM_PY_LOG(INFO, this->logger) << "Task #" << task_id - << " has finished. Remaining task(s): " << this->remaining_tasks_ - << "\n" - << this->TuningStatistics(); + << " has finished. Remaining task(s): " << this->remaining_tasks_; + this->PrintTuningStatistics(); } -std::string TaskSchedulerNode::TuningStatistics() const { +void TaskSchedulerNode::PrintTuningStatistics() { std::ostringstream os; int n_tasks = this->tasks_.size(); int total_trials = 0; @@ -307,11 +305,18 @@ std::string TaskSchedulerNode::TuningStatistics() const { } } p.Separator(); - os << p.AsStr() // - << "\nTotal trials: " << total_trials // + + os << "\nTotal trials: " << total_trials // << "\nTotal latency (us): " << total_latency // << "\n"; - return os.str(); + + if (using_ipython()) { + print_interactive_table(p.AsStr()); + std::cout << os.str() << std::endl << std::flush; + TVM_PY_LOG(DEBUG, this->logger) << "\n" << p.AsStr() << os.str(); + } else { + TVM_PY_LOG(INFO, this->logger) << "\n" << p.AsStr() << os.str(); + } } TaskScheduler TaskScheduler::PyTaskScheduler( @@ -369,8 +374,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask") .set_body_method(&TaskSchedulerNode::TerminateTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") .set_body_method(&TaskSchedulerNode::TouchTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTuningStatistics") - .set_body_method(&TaskSchedulerNode::TuningStatistics); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics") + .set_body_method(&TaskSchedulerNode::PrintTuningStatistics); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 41d8ffde558c..b14717f4b29e 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -82,36 +82,32 @@ class PyLogMessage { // FATAL not included }; - explicit PyLogMessage(const char* file, int lineno, PackedFunc logger, Level logging_level) - : file_(file), lineno_(lineno), logger_(logger), logging_level_(logging_level) { - if (this->logger_ != nullptr) { - stream_ << "" << file_ << ":" << lineno_ << " "; - } - } + explicit PyLogMessage(const char* filename, int lineno, PackedFunc logger, Level logging_level) + : filename_(filename), lineno_(lineno), logger_(logger), logging_level_(logging_level) {} TVM_NO_INLINE ~PyLogMessage() { ICHECK(logging_level_ != Level::CLEAR) << "Cannot use CLEAR as logging level in TVM_PY_LOG, please use TVM_PY_LOG_CLEAR_SCREEN."; if (this->logger_ != nullptr) { - logger_(static_cast(logging_level_), stream_.str()); + logger_(static_cast(logging_level_), std::string(filename_), lineno_, stream_.str()); } else { if (logging_level_ == Level::INFO) { - runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str(); + runtime::detail::LogMessage(filename_, lineno_).stream() << stream_.str(); } else if (logging_level_ == Level::WARNING) { - runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " << stream_.str(); + runtime::detail::LogMessage(filename_, lineno_).stream() << "Warning: " << stream_.str(); } else if (logging_level_ == Level::ERROR) { - runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " << stream_.str(); + runtime::detail::LogMessage(filename_, lineno_).stream() << "Error: " << stream_.str(); } else if (logging_level_ == Level::DEBUG) { - runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " << stream_.str(); + runtime::detail::LogMessage(filename_, lineno_).stream() << "Debug: " << stream_.str(); } else { - runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str(); + runtime::detail::LogFatal(filename_, lineno_).stream() << stream_.str(); } } } std::ostringstream& stream() { return stream_; } private: - const char* file_; + const char* filename_; int lineno_; std::ostringstream stream_; PackedFunc logger_; @@ -131,6 +127,18 @@ inline bool using_ipython() { return flag; } +/*! + * \brief Print out the performance table interactively in jupyter notebook. + * \param str The serialized performance table. + */ +inline void print_interactive_table(const String& data) { + const auto* f_print_interactive_table = + runtime::Registry::Get("meta_schedule.print_interactive_table"); + ICHECK(f_print_interactive_table->defined()) + << "Cannot find print_interactive_table function in registry."; + (*f_print_interactive_table)(data); +} + /*! * \brief A helper function to clear logging output for ipython kernel and console. * \param file The file name. @@ -139,7 +147,7 @@ inline bool using_ipython() { */ inline void clear_logging(const char* file, int lineno, PackedFunc logging_func) { if (logging_func.defined() && using_ipython()) { - logging_func(static_cast(PyLogMessage::Level::CLEAR), ""); + logging_func(static_cast(PyLogMessage::Level::CLEAR), file, lineno, ""); } else { // this would clear all logging output in the console runtime::detail::LogMessage(file, lineno).stream() << "\033c\033[3J\033[2J\033[0m\033[H";