From 260e96dc402b047f9634ef50d1137e10fe7a28b2 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 3 Oct 2022 18:07:58 -0700 Subject: [PATCH 1/5] [TVMScript] AST, Source and diagnostics for Parser This PR introduces AST, Source and diagnostics for Parser Co-authored-by: yongwww --- python/tvm/script/_parser/__init__.py | 18 + python/tvm/script/_parser/_core.py | 19 + python/tvm/script/_parser/core/__init__.py | 18 + python/tvm/script/_parser/core/diagnostics.py | 211 ++++++++++ python/tvm/script/_parser/core/doc.py | 361 ++++++++++++++++++ .../{printer => _parser/core}/doc_core.py | 0 python/tvm/script/_parser/core/utils.py | 60 +++ .../unittest/test_tvmscript_parser_source.py | 86 +++++ 8 files changed, 773 insertions(+) create mode 100644 python/tvm/script/_parser/__init__.py create mode 100644 python/tvm/script/_parser/_core.py create mode 100644 python/tvm/script/_parser/core/__init__.py create mode 100644 python/tvm/script/_parser/core/diagnostics.py create mode 100644 python/tvm/script/_parser/core/doc.py rename python/tvm/script/{printer => _parser/core}/doc_core.py (100%) create mode 100644 python/tvm/script/_parser/core/utils.py create mode 100644 tests/python/unittest/test_tvmscript_parser_source.py diff --git a/python/tvm/script/_parser/__init__.py b/python/tvm/script/_parser/__init__.py new file mode 100644 index 000000000000..d885b405257b --- /dev/null +++ b/python/tvm/script/_parser/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the Licens. +"""The parser""" +from . import _core diff --git a/python/tvm/script/_parser/_core.py b/python/tvm/script/_parser/_core.py new file mode 100644 index 000000000000..a2dcc5b531dc --- /dev/null +++ b/python/tvm/script/_parser/_core.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the Licens. +"""The core parser infra""" +# pylint: disable=unused-import +from .core import doc, utils diff --git a/python/tvm/script/_parser/core/__init__.py b/python/tvm/script/_parser/core/__init__.py new file mode 100644 index 000000000000..ae1521006d9b --- /dev/null +++ b/python/tvm/script/_parser/core/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The core parser infra""" +from . import diagnostics, doc, doc_core, utils diff --git a/python/tvm/script/_parser/core/diagnostics.py b/python/tvm/script/_parser/core/diagnostics.py new file mode 100644 index 000000000000..26a6c47f9df3 --- /dev/null +++ b/python/tvm/script/_parser/core/diagnostics.py @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import inspect +import re +import sys +from typing import Union + +from tvm.ir import IRModule, SourceName, Span, diagnostics + +from . import doc + + +class Source: + """Source code class for TVMScript.""" + + source_name: str + start_line: int + start_column: int + source: str + full_source: str + + def __init__(self, program: Union[str, doc.AST]): + if isinstance(program, str): + self.source_name = "" + self.start_line = 1 + self.start_column = 0 + self.source = program + self.full_source = program + return + + self.source_name = inspect.getsourcefile(program) # type: ignore + lines, self.start_line = getsourcelines(program) # type: ignore + if lines: + self.start_column = len(lines[0]) - len(lines[0].lstrip()) + else: + self.start_column = 0 + if self.start_column and lines: + self.source = "\n".join([l[self.start_column :].rstrip() for l in lines]) + else: + self.source = "".join(lines) + try: + # It will cause a problem when running in Jupyter Notebook. + # `mod` will be , which is a built-in module + # and `getsource` will throw a TypeError + mod = inspect.getmodule(program) + if mod: + self.full_source = inspect.getsource(mod) + else: + self.full_source = self.source + except TypeError: + # It's a work around for Jupyter problem. + # Since `findsource` is an internal API of inspect, we just use it + # as a fallback method. + src, _ = inspect.findsource(program) # type: ignore + self.full_source = "".join(src) + + def as_ast(self) -> doc.AST: + """Parse the source code into AST. + + Returns + ------- + res : doc.AST + The AST of source code. + """ + return doc.parse(self.source) + + +_getfile = inspect.getfile # pylint: disable=invalid-name +_findsource = inspect.findsource # pylint: disable=invalid-name + + +def _patched_inspect_getfile(obj): + """Work out which source or compiled file an object was defined in.""" + if not inspect.isclass(obj): + return _getfile(obj) + mod = getattr(obj, "__module__", None) + if mod is not None: + file = getattr(sys.modules[mod], "__file__", None) + if file is not None: + return file + for _, member in inspect.getmembers(obj): + if inspect.isfunction(member): + if obj.__qualname__ + "." + member.__name__ == member.__qualname__: + return inspect.getfile(member) + raise TypeError(f"Source for {obj:!r} not found") + + +def findsource(obj): + """Return the entire source file and starting line number for an object.""" + import linecache # pylint: disable=import-outside-toplevel + + if not inspect.isclass(obj): + return _findsource(obj) + + file = inspect.getsourcefile(obj) + if file: + linecache.checkcache(file) + else: + file = inspect.getfile(obj) + if not (file.startswith("<") and file.endswith(">")): + raise OSError("source code not available") + + module = inspect.getmodule(obj, file) + if module: + lines = linecache.getlines(file, module.__dict__) + else: + lines = linecache.getlines(file) + if not lines: + raise OSError("could not get source code") + qual_names = obj.__qualname__.replace(".", "").split(".") + pattern_list = [] + for name in qual_names: + if name.endswith(""): + pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b")) + else: + pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b")) + for i, line in enumerate(lines): + match = pattern_list[0].match(line) + if match: + pattern_list.pop(0) + if not pattern_list: + return lines, i + raise OSError("could not find class definition") + + +def getsourcelines(obj): + """Extract the block of code at the top of the given list of lines.""" + obj = inspect.unwrap(obj) + lines, l_num = findsource(obj) + return inspect.getblock(lines[l_num:]), l_num + 1 + + +inspect.getfile = _patched_inspect_getfile + + +class Diagnostics: + """Diagnostics class for error reporting in parser.""" + + source: Source + ctx: diagnostics.DiagnosticContext + + def __init__(self, source: Source): + mod = IRModule() + mod.source_map.add(source.source_name, source.full_source) + self.source = source + self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer()) + + def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None: + """Emit a diagnostic. + + Parameters + ---------- + node : doc.AST + The node with diagnostic information. + + message : str + The diagnostic message. + + level : diagnostics.DiagnosticLevel + The diagnostic level. + """ + lineno = node.lineno or self.source.start_line + col_offset = node.col_offset or self.source.start_column + end_lineno = node.end_lineno or lineno + end_col_offset = node.end_col_offset or col_offset + lineno += self.source.start_line - 1 + end_lineno += self.source.start_line - 1 + col_offset += self.source.start_column + 1 + end_col_offset += self.source.start_column + 1 + self.ctx.emit( + diagnostics.Diagnostic( + level=level, + span=Span( + source_name=SourceName(self.source.source_name), + line=lineno, + end_line=end_lineno, + column=col_offset, + end_column=end_col_offset, + ), + message=message, + ) + ) + + def error(self, node: doc.AST, message: str) -> None: + """Emit a diagnostic error. + + Parameters + ---------- + node : doc.AST + The node with diagnostic error. + + message : str + The diagnostic message. + """ + self._emit(node, message, diagnostics.DiagnosticLevel.ERROR) + self.ctx.render() diff --git a/python/tvm/script/_parser/core/doc.py b/python/tvm/script/_parser/core/doc.py new file mode 100644 index 000000000000..f6a641cb6422 --- /dev/null +++ b/python/tvm/script/_parser/core/doc.py @@ -0,0 +1,361 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import ast +import inspect +import sys +import typing +from collections import defaultdict + +from . import doc_core as doc +from .doc_core import * # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614 + +FnToDoc = typing.Callable[[ast.AST], doc.AST] +FnFromDoc = typing.Callable[[doc.AST], ast.AST] + + +class Entry: + to_doc: typing.Optional[FnToDoc] + from_doc: typing.Optional[FnFromDoc] + + def __init__(self): + self.to_doc = None + self.from_doc = None + + +class Registry: + _inst: typing.Optional["Registry"] = None + table: typing.Dict[str, Entry] + + def __init__(self): + self.table = defaultdict(Entry) + + +def register_to_doc(name: str): + def f(to_doc: FnToDoc): # pylint: disable=redefined-outer-name + reg = Registry._inst # pylint: disable=protected-access + reg.table[name].to_doc = to_doc + + return f + + +def register_from_doc(name: str): + def f(to_doc: FnFromDoc): # pylint: disable=redefined-outer-name + reg = Registry._inst # pylint: disable=protected-access + reg.table[name].from_doc = to_doc + + return f + + +def _is_atomic_type(node): + return ( + node is None + or node in [..., True, False] + or isinstance( + node, + ( + int, + float, + str, + bool, + bytes, + complex, + ), + ) + ) + + +def _get_registry_entry(cls_name, attr): + cls_name = cls_name.split(".")[-1] + reg = Registry._inst # pylint: disable=protected-access + if cls_name in reg.table: + entry = reg.table[cls_name] + return getattr(entry, attr, None) + return None + + +def from_doc(node): + if _is_atomic_type(node): + return node + if isinstance(node, tuple): + return tuple(from_doc(n) for n in node) + if isinstance(node, list): + return [from_doc(n) for n in node] + func = _get_registry_entry(node.__class__.__name__, "from_doc") + if not func: + raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}") + return func(node) + + +def to_doc(node): + if _is_atomic_type(node): + return node + if isinstance(node, tuple): + return tuple(to_doc(n) for n in node) + if isinstance(node, list): + return [to_doc(n) for n in node] + func = _get_registry_entry(node.__class__.__name__, "to_doc") + if not func: + raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}") + return func(node) + + +def parse( + source, + filename="", + mode="exec", +) -> doc.AST: + try: + program = ast.parse( # pylint: disable=unexpected-keyword-arg + source=source, + filename=filename, + mode=mode, + feature_version=(3, 8), + ) + except: # pylint: disable=bare-except + program = ast.parse( + source=source, + filename=filename, + mode=mode, + ) + return to_doc(program) + + +class NodeVisitor: + def visit(self, node: doc.AST) -> None: + if isinstance(node, (list, tuple)): + for item in node: + self.visit(item) + return + if not isinstance(node, doc.AST): + return + getattr( + self, + "visit_" + node.__class__.__name__.split(".")[-1], + self.generic_visit, + )(node) + + def generic_visit(self, node: doc.AST) -> None: + for field in node.__class__._FIELDS: # pylint: disable=protected-access + value = getattr(node, field, None) + if value is None: + pass + elif isinstance(value, (doc.AST, list, tuple)): + self.visit(value) + + +class NodeTransformer: + def visit(self, node: doc.AST) -> doc.AST: + if isinstance(node, list): + return [self.visit(item) for item in node] + if isinstance(node, tuple): + return tuple(self.visit(item) for item in node) + if not isinstance(node, doc.AST): + return node + return getattr( + self, + "visit_" + node.__class__.__name__.split(".")[-1], + self.generic_visit, + )(node) + + def generic_visit(self, node: doc.AST) -> doc.AST: + kv: typing.Dict[str, typing.Any] = {} + for field in node.__class__._FIELDS: # pylint: disable=protected-access + value = getattr(node, field, None) + if value is None: + pass + elif isinstance(value, (doc.AST, list, tuple)): + value = self.visit(value) + kv[field] = value + return node.__class__(**kv) + + +def _register_default(): + class DefaultTranslator: + def __init__(self, doc_cls, func, fields): + self.doc_cls = doc_cls # getattr(doc, name) + self.func = func + self.fields = fields + + def __call__(self, node): + kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields} + return self.doc_cls(**kv) + + Registry._inst = Registry() # pylint: disable=protected-access + for cls_name in dir(doc): + doc_cls = getattr(doc, cls_name) + if not hasattr(ast, cls_name): + continue + if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST): + assert "." not in cls_name + register_to_doc(cls_name)( + DefaultTranslator( + getattr(doc, cls_name), + to_doc, + doc_cls._FIELDS, # pylint: disable=protected-access + ) + ) + register_from_doc(cls_name)( + DefaultTranslator( + getattr(ast, cls_name), + from_doc, + doc_cls._FIELDS, # pylint: disable=protected-access + ) + ) + + +def _py_version() -> typing.Tuple[int, int]: + return (sys.version_info.major, sys.version_info.minor) + + +def _register_constant_handling(): + if _py_version() not in [(3, 6), (3, 7)]: + return + + def as_constant(f) -> doc.Constant: + def to_doc_func(x: ast.AST) -> doc.Constant: + return doc.Constant( + value=getattr(x, f) if isinstance(f, str) else f(x), + kind=None, + s=None, + n=None, + lineno=x.lineno, + col_offset=x.col_offset, + end_lineno=x.lineno, + end_col_offset=x.col_offset, + ) + + return to_doc_func + + register_to_doc("Str")(as_constant("s")) + register_to_doc("NameConstant")(as_constant("value")) + register_to_doc("Num")(as_constant("n")) + register_to_doc("Bytes")(as_constant("s")) + register_to_doc("Ellipsis")(as_constant(lambda _: ...)) + + +def _register_subscription_handling(): + if _py_version() >= (3, 9): + return + + def subscript_to_doc(x: ast.Subscript) -> doc.Subscript: + if isinstance(x.slice, ast.Slice): + return doc.Subscript( + value=to_doc(x.value), + slice=doc.Slice( + lower=to_doc(x.slice.lower), + upper=to_doc(x.slice.upper), + step=to_doc(x.slice.step), + lineno=getattr(x.slice, "lineno", None), + col_offset=getattr(x.slice, "col_offset", None), + end_lineno=getattr(x.slice, "end_lineno", None), + end_col_offset=getattr(x.slice, "end_col_offset", None), + ), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + if isinstance(x.slice, ast.ExtSlice): + return doc.Subscript( + value=to_doc(x.value), + slice=doc.Tuple( + elts=[to_doc(i) for i in x.slice.dims], + ctx=doc.Load( + lineno=None, + col_offset=None, + end_lineno=None, + end_col_offset=None, + ), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + if isinstance(x.slice, ast.Index): + return doc.Subscript( + value=to_doc(x.value), + slice=to_doc(x.slice.value), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + raise TypeError(f"Unknown subscript type: {type(x.slice)}") + + def subscript_from_doc(x: doc.Subscript) -> ast.Subscript: + if isinstance(x.slice, doc.Slice): + result = ast.Subscript( + value=from_doc(x.value), + slice=from_doc(x.slice), + ctx=from_doc(x.ctx), + ) + elif isinstance(x.slice, doc.Tuple): + result = ast.Subscript( + value=from_doc(x.value), + slice=ast.ExtSlice( + dims=[from_doc(i) for i in x.slice.elts], + ), + ctx=from_doc(x.ctx), + ) + else: + result = ast.Subscript( + value=from_doc(x.value), + slice=ast.Index(value=from_doc(x.slice)), + ctx=from_doc(x.ctx), + ) + result.lineno = x.lineno + result.col_offset = x.col_offset + result.end_lineno = x.end_lineno + result.end_col_offset = x.end_col_offset + return result + + register_to_doc("Subscript")(subscript_to_doc) + register_from_doc("Subscript")(subscript_from_doc) + + +def _register_index_handling(): + if _py_version() >= (3, 9): + return + + def index_to_doc(x: ast.Index) -> doc.Expr: + return to_doc(x.value) + + def index_from_doc(x: doc.Expr) -> ast.Index: + result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx)) + result.lineno = x.lineno + result.col_offset = x.col_offset + result.end_lineno = x.end_lineno + result.end_col_offset = x.end_col_offset + return result + + register_to_doc("Index")(index_to_doc) + register_from_doc("Index")(index_from_doc) + + +_register_default() +_register_constant_handling() +_register_subscription_handling() +_register_index_handling() diff --git a/python/tvm/script/printer/doc_core.py b/python/tvm/script/_parser/core/doc_core.py similarity index 100% rename from python/tvm/script/printer/doc_core.py rename to python/tvm/script/_parser/core/doc_core.py diff --git a/python/tvm/script/_parser/core/utils.py b/python/tvm/script/_parser/core/utils.py new file mode 100644 index 000000000000..8f88de42a85a --- /dev/null +++ b/python/tvm/script/_parser/core/utils.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import inspect +from typing import Any, Callable, Dict + + +def inspect_function_capture(func: Callable) -> Dict[str, Any]: + """Capture function non-locals and global variables. + + Parameters + ---------- + func : Callable + The function to inspect. + + Returns + ------- + res : Dict[str, Any] + The function variables map with non-local or global variables. + """ + captured = { + **inspect.getclosurevars(func).nonlocals, + **func.__globals__, # type: ignore + } + return captured + + +def inspect_class_capture(cls: type) -> Dict[str, Any]: + """Capture class non-locals and global variables. + + Parameters + ---------- + cls : type + The class to inspect. + + Returns + ------- + res : Dict[str, Any] + The class variables map with non-local or global variables. + """ + result: Dict[str, Any] = {} + for _, v in cls.__dict__.items(): + if inspect.isfunction(v): + func_vars = inspect_function_capture(v) + result.update(**func_vars) + return result diff --git a/tests/python/unittest/test_tvmscript_parser_source.py b/tests/python/unittest/test_tvmscript_parser_source.py new file mode 100644 index 000000000000..c638ab1ac84a --- /dev/null +++ b/tests/python/unittest/test_tvmscript_parser_source.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unittests for tvm.script.parser.core""" +import pytest +import inspect +import tvm.testing +from tvm.script._parser.core.diagnostics import Source +from tvm.script._parser.core import doc_core as doc +from tvm.script import tir as T + + +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +def test_source_base(): + source = Source(matmul) + assert ( + source.source_name == inspect.getsourcefile(matmul) + and source.start_line == 26 + and source.start_column == 0 + and source.source == inspect.getsource(matmul) + and source.full_source == inspect.getsource(inspect.getmodule(matmul)) + ) + + +def test_source_ast(): + source = Source(matmul) + mod = source.as_ast() + assert isinstance(mod, doc.Module) + func_def = mod.body[0] + assert isinstance(func_def, doc.FunctionDef) + assert func_def.name == "matmul" + func_args = func_def.args + assert ( + len(func_args.args) == 3 + and func_args.args[0].arg == "a" + and func_args.args[1].arg == "b" + and func_args.args[2].arg == "c" + ) + func_body = func_def.body + assert len(func_body) == 4 + func_assigns = func_body[:3] + assert ( + isinstance(func_assigns[0], doc.Assign) + and func_assigns[0].targets[0].id == "A" + and isinstance(func_assigns[1], doc.Assign) + and func_assigns[1].targets[0].id == "B" + and isinstance(func_assigns[2], doc.Assign) + and func_assigns[2].targets[0].id == "C" + ) + func_for = func_body[3] + assert ( + len(func_for.target.elts) == 3 + and func_for.target.elts[0].id == "i" + and func_for.target.elts[1].id == "j" + and func_for.target.elts[2].id == "k" + ) + for_body = func_for.body + assert len(for_body) == 1 + for_block = for_body[0] + assert isinstance(for_block, doc.With) and len(for_block.body) == 2 + + +if __name__ == "__main__": + tvm.testing.main() From 720859659357963e014adedd188f05a8f1b65baa Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 4 Oct 2022 13:11:14 -0700 Subject: [PATCH 2/5] apply code review suggestion --- tests/python/unittest/test_tvmscript_parser_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_parser_source.py b/tests/python/unittest/test_tvmscript_parser_source.py index c638ab1ac84a..cb93a2dcf62b 100644 --- a/tests/python/unittest/test_tvmscript_parser_source.py +++ b/tests/python/unittest/test_tvmscript_parser_source.py @@ -37,7 +37,7 @@ def test_source_base(): source = Source(matmul) assert ( source.source_name == inspect.getsourcefile(matmul) - and source.start_line == 26 + and source.start_line is not None and source.start_column == 0 and source.source == inspect.getsource(matmul) and source.full_source == inspect.getsource(inspect.getmodule(matmul)) From d41a85333a0faf189df4654e205fa3e155eb9284 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 4 Oct 2022 16:59:55 -0700 Subject: [PATCH 3/5] add doc for public APIs --- python/tvm/script/_parser/core/diagnostics.py | 1 - python/tvm/script/_parser/core/doc.py | 29 +++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/_parser/core/diagnostics.py b/python/tvm/script/_parser/core/diagnostics.py index 26a6c47f9df3..a41b0d694fb7 100644 --- a/python/tvm/script/_parser/core/diagnostics.py +++ b/python/tvm/script/_parser/core/diagnostics.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring import inspect import re import sys diff --git a/python/tvm/script/_parser/core/doc.py b/python/tvm/script/_parser/core/doc.py index f6a641cb6422..3a4865f2e86d 100644 --- a/python/tvm/script/_parser/core/doc.py +++ b/python/tvm/script/_parser/core/doc.py @@ -29,6 +29,8 @@ class Entry: + """Mapping entry between str and doc AST.""" + to_doc: typing.Optional[FnToDoc] from_doc: typing.Optional[FnFromDoc] @@ -38,6 +40,8 @@ def __init__(self): class Registry: + """Registration map for str and doc AST""" + _inst: typing.Optional["Registry"] = None table: typing.Dict[str, Entry] @@ -115,10 +119,29 @@ def to_doc(node): def parse( - source, - filename="", - mode="exec", + source: str, + filename: str = "", + mode: str = "exec", ) -> doc.AST: + """Parse TVMScript source code to doc AST. + + Parameters + ---------- + source : str + The TVMScript source code. + + filename : str + The optional filename of source code. + + mode : str + The parsing mode. + + Returns + ------- + res : doc.AST + The parsed doc AST. + """ + try: program = ast.parse( # pylint: disable=unexpected-keyword-arg source=source, From 7af4bc7d44c65c14360bfa33d8b70cee68113e70 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 4 Oct 2022 17:29:45 -0700 Subject: [PATCH 4/5] add module doc and function doc --- python/tvm/script/_parser/core/diagnostics.py | 2 ++ python/tvm/script/_parser/core/doc.py | 32 +++++++++++++++++-- python/tvm/script/_parser/core/utils.py | 3 +- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/_parser/core/diagnostics.py b/python/tvm/script/_parser/core/diagnostics.py index a41b0d694fb7..30091821015c 100644 --- a/python/tvm/script/_parser/core/diagnostics.py +++ b/python/tvm/script/_parser/core/diagnostics.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""TVM Script Parser Source and diagnostics""" + import inspect import re import sys diff --git a/python/tvm/script/_parser/core/doc.py b/python/tvm/script/_parser/core/doc.py index 3a4865f2e86d..71f16a93b278 100644 --- a/python/tvm/script/_parser/core/doc.py +++ b/python/tvm/script/_parser/core/doc.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring +"""TVM Script Parser doc AST""" + import ast import inspect import sys @@ -93,6 +94,18 @@ def _get_registry_entry(cls_name, attr): def from_doc(node): + """Get AST node from doc AST node. + + Parameters + ---------- + node : doc.AST + The doc AST node. + + Returns + ------- + res : ast.AST + The corresponding AST node. + """ if _is_atomic_type(node): return node if isinstance(node, tuple): @@ -106,6 +119,18 @@ def from_doc(node): def to_doc(node): + """Get doc AST node from AST node. + + Parameters + ---------- + node : ast.AST + The AST node. + + Returns + ------- + res : doc.AST + The corresponding doc AST node. + """ if _is_atomic_type(node): return node if isinstance(node, tuple): @@ -141,7 +166,6 @@ def parse( res : doc.AST The parsed doc AST. """ - try: program = ast.parse( # pylint: disable=unexpected-keyword-arg source=source, @@ -159,6 +183,8 @@ def parse( class NodeVisitor: + """ "Node visitor for doc AST""" + def visit(self, node: doc.AST) -> None: if isinstance(node, (list, tuple)): for item in node: @@ -182,6 +208,8 @@ def generic_visit(self, node: doc.AST) -> None: class NodeTransformer: + """ "Node transformer for doc AST""" + def visit(self, node: doc.AST) -> doc.AST: if isinstance(node, list): return [self.visit(item) for item in node] diff --git a/python/tvm/script/_parser/core/utils.py b/python/tvm/script/_parser/core/utils.py index 8f88de42a85a..65e7166bfcc2 100644 --- a/python/tvm/script/_parser/core/utils.py +++ b/python/tvm/script/_parser/core/utils.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring +"""TVM Script Parser utils""" + import inspect from typing import Any, Callable, Dict From 400ea021d2bb34878007fab4b876dcf570031a25 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 7 Oct 2022 22:05:27 -0700 Subject: [PATCH 5/5] add more detailed doc --- python/tvm/script/_parser/core/diagnostics.py | 33 ++++++++- python/tvm/script/_parser/core/doc.py | 68 ++++++++++++++++--- 2 files changed, 90 insertions(+), 11 deletions(-) diff --git a/python/tvm/script/_parser/core/diagnostics.py b/python/tvm/script/_parser/core/diagnostics.py index 30091821015c..b077d221424c 100644 --- a/python/tvm/script/_parser/core/diagnostics.py +++ b/python/tvm/script/_parser/core/diagnostics.py @@ -27,7 +27,27 @@ class Source: - """Source code class for TVMScript.""" + """Source code class for TVMScript. + + It is constructed by source code str or doc AST tree. + + Parameters + ---------- + source_name : str + The filename of the file where the source code locates. + + start_line : int + The first line number of the source code. + + start_column : int + The first column number of the first line of the source code. + + source : str + The source code str of source code. + + full_source : str + The complete source code of the file where the source code locates. + """ source_name: str start_line: int @@ -150,7 +170,16 @@ def getsourcelines(obj): class Diagnostics: - """Diagnostics class for error reporting in parser.""" + """Diagnostics class for error reporting in parser. + + Parameters + ---------- + source : Source + The source code. + + ctx : diagnostics.DiagnosticContext + The diagnostic context for diagnostics. + """ source: Source ctx: diagnostics.DiagnosticContext diff --git a/python/tvm/script/_parser/core/doc.py b/python/tvm/script/_parser/core/doc.py index 71f16a93b278..5ea83749eadf 100644 --- a/python/tvm/script/_parser/core/doc.py +++ b/python/tvm/script/_parser/core/doc.py @@ -30,7 +30,16 @@ class Entry: - """Mapping entry between str and doc AST.""" + """Mapping entry between python AST node type str and doc AST. + + Parameters + ---------- + to_doc : typing.Optional[FnToDoc] + The callable methods for converting python AST node to doc AST. + + from_doc : typing.Optional[FnFromDoc] + The callable methods for converting doc AST to python AST node. + """ to_doc: typing.Optional[FnToDoc] from_doc: typing.Optional[FnFromDoc] @@ -41,7 +50,18 @@ def __init__(self): class Registry: - """Registration map for str and doc AST""" + """Registration map from python AST node type str to methods of conversion + between python AST node and doc AST node. + + Parameters + ---------- + _inst : typing.Optional["Registry"] + The instance of Registry. + + table : typing.Dict[str, Entry] + The registration map from python AST node type str to methods of conversion + between python AST node and doc AST node. + """ _inst: typing.Optional["Registry"] = None table: typing.Dict[str, Entry] @@ -51,6 +71,19 @@ def __init__(self): def register_to_doc(name: str): + """Register the to_doc method for python AST node type. + + Parameters + ---------- + name : str + The type of python AST node. + + Returns + ------- + f : Callable[[FnToDoc], None] + The function of registering the to_doc method for python AST node type. + """ + def f(to_doc: FnToDoc): # pylint: disable=redefined-outer-name reg = Registry._inst # pylint: disable=protected-access reg.table[name].to_doc = to_doc @@ -59,6 +92,19 @@ def f(to_doc: FnToDoc): # pylint: disable=redefined-outer-name def register_from_doc(name: str): + """Register the from_doc method for python AST node type. + + Parameters + ---------- + name : str + The type of python AST node. + + Returns + ------- + f : Callable[[FnFromDoc], None] + The function of registering the from_doc method for python AST node type. + """ + def f(to_doc: FnFromDoc): # pylint: disable=redefined-outer-name reg = Registry._inst # pylint: disable=protected-access reg.table[name].from_doc = to_doc @@ -94,7 +140,7 @@ def _get_registry_entry(cls_name, attr): def from_doc(node): - """Get AST node from doc AST node. + """Get original python AST node from doc AST node. Parameters ---------- @@ -119,7 +165,7 @@ def from_doc(node): def to_doc(node): - """Get doc AST node from AST node. + """Get doc AST node from python AST node. Parameters ---------- @@ -148,7 +194,11 @@ def parse( filename: str = "", mode: str = "exec", ) -> doc.AST: - """Parse TVMScript source code to doc AST. + """Parse TVMScript source code str to doc AST. + + Its interface is consistent with python built-in ast.parse. + And it will parse by python 3.8 first if possible, + or it will parse with python version in current environment. Parameters ---------- @@ -156,10 +206,10 @@ def parse( The TVMScript source code. filename : str - The optional filename of source code. + The optional filename of the file where source code locates. mode : str - The parsing mode. + The parsing mode for ast.parse. Returns ------- @@ -183,7 +233,7 @@ def parse( class NodeVisitor: - """ "Node visitor for doc AST""" + """Node visitor for doc AST""" def visit(self, node: doc.AST) -> None: if isinstance(node, (list, tuple)): @@ -208,7 +258,7 @@ def generic_visit(self, node: doc.AST) -> None: class NodeTransformer: - """ "Node transformer for doc AST""" + """Node transformer for doc AST""" def visit(self, node: doc.AST) -> doc.AST: if isinstance(node, list):