diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index e7846c0680d7..dff542911139 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Configuration of TVMScript printer""" +import os from typing import Dict, List, Optional, Sequence from tvm._ffi import get_global_func, register_object @@ -199,7 +200,7 @@ def script( def show( self, style: Optional[str] = None, - black_format: bool = False, + black_format: Optional[bool] = None, *, name: Optional[str] = None, show_meta: bool = False, @@ -226,8 +227,26 @@ def show( style : str, optional Pygmentize printing style, auto-detected if None. See `tvm.script.highlight.cprint` for more details. - black_format: bool - If true, use the formatter Black to format the TVMScript + + black_format: Optional[bool] + + If true, use the formatter Black to format the TVMScript. + If false, do not apply the auto-formatter. + + If None (default), determine the behavior based on the + environment variable "TVM_BLACK_FORMAT". If this + environment variable is unset, set to the empty string, or + set to the integer zero, black auto-formatting will be + disabled. If the environment variable is set to a + non-zero integer, black auto-formatting will be enabled. + + Note that the "TVM_BLACK_FORMAT" environment variable only + applies to the `.show()` method, and not the underlying + `.script()` method. The `.show()` method is intended for + human-readable output based on individual user + preferences, while the `.script()` method is intended to + provided a consistent output regardless of environment. + name : Optional[str] = None The name of the object show_meta : bool = False @@ -263,11 +282,16 @@ def show( Object to be underlined obj_to_annotate : Optional[Dict[Object, str]] = None Object to be annotated + """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, ) + if black_format is None: + env = os.environ.get("TVM_BLACK_FORMAT") + black_format = env and int(env) + cprint( self.script( name=name, diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 85d1f19bba00..5df3a486ca2e 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" +import inspect from typing import Callable, Dict, List, Optional, Tuple, Union from tvm._ffi import register_object as _register_object @@ -268,26 +269,24 @@ def fork_seed(self) -> int: """ return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member - def show(self, style: Optional[str] = None, black_format: bool = False) -> None: + def show(self, *args, **kwargs) -> None: """A sugar for print highlighted TVM script. - Parameters - ---------- - style : str, optional - - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - - black_format: bool - - If true, use the formatter Black to format the TVMScript + All parameters are forwarded to the underlying `Module.show` + and `Trace.show` methods. """ mod = self.mod if mod is not None: - mod.show(style=style, black_format=black_format) + mod.show(*args, **kwargs) + trace = self.trace if trace is not None: - trace.show(style=style, black_format=black_format) + # Trace.show only supports the style and black_format arguments + param_binding = inspect.signature(mod.show).bind(*args, **kwargs) + param_binding.apply_defaults() + bound_args = param_binding.arguments + + trace.show(style=bound_args["style"], black_format=bound_args["black_format"]) ########## Lookup ########## diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index e3173049931d..cb8d5ce9973e 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """An execution trace of a scheduling program""" +import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from tvm._ffi import register_object as _register_object @@ -274,10 +275,16 @@ def show(self, style: Optional[str] = None, black_format: bool = False) -> None: black_format: bool - If true, use the formatter Black to format the TVMScript + If true, use the formatter Black to format the TVMScript. + If None, determine based on the "TVM_BLACK_FORMAT" environment + variable. """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, ) + if black_format is None: + env = os.environ.get("TVM_BLACK_FORMAT") + black_format = bool(env and int(env)) + cprint(str(self), style=style, black_format=black_format)