From 9ab2484bd2b463d06111776997bf689cd3afadac Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Aug 2022 10:36:05 -0700 Subject: [PATCH 1/3] Amend RFC0070 with DeclBuffer TVMScript syntax updates --- rfcs/0070-introducing-decl-buffer.md | 29 +++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/rfcs/0070-introducing-decl-buffer.md b/rfcs/0070-introducing-decl-buffer.md index 6ada7308..a07a40b7 100644 --- a/rfcs/0070-introducing-decl-buffer.md +++ b/rfcs/0070-introducing-decl-buffer.md @@ -146,9 +146,36 @@ Specifically, the updated flow of buffer flattening using `DeclBuffer` will be: with flattened indices. ## TVM script updates +* New statement `T.decl_buffer` will be introduced. It has the same interface as `T.buffer_decl`. +```python +def decl_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + axis_separators=None) -> Buffer: + ... +``` +It will be parsed to `DeclBuffer` node. + * `T.allocate` will return data variable instead of a buffer. If the subsequent program need to access the data variable as a buffer, it should use `T.decl_buffer` to declare the buffer. -* `T.buffer_decl` will be renamed to `T.decl_buffer`. +* As a syntax sugar to avoid writing both `T.allocate` and `T.decl_buffer` at the same time, +when the `data` parameter is not specified for `T.decl_buffer`, the buffer data will be +allocated implicitly. This means the following code snippets are equivalent: +``` +A_data = T.allocate(shape=[16], dtype="float32") +A = T.decl_buffer(shape=[16], dtype="float32", data=A_data) +``` +``` +A = T.decl_buffer(shape=[16], dtype="float32") +``` +* `T.buffer_decl` will be deprecated in favor of the explicit `T.decl_buffer`. ## TIR validation With `DeclBuffer` introduced, we can implement utilities for TIR validation. It will enforce that: From 7ce0fc137418c7b0095ff158586f4852d55faf7d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Aug 2022 10:50:11 -0700 Subject: [PATCH 2/3] Update 0070-introducing-decl-buffer.md --- rfcs/0070-introducing-decl-buffer.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/rfcs/0070-introducing-decl-buffer.md b/rfcs/0070-introducing-decl-buffer.md index a07a40b7..d38f9053 100644 --- a/rfcs/0070-introducing-decl-buffer.md +++ b/rfcs/0070-introducing-decl-buffer.md @@ -149,17 +149,17 @@ with flattened indices. * New statement `T.decl_buffer` will be introduced. It has the same interface as `T.buffer_decl`. ```python def decl_buffer( - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None) -> Buffer: - ... + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: Optional[List[int]] = None, +) -> Buffer: ... ``` It will be parsed to `DeclBuffer` node. From 588b71cd688ed1c8430c82b6c37e3d3d7adcbb87 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 15 Aug 2022 16:24:31 -0700 Subject: [PATCH 3/3] Update 0070-introducing-decl-buffer.md --- rfcs/0070-introducing-decl-buffer.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rfcs/0070-introducing-decl-buffer.md b/rfcs/0070-introducing-decl-buffer.md index d38f9053..c79f172e 100644 --- a/rfcs/0070-introducing-decl-buffer.md +++ b/rfcs/0070-introducing-decl-buffer.md @@ -87,7 +87,7 @@ Allocate { This can also be represented in TVMScript: ``` A_data = T.allocate(shape=..., dtype=...) -A = T.decl_buffer(data=A_data) +A = T.decl_buffer(shape=..., dtype=..., data=A_data) ``` ## Declaration of buffer alias @@ -132,8 +132,8 @@ After flattening: ``` @T.prim_func def elemwise(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): - A_flattened = T.decl_buffer(A.data, (256,), "float32") - C_flattened = T.decl_buffer(C.data, (256,), "float32") + A_flattened = T.decl_buffer(shape=(256,), dtype="float32", data=A.data) + C_flattened = T.decl_buffer(shape=(256,), dtype="float32", data=C.data) for i, j in T.grid(16, 16): C_flattened[i * 16 + j] = A[i * 16 + j] ```