-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TensorIR][M1a] TVMScript Parser/Printer #7630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cd45ed1
e3abef4
4306a6f
f77f36e
5a8984d
6fe3e1d
2af767c
f356cb6
9be5abf
27bee1f
9a2e55f
5d29045
0a2146a
496383d
83ca0e1
e4835db
b7a240f
702a0b0
f6806a0
77df062
c9f9cc4
39aea89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,59 +16,217 @@ | |
| # under the License. | ||
| """TVM Script Context Maintainer for TIR""" | ||
|
|
||
| from tvm.te import schedule | ||
| from typing import List, Mapping, Union, Optional, Dict, Callable | ||
| import synr | ||
|
|
||
|
|
||
| import tvm | ||
| from tvm.ir import Span | ||
| from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion | ||
| from tvm.runtime import Object | ||
| from .node import BufferSlice | ||
|
|
||
|
|
||
| class BlockInfo: | ||
| """Information for block and block_realize signature | ||
|
|
||
| Examples | ||
| ---------- | ||
| .. code-block:: python | ||
|
|
||
| @tvm.script.tir | ||
| def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: | ||
| A = tir.match_buffer(a, (16, 16), "float32") | ||
| B = tir.match_buffer(b, (16, 16), "float32") | ||
| C = tir.match_buffer(a, (16, 16), "float32") | ||
|
|
||
| for i, j, k in tir.grid(16, 16, 16): | ||
| with tir.block([16, 16, tir.reduce_axis(16)], "matmul") as [vi, vj, vk]: | ||
| tir.bind(vi, i) | ||
| tir.bind(vj, j) | ||
| tir.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k} | ||
|
|
||
| tir.where(True) # predicate of the block_realize | ||
|
|
||
| tir.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block | ||
| tir.writes(C[0: 16, 0: 16]) # writes region of the block | ||
| tir.block_attr({"attr_key": "attr_value"}) # block annotations | ||
|
|
||
| # alloc_buffers inside the block | ||
| CC = tir.alloc_buffer((1, 1), dtype="float32") | ||
|
|
||
| # match_buffers of the block, | ||
| # which bind a sub-region of source buffer into a new buffer | ||
| D = tir.match_buffer_region(C[vi, vj]) | ||
|
|
||
| # init part of the block, executed when all reduce axes are the beginning value | ||
| with tir.init(): | ||
| C[vi, vj] = tir.float32(0) | ||
|
|
||
| # block body | ||
| CC[0, 0] = A[vi, vk] * B[vj, vk] | ||
| D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] | ||
| """ | ||
|
|
||
| alloc_buffers: List[Buffer] = [] | ||
| """List[Buffer]: list of tir.alloc_buffer statements in the block signature""" | ||
| match_buffers: List[MatchBufferRegion] = [] | ||
| """List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature""" | ||
| iter_bindings: Mapping[Var, PrimExpr] = {} | ||
| """Mapping[Var, PrimExpr]: map of block iter var to its values""" | ||
| reads: Optional[List[BufferSlice]] = None | ||
| """Optional[List[BufferSlice]]: | ||
| list of tir.reads statements in the block signature, None for not-visited""" | ||
| writes: Optional[List[BufferSlice]] = None | ||
| """Optional[List[BufferSlice]]: | ||
| list of tir.writes statements in the block signature, None for not-visited""" | ||
| annotations: Optional[Mapping[str, Object]] = None | ||
| """Optional[Mapping[str, Object]]: | ||
| list of tir.block_attr statements in the block signature, None for not-visited""" | ||
| predicate: Optional[PrimExpr] = None | ||
| """Optional[PrimExpr]: block realize predicate, None for not-visited""" | ||
| init: Optional[Stmt] = None | ||
| """Optional[Stmt]: init part of the block, None for not-visited""" | ||
|
|
||
| def __init__(self): | ||
| self.alloc_buffers = [] | ||
| self.match_buffers = [] | ||
| self.iter_bindings = {} | ||
| self.reads = None | ||
| self.writes = None | ||
| self.annotations = None | ||
| self.predicate = None | ||
| self.init = None | ||
|
|
||
|
|
||
| class ContextMaintainer: | ||
| """Maintain all the necessary context info""" | ||
| """Maintain all the necessary context info | ||
| Parameters | ||
| ---------- | ||
| _report_error : Callable[[str, Union[Span, synr.ast.Span]], None] | ||
| The report error function handle | ||
| """ | ||
|
|
||
| # scope context | ||
| node_stack: List[List[synr.ast.Node]] = [] | ||
| """List[List[synr.ast.Node]]: The ast nodes insides the current scope""" | ||
| block_info_stack: List[BlockInfo] = [] | ||
| """List[BlockInfo]: The block info for the current block scope""" | ||
| loop_stack: List[List[Var]] = [] | ||
| """List[List[Var]]: List of loop vars inside the current block scope""" | ||
| symbols: List[Dict[str, Union[Var, Buffer]]] = [] | ||
| """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" | ||
|
|
||
| def __init__(self, parser): | ||
| # function context | ||
| func_params: List[Var] = [] | ||
| """List[Var]: The function parameters""" | ||
| func_buffer_map: Mapping[Var, Buffer] = {} | ||
| """Mapping[Var, Buffer]: The function buffer map""" | ||
| func_dict_attr: Mapping[str, Object] = {} | ||
| """Mapping[str, Object]: The function attrs""" | ||
| func_var_env_dict: Mapping[Var, str] = {} | ||
| """Mapping[Var, str]: The map from var to env thread""" | ||
|
|
||
| # parser and analyzer | ||
| analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer() | ||
| """tvm.arith.Analyzer: The analyzer for simplifying""" | ||
| _report_error: Callable[[str, Union[Span, synr.ast.Span]], None] | ||
| """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle""" | ||
|
|
||
| def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]): | ||
| # scope context | ||
| self.node_stack = [] # AST nodes of scopes | ||
| self.symbols = [] # symbols of scopes | ||
| self.node_stack = [] | ||
| self.block_info_stack = [] | ||
| self.loop_stack = [] | ||
| self.symbols = [] | ||
| # function context | ||
| self.func_params = [] # parameter list of function | ||
| self.func_buffer_map = {} # buffer_map of function | ||
| self.func_dict_attr = {} # func_attr of function | ||
| self.func_var_env_dict = {} # map from var to env_name | ||
| # parser | ||
| self.parser = parser | ||
|
|
||
| def pop_scope(self): | ||
| """Pop the inner most scope""" | ||
| self.symbols.pop() | ||
| self.node_stack.pop() | ||
| self.func_params = [] | ||
| self.func_buffer_map = {} | ||
| self.func_dict_attr = {} | ||
| self.func_var_env_dict = {} | ||
| # parser and analyzer | ||
| self._report_error = _report_error | ||
| self.analyzer = tvm.arith.Analyzer() | ||
|
|
||
| def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None): | ||
| """Creates a new scope | ||
|
|
||
| def new_scope(self, nodes=None): | ||
| """Creating a new scope""" | ||
| Note | ||
| ---- | ||
| This function is used for normal scopes that do not involve | ||
| a `with block` scope. Use `enter_block_scope` | ||
| for block scope cases. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| nodes : Optional[List[synr.ast.Node]] | ||
| The synr AST nodes in new scope | ||
| """ | ||
| if nodes is None: | ||
| nodes = [] | ||
| self.node_stack.append(list(reversed(nodes))) | ||
| self.symbols.append(dict()) | ||
|
|
||
| def update_symbol(self, name, symbol): | ||
| def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document the difference between a regular and block scope from the user perspective. When should a should |
||
| """Creates a new block scope, the function will call `enter_scope` implicitly | ||
| Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack | ||
| to maintain block info. | ||
|
|
||
| Note | ||
| ---- | ||
| This function should be used to handle a block scope, | ||
| aka the blocks that involve a `with block` scope. | ||
|
Comment on lines
+177
to
+178
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still confusing. What is a block? How does it differ from a regular scope?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The block structure is clearly described in the RFC https://discuss.tvm.apache.org/t/rfc-tensorir-a-schedulable-ir-for-tvm/7872
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you put that in the codebase somewhere then? It'll be hard for people to understand if they have to go to discus to get the full docs.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There will be docs about the TensorIR tvmscript langauge, but that should come as a separate PR. Additionally, this PR already contains test cases that covers the cases needed. Like in our previous parser code , there is less of a description of the language itself.. While I agree some examples would be helpful, it may not be necessary, assuming the maintainer have a good understanding of the Block structure itself, and future docs of the language
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with Tristan on this. I think that assuming that a maintainer has a good understanding of block structure already is not a good assumption to make. Having examples in the codebase makes code easily understandable and accessible to anyone who wants to read it, not just people who are familiar with the code. Since the project is rapidly growing and getting new contributors, it's important to make code understandable and accessible. Scaling the number of developers in TVM isn't sustainable without good documentation -- and good documentation includes having good comments. Ideally, the comments and the formal documentation would even be a little bit redundant.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @electriclilies. I agree with all you said in particular wrt to code readability. We already followed the principle of enforcing heavy documentation in the case of user facing code and making sure the overall logic flows well. Code readability also goes beyond the comments, a lot of efforts needs to be spent on API naming, intuitive callings and error handling. This PR does a lot of that, for example:
There is of course a tradeoff between the time we spend and amount of comments to be added and other efforts. On one hand we certainly want to add as many comments as possible. On the other hand, adding every code blocks may not be the best way of investing time -- we could spend more time on overall architectural correctness, the scaffolds(APIs and components) and other elements that makes the code more readable and maintainable. Comming back to a related example(e.g. reviewing the quantization code). It is certainly helpful to add examples about network patterns happened during the quantization process, values being involved and so on. But that may not be the most important thing for now, since we can focus on more important issues on readability and maintainability -- e.g. clarifying the key APIs, make sure they compose well and so on. Examples can then be added to places that could contain subtle set of logic to help clarify things. Right now we are prioritizing to add the examples to developer code paths that are more sutble, like the arithmetic modules and so on. In this particular case, an example can be suggested and checked in as followup PRs is more than welcomed.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think any of us disagree with the extreme importance of comprehensive docs, but we need think a bit deeper what kind docs really help. I would like to further specifically reiterate my understanding of categorization of helpful docs:
D1 directly helps a user to better understand how to make things work and how it works. It is clear that D1 is desirable. D2 helps developers to understand the design philosophy, how the codebase is structured, etc, so that more people could help better maintain the codebase. Without D2, people are unable to understand the key concepts - that is why Tianqi redirects us to the RFC, so that we could all understand what a "Block" is, etc. Without D2, no matter how many words the developer use to document private APIs, it is still painful to understand some data structures. On D3, we always insist that APIs to be at least somewhat documented, so that maintainers could get a brief sense what will be going on if we call an API. Assuming we have good D1 and D2, we could substantially lower the steep learning curve and makes understanding D3 much easier - the prerequisite is that maintainers should read D1 and D2 first, in our specific case, the RFC and related materials. I think everybody totally agrees with Lily's words, and trust me, nobody wants bring trouble to future maintainers :-) With D1 and D2 ready in place (after M1s and M2s merged), it will be much easier to understand the design philosophy. This is what we are doing in Ansor upstreaming too - we upstream many tutorials after the codebase is fully functioning. It is totally understandable that reviewers are feeling frustrated when not understanding the design philosophy, and that is what RFCs are for, aren't they :-) Next time, we would love to see such frustration converts to clear questions and answers. For design philosophy-related questions, I would love to redirect everybody to the RFC from the very beginning, and we shouldn't debate outside the RFC about topics like "why two scopes". Then should we include the RFC text into the inlined docs? I think it is debatable, and I prefer not to replicate. The basic reason is that we will have D1 and D2 in the end, and maintaining two copies is quite error-prone. An easier solution for maintenance is to add links to D1 and D2 on critical data structures. |
||
|
|
||
| Parameters | ||
| ---------- | ||
| nodes : Optional[List[synr.ast.Node]] | ||
| The synr AST nodes in new scope | ||
| """ | ||
| self.enter_scope(nodes) | ||
| # Create a new loop stack for the new block | ||
| self.loop_stack.append([]) | ||
| # Create a new BlockInfo for the new block | ||
| self.block_info_stack.append(BlockInfo()) | ||
|
|
||
| def exit_scope(self): | ||
| """Pop the inner most scope""" | ||
| self.symbols.pop() | ||
| self.node_stack.pop() | ||
|
|
||
| def exit_block_scope(self): | ||
| """Pop the inner most block scope, the function will call `exit_scope` implicitly""" | ||
| self.exit_scope() | ||
| # Pop loop stack | ||
| self.loop_stack.pop() | ||
| # Pop block_info | ||
| self.block_info_stack.pop() | ||
|
|
||
| def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node): | ||
| """Append a symbol into current scope""" | ||
| if isinstance(symbol, schedule.Buffer): | ||
| if isinstance(symbol, Buffer): | ||
| if name in self.symbols[0]: | ||
| self.parser.report_error("Duplicate Buffer name") | ||
| self.report_error("Duplicate Buffer name: " + symbol.name, node.span) | ||
| self.symbols[0][name] = symbol | ||
| else: | ||
| self.symbols[-1][name] = symbol | ||
|
|
||
| def remove_symbol(self, name): | ||
| def remove_symbol(self, name: str): | ||
| """Remove a symbol""" | ||
| for symbols in reversed(self.symbols): | ||
| if name in symbols: | ||
| symbols.pop(name) | ||
| return | ||
| raise RuntimeError("Internal error of tvm script parser: no symbol named" + name) | ||
| raise RuntimeError("Internal error of tvm script parser: no symbol named " + name) | ||
|
|
||
| def lookup_symbol(self, name): | ||
| def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]: | ||
| """Look up symbol by name""" | ||
| for symbols in reversed(self.symbols): | ||
| if name in symbols: | ||
| return symbols[name] | ||
| return None | ||
|
|
||
| def report_error(self, message, span): | ||
| self.parser.report_error(message, span) | ||
| def report_error(self, message: str, span: Union[Span, synr.ast.Span]): | ||
| self._report_error(message, span) | ||
|
|
||
| def current_block_scope(self) -> BlockInfo: | ||
| return self.block_info_stack[-1] | ||
Uh oh!
There was an error while loading. Please reload this page.