diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 06537e2cdc4d..3ed7e57cb758 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -278,18 +278,24 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: self, tir_prefix, show_meta ) # type: ignore - def show(self, style: Optional[str] = None) -> None: - """ - A sugar for print highlighted TVM script. + def show(self, style: Optional[str] = None, black_format: bool = True) -> None: + """A sugar for print highlighted TVM script. + Parameters ---------- style : str, optional - Pygments styles extended by "light" (default) and "dark", by default "light" + + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + + black_format: bool + + If true (default), use the formatter Black to format the TVMScript """ from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style) + cprint(self, style=style, black_format=black_format) def get_attr(self, attr_key): """Get the IRModule attribute. diff --git a/python/tvm/script/highlight.py b/python/tvm/script/highlight.py index dc45b5a3f1cd..d12f6c276767 100644 --- a/python/tvm/script/highlight.py +++ b/python/tvm/script/highlight.py @@ -17,6 +17,7 @@ """Highlight printed TVM script. """ +import os import sys import warnings from typing import Optional, Union @@ -25,17 +26,30 @@ from tvm.tir import PrimFunc -def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = None) -> None: - """ - Print highlighted TVM script string with Pygments +def cprint( + printable: Union[IRModule, PrimFunc, str], + style: Optional[str] = None, + black_format: bool = True, +) -> None: + """Print TVMScript string with Pygments highlight and Black auto-formatting. + Parameters ---------- printable : Union[IRModule, PrimFunc, str] - The TVM script to be printed + + The TVMScript to be printed + style : str, optional - Printing style, auto-detected if None. + + Pygmentize printing style, auto-detected if None. + + black_format: bool + + If true (default), use the formatter Black to format the TVMScript + Notes ----- + The style parameter follows the Pygments style names or Style objects. Three built-in styles are extended: "light", "dark" and "ansi". By default, "light" will be used for notebook environment and terminal style will be "ansi" for @@ -43,16 +57,103 @@ def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = Non not installed, plain text will be printed with a one-time warning to suggest installing the Pygment library. Other Pygment styles can be found in https://pygments.org/styles/ + + The default pygmentize style can also be set with the environment + variable "TVM_PYGMENTIZE_STYLE". """ if isinstance(printable, (IRModule, PrimFunc)): printable = printable.script() + + if black_format: + printable = _format(printable) + + is_in_notebook = "ipykernel" in sys.modules # in notebook env (support html display). + + style = _get_pygments_style(style, is_in_notebook) + + if style is None: + print(printable) + return + + # pylint: disable=import-outside-toplevel + from pygments import highlight + from pygments.formatters import HtmlFormatter, Terminal256Formatter + from pygments.lexers.python import Python3Lexer + + if is_in_notebook: + from IPython import display # pylint: disable=import-outside-toplevel + + formatter = HtmlFormatter(style=style) + formatter.noclasses = True # inline styles + html = highlight(printable, Python3Lexer(), formatter) + display.display(display.HTML(html)) + else: + print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=style))) + + +def _format(code_str: str) -> str: + """Format a code string using Black. + + Parameters + ---------- + code_str: str + + The string containing Python/TVMScript code to format + + Returns + ------- + formatted: str + + The formatted Python/TVMScript code + """ + try: + # pylint: disable=import-outside-toplevel + import black + except ImportError as err: + with warnings.catch_warnings(): + warnings.simplefilter("once", UserWarning) + install_cmd = sys.executable + ' -m pip install "black==22.3.0" --upgrade --user' + warnings.warn( + str(err) + + "\n" + + "To print formatted TVM script, please install the formatter 'Black':\n" + + install_cmd, + category=UserWarning, + ) + return code_str + else: + return black.format_str(code_str, mode=black.FileMode()) + + +def _get_pygments_style( + style: Optional[str], is_in_notebook: bool +) -> Optional[Union["pygments.style.Style", str]]: + """Select a pygments style to use + + Parameters + ---------- + style: str + + The style specifier to use. If None, auto-select a style. + + is_in_notebook: bool + + Whether python is currently running in a jupyter notebook. + Used for automatic selection. + + Returns + ------- + style: Optional[Union['pygments.style.Style',str]] + + If pygments is installed, the style object or string, suitable + for use as the "style" argument to pygments formatters. If + pygments is not installed, returns None. + + """ try: # pylint: disable=import-outside-toplevel import pygments from packaging import version - from pygments import highlight - from pygments.formatters import HtmlFormatter, Terminal256Formatter - from pygments.lexers.python import Python3Lexer from pygments.style import Style from pygments.token import Comment, Keyword, Name, Number, Operator, String @@ -69,82 +170,75 @@ def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = Non + install_cmd, category=UserWarning, ) - print(printable) - else: + return None - class JupyterLight(Style): - """A Jupyter-Notebook-like Pygments style configuration (aka. "light")""" - - background_color = "" - styles = { - Keyword: "bold #008000", - Keyword.Type: "nobold #008000", - Name.Function: "#0000FF", - Name.Class: "bold #0000FF", - Name.Decorator: "#AA22FF", - String: "#BA2121", - Number: "#008000", - Operator: "bold #AA22FF", - Operator.Word: "bold #008000", - Comment: "italic #007979", - } - - class VSCDark(Style): - """A VSCode-Dark-like Pygments style configuration (aka. "dark")""" - - background_color = "" - styles = { - Keyword: "bold #c586c0", - Keyword.Type: "#82aaff", - Keyword.Namespace: "#4ec9b0", - Name.Class: "bold #569cd6", - Name.Function: "bold #dcdcaa", - Name.Decorator: "italic #fe4ef3", - String: "#ce9178", - Number: "#b5cea8", - Operator: "#bbbbbb", - Operator.Word: "#569cd6", - Comment: "italic #6a9956", - } - - class AnsiTerminalDefault(Style): - """The default style for terminal display with ANSI colors (aka. "ansi")""" - - background_color = "" - styles = { - Keyword: "bold ansigreen", - Keyword.Type: "nobold ansigreen", - Name.Class: "bold ansiblue", - Name.Function: "bold ansiblue", - Name.Decorator: "italic ansibrightmagenta", - String: "ansiyellow", - Number: "ansibrightgreen", - Operator: "bold ansimagenta", - Operator.Word: "bold ansigreen", - Comment: "italic ansibrightblack", - } - - is_in_notebook = "ipykernel" in sys.modules # in notebook env (support html display). - - if style is None: - # choose style automatically according to the environment: - style = JupyterLight if is_in_notebook else AnsiTerminalDefault - elif style == "light": - style = JupyterLight - elif style == "dark": - style = VSCDark - elif style == "ansi": - style = AnsiTerminalDefault - - if is_in_notebook: # print with HTML display - from IPython.display import ( # pylint: disable=import-outside-toplevel - HTML, - display, - ) + class JupyterLight(Style): + """A Jupyter-Notebook-like Pygments style configuration (aka. "light")""" + + background_color = "" + styles = { + Keyword: "bold #008000", + Keyword.Type: "nobold #008000", + Name.Function: "#0000FF", + Name.Class: "bold #0000FF", + Name.Decorator: "#AA22FF", + String: "#BA2121", + Number: "#008000", + Operator: "bold #AA22FF", + Operator.Word: "bold #008000", + Comment: "italic #007979", + } + + class VSCDark(Style): + """A VSCode-Dark-like Pygments style configuration (aka. "dark")""" + + background_color = "" + styles = { + Keyword: "bold #c586c0", + Keyword.Type: "#82aaff", + Keyword.Namespace: "#4ec9b0", + Name.Class: "bold #569cd6", + Name.Function: "bold #dcdcaa", + Name.Decorator: "italic #fe4ef3", + String: "#ce9178", + Number: "#b5cea8", + Operator: "#bbbbbb", + Operator.Word: "#569cd6", + Comment: "italic #6a9956", + } + + class AnsiTerminalDefault(Style): + """The default style for terminal display with ANSI colors (aka. "ansi")""" + + background_color = "" + styles = { + Keyword: "bold ansigreen", + Keyword.Type: "nobold ansigreen", + Name.Class: "bold ansiblue", + Name.Function: "bold ansiblue", + Name.Decorator: "italic ansibrightmagenta", + String: "ansiyellow", + Number: "ansibrightgreen", + Operator: "bold ansimagenta", + Operator.Word: "bold ansigreen", + Comment: "italic ansibrightblack", + } + + if style == "light": + return JupyterLight + elif style == "dark": + return VSCDark + elif style == "ansi": + return AnsiTerminalDefault + + if style is not None: + return style + + style_from_environment = os.environ.get("TVM_PYGMENTIZE_STYLE", "").strip() + if style_from_environment: + return style_from_environment + + if is_in_notebook: + return JupyterLight - formatter = HtmlFormatter(style=JupyterLight) - formatter.noclasses = True # inline styles - html = highlight(printable, Python3Lexer(), formatter) - display(HTML(html)) - else: - print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=style))) + return AnsiTerminalDefault diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index c5cc922a3e48..082faeb456d3 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -189,18 +189,24 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: self, tir_prefix, show_meta ) # type: ignore - def show(self, style: Optional[str] = None) -> None: - """ - A sugar for print highlighted TVM script. + def show(self, style: Optional[str] = None, black_format: bool = True) -> None: + """A sugar for print highlighted TVM script. + Parameters ---------- style : str, optional - Pygments styles extended by "light" (default) and "dark", by default "light" + + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + + black_format: bool + + If true (default), use the formatter Black to format the TVMScript """ from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel # Use deferred import to avoid circular import while keeping cprint under tvm/script - cprint(self, style=style) + cprint(self, style=style, black_format=black_format) @tvm._ffi.register_object("tir.TensorIntrin") diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index da599081df3b..99e48debfefd 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -259,16 +259,22 @@ def apply_json_to_schedule(json_obj: JSON_TYPE, sch: "Schedule") -> None: """ _ffi_api.TraceApplyJSONToSchedule(json_obj, sch) # type: ignore # pylint: disable=no-member - def show(self, style: Optional[str] = None) -> None: - """A sugar for print highlighted trace. + def show(self, style: Optional[str] = None, black_format: bool = True) -> None: + """A sugar for print highlighted TVM script. Parameters ---------- style : str, optional - Pygments styles extended by "light" (default) and "dark", by default "light" + + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + + black_format: bool + + If true (default), use the formatter Black to format the TVMScript """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, ) - cprint(str(self), style=style) + cprint(str(self), style=style, black_format=black_format)