diff --git a/rfcs/0070-introducing-decl-buffer.md b/rfcs/0070-introducing-decl-buffer.md index 6ada7308..c79f172e 100644 --- a/rfcs/0070-introducing-decl-buffer.md +++ b/rfcs/0070-introducing-decl-buffer.md @@ -87,7 +87,7 @@ Allocate { This can also be represented in TVMScript: ``` A_data = T.allocate(shape=..., dtype=...) -A = T.decl_buffer(data=A_data) +A = T.decl_buffer(shape=..., dtype=..., data=A_data) ``` ## Declaration of buffer alias @@ -132,8 +132,8 @@ After flattening: ``` @T.prim_func def elemwise(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): - A_flattened = T.decl_buffer(A.data, (256,), "float32") - C_flattened = T.decl_buffer(C.data, (256,), "float32") + A_flattened = T.decl_buffer(shape=(256,), dtype="float32", data=A.data) + C_flattened = T.decl_buffer(shape=(256,), dtype="float32", data=C.data) for i, j in T.grid(16, 16): C_flattened[i * 16 + j] = A[i * 16 + j] ``` @@ -146,9 +146,36 @@ Specifically, the updated flow of buffer flattening using `DeclBuffer` will be: with flattened indices. ## TVM script updates +* New statement `T.decl_buffer` will be introduced. It has the same interface as `T.buffer_decl`. +```python +def decl_buffer( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", + axis_separators: Optional[List[int]] = None, +) -> Buffer: ... +``` +It will be parsed to `DeclBuffer` node. + * `T.allocate` will return data variable instead of a buffer. If the subsequent program need to access the data variable as a buffer, it should use `T.decl_buffer` to declare the buffer. -* `T.buffer_decl` will be renamed to `T.decl_buffer`. +* As a syntax sugar to avoid writing both `T.allocate` and `T.decl_buffer` at the same time, +when the `data` parameter is not specified for `T.decl_buffer`, the buffer data will be +allocated implicitly. This means the following code snippets are equivalent: +``` +A_data = T.allocate(shape=[16], dtype="float32") +A = T.decl_buffer(shape=[16], dtype="float32", data=A_data) +``` +``` +A = T.decl_buffer(shape=[16], dtype="float32") +``` +* `T.buffer_decl` will be deprecated in favor of the explicit `T.decl_buffer`. ## TIR validation With `DeclBuffer` introduced, we can implement utilities for TIR validation. It will enforce that: