From 4e5a2ac8a8bc9b39d02ac2c803a049d1c4d24b2c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 24 Oct 2022 19:16:50 -0700 Subject: [PATCH] [TVMScript] TIR parser This PR introduces parser for TIR as part of #12442 Co-authored-by: yongwww --- python/tvm/script/_parser/__init__.py | 3 +- python/tvm/script/_parser/core/parser.py | 15 + python/tvm/script/_parser/tir/__init__.py | 25 + python/tvm/script/_parser/tir/entry.py | 108 ++++ python/tvm/script/_parser/tir/operation.py | 85 ++++ python/tvm/script/_parser/tir/parser.py | 468 ++++++++++++++++++ .../unittest/test_tvmscript_parser_tir.py | 63 +++ 7 files changed, 766 insertions(+), 1 deletion(-) create mode 100644 python/tvm/script/_parser/tir/__init__.py create mode 100644 python/tvm/script/_parser/tir/entry.py create mode 100644 python/tvm/script/_parser/tir/operation.py create mode 100644 python/tvm/script/_parser/tir/parser.py create mode 100644 tests/python/unittest/test_tvmscript_parser_tir.py diff --git a/python/tvm/script/_parser/__init__.py b/python/tvm/script/_parser/__init__.py index fd4e45818c20..38c8b88cc7ca 100644 --- a/python/tvm/script/_parser/__init__.py +++ b/python/tvm/script/_parser/__init__.py @@ -15,5 +15,6 @@ # specific language governing permissions and limitations # under the Licens. """The parser""" -from . import _core, ir +from . import _core, ir, tir from .ir import ir_module +from .tir import prim_func diff --git a/python/tvm/script/_parser/core/parser.py b/python/tvm/script/_parser/core/parser.py index daf95cb3cd1b..c6d43f11cbf5 100644 --- a/python/tvm/script/_parser/core/parser.py +++ b/python/tvm/script/_parser/core/parser.py @@ -571,6 +571,21 @@ def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name """ return _dispatch(self, "Assign")(self, node) + def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint: disable=invalid-name + """The general annotated assign visiting method. + + Parameters + ---------- + node : doc.Assign + The doc AST annotated assign node. + + Returns + ------- + res : Any + The visiting result. + """ + return _dispatch(self, "AnnAssign")(self, node) + def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name """The general expression visiting method. diff --git a/python/tvm/script/_parser/tir/__init__.py b/python/tvm/script/_parser/tir/__init__.py new file mode 100644 index 000000000000..7754baf087f5 --- /dev/null +++ b/python/tvm/script/_parser/tir/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The tir parser""" + +from ...ir_builder.tir import * # pylint: disable=redefined-builtin +from ...ir_builder.tir import ir as _tir +from . import operation as _operation +from . import parser as _parser +from .entry import Buffer, Ptr, prim_func + +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] diff --git a/python/tvm/script/_parser/tir/entry.py b/python/tvm/script/_parser/tir/entry.py new file mode 100644 index 000000000000..632b87aa24dc --- /dev/null +++ b/python/tvm/script/_parser/tir/entry.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The entry point of TVM parser for tir.""" + +import inspect +from typing import Callable, Union + +from tvm.tir import Buffer, PrimFunc + +from ...ir_builder.tir import buffer_decl, ptr +from .._core import parse, utils +from ..ir import is_defined_in_class + + +def prim_func(func: Callable) -> Union[PrimFunc, Callable]: + """The parsing method for tir prim func, by using `@prim_func` as decorator. + + Parameters + ---------- + func : Callable + The function to be parsed as prim func. + + Returns + ------- + res : Union[PrimFunc, Callable] + The parsed tir prim func. + """ + if not inspect.isfunction(func): + raise TypeError(f"Expect a function, but got: {func}") + if is_defined_in_class(inspect.stack()): + return func + return parse(func, utils.inspect_function_capture(func)) + + +setattr(prim_func, "dispatch_token", "tir") + + +class BufferProxy: + """Buffer proxy class for constructing tir buffer. + Overload __call__ and __getitem__ to support syntax as T.Buffer() and T.Buffer[]. + """ + + def __call__( + self, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Buffer: + return buffer_decl( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + def __getitem__(self, keys) -> Buffer: + if not isinstance(keys, tuple): + return self(keys) + if len(keys) >= 2 and not isinstance(keys[1], str): + return self(keys) + return self(*keys) # pylint: disable=no-member # type: ignore + + +class PtrProxy: + """Ptr proxy class for constructing tir pointer. + Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr[]. + """ + + def __call__(self, dtype, storage_scope="global"): + if callable(dtype): + dtype = dtype().dtype + return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + + def __getitem__(self, keys): + if not isinstance(keys, tuple): + return self(keys) + return self(*keys) + + +Buffer = BufferProxy() # pylint: disable=invalid-name +Ptr = PtrProxy() # pylint: disable=invalid-name diff --git a/python/tvm/script/_parser/tir/operation.py b/python/tvm/script/_parser/tir/operation.py new file mode 100644 index 000000000000..ed8f07a06369 --- /dev/null +++ b/python/tvm/script/_parser/tir/operation.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The tir expression operation registration""" + +from typing import Type + +from tvm import tir +from tvm.tir import IntImm + +from .._core import OpMethod, doc, register_op + + +def _register_expr_op(ty: Type): # pylint: disable=invalid-name + ty._dispatch_type = ty # pylint: disable=protected-access + + def _and(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + return tir.And(a, b) + + def _or(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + return tir.Or(a, b) + + def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name + register_op(ty, op, i)(m) + + for i in [0, 1]: + # Case 1. binop + r(doc.Add, i, lambda a, b: a + b) + r(doc.Sub, i, lambda a, b: a - b) + r(doc.Mult, i, lambda a, b: a * b) + r(doc.Div, i, lambda a, b: a / b) + r(doc.FloorDiv, i, lambda a, b: a // b) + r(doc.Mod, i, lambda a, b: a % b) + r(doc.LShift, i, lambda a, b: a << b) + r(doc.RShift, i, lambda a, b: a >> b) + r(doc.BitOr, i, lambda a, b: a | b) + r(doc.BitXor, i, lambda a, b: a ^ b) + r(doc.BitAnd, i, lambda a, b: a & b) + # doc.MatMult <-- not implemented + # doc.Pow <-- not implemented + # Case 2. cmpop + r(doc.Eq, i, tir.EQ) + r(doc.NotEq, i, tir.NE) + r(doc.Lt, i, tir.LT) + r(doc.LtE, i, tir.LE) + r(doc.Gt, i, tir.GT) + r(doc.GtE, i, tir.GE) + # doc.Is <-- not implemented + # doc.IsNot <-- not implemented + # doc.In <-- not implemented + # doc.NotIn <-- not implemented + # Case 3. boolop + r(doc.And, i, _and) + r(doc.Or, i, _or) + for i in [0]: + # Case 4. unaryop + r(doc.Invert, i, lambda a: ~a) + r(doc.Not, i, tir.Not) + r(doc.UAdd, i, lambda a: +a) + r(doc.USub, i, lambda a: -a) + + +_register_expr_op(tir.PrimExpr) +_register_expr_op(tir.IterVar) diff --git a/python/tvm/script/_parser/tir/parser.py b/python/tvm/script/_parser/tir/parser.py new file mode 100644 index 000000000000..909238563fab --- /dev/null +++ b/python/tvm/script/_parser/tir/parser.py @@ -0,0 +1,468 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The base parser for tir""" + +import contextlib +from functools import partial +from typing import Any + +from tvm.ir import PrimType +from tvm.tir import Buffer, IterVar, PrimExpr, Var + +from ...ir_builder import tir as T +from ...ir_builder.base import IRBuilder +from ...ir_builder.base import IRBuilderFrame as Frame +from .._core import Parser, dispatch, doc + + +def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + """Value binding methods when parsing with statement. + e.g. binding i, j, k with T.grid(128, 128, 128), when parsing + with T.grid(128, 128, 18) as i, j, k. + + Parameters + ---------- + self : Parser + The current parser. + + node : doc.expr + The doc AST expression node for error reporting. + + var_name : str + The variable name. + + value : Any + The value to be bound with. + + Returns + ------- + res : Any + The bound value. + """ + if isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, (Buffer, Var)): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement") + raise NotImplementedError + + +def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + """Value binding methods when parsing for statement. + e.g. binding i, j, k with T.grid(128, 128, 128), when parsing + for i, j, k in T.grid(128, 128, 128). + + Parameters + ---------- + self : Parser + The current parser. + + node : doc.expr + The doc AST expression node for error reporting. + + var_name : str + The variable name. + + value : Any + The value to be bound with. + + Returns + ------- + res : Any + The bound value. + """ + if isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_for_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Var): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement") + raise NotImplementedError + + +def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + """Value binding methods when parsing assign statement. + e.g. binding vi, vj, vk with T.axis.remap("SSR", [i, j, k]), when parsing + vi, vj, vk = T.axis.remap("SSR", [i, j, k]). + + Parameters + ---------- + self : Parser + The current parser. + + node : doc.expr + The doc AST expression node for error reporting. + + var_name : str + The variable name. + + value : Any + The value to be bound with. + + Returns + ------- + res : Any + The bound value. + """ + if isinstance(value, T.inline): + return value.value + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_assign_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Frame): + value.add_callback(partial(value.__exit__, None, None, None)) + res = value.__enter__() + IRBuilder.name(var_name, res) + return res + elif isinstance(value, (Buffer, IterVar)) or ( + isinstance(value, Var) and not self.var_table.exist(value) + ): + IRBuilder.name(var_name, value) + return value + elif isinstance(value, PrimExpr): + var = T.var(value.dtype) + IRBuilder.name(var_name, var) + frame = T.let(var, value) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + return var + return value + + +@dispatch.register(token="tir", type_name="For") +def visit_for(self: Parser, node: doc.For) -> None: + """The for visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.For + The doc AST for node. + """ + for_frame = self.eval_expr(node.iter) + if not isinstance(for_frame, T.frame.ForFrame): + self.report_error( + node.iter, + "Expect the for loop to be one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + ) + with self.var_table.with_frame(): + with for_frame as iters: + self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="While") +def visit_while(self: Parser, node: doc.While) -> None: + """The while visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.While + The doc AST while node. + """ + with self.var_table.with_frame(): + cond = self.eval_expr(node.test) + with T.While(cond): + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + """The assign visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Assign + The doc AST assign node. + """ + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AugAssign") +def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None: + """The augmented assign visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.AugAssign + The doc AST augmented assign node. + """ + lhs_pos = ( + node.target.lineno, + node.target.col_offset, + node.target.end_lineno, + node.target.end_col_offset, + ) + rhs_pos = ( + node.value.lineno, + node.value.col_offset, + node.value.end_lineno, + node.value.end_col_offset, + ) + node.target.ctx = doc.Load(*lhs_pos) + with self.var_table.with_frame(): + lhs_name = "__tvm_tmp_value_aug_assign_lhs" + rhs_name = "__tvm_tmp_value_aug_assign_rhs" + lhs_expr = self.eval_expr(node.target) + rhs_expr = self.eval_expr(node.value) + self.var_table.add(lhs_name, lhs_expr) + self.var_table.add(rhs_name, rhs_expr) + op = doc.BinOp( + doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos), + node.op, + doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos), + *lhs_pos, + ) + rhs = self.eval_expr(op) + lhs = node.target + lhs.ctx = doc.Store(*lhs_pos) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + """The annotated assign visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.AnnAssign + The doc AST annotated assign node. + """ + lhs = node.target + rhs = self.eval_expr(node.value) + ann_var = self.visit_tvm_annotation(node.annotation) + if not isinstance(ann_var, Var): + self.report_error(node.annotation, "Annotation should be Var") + self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) + frame = T.let(ann_var, rhs) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + """The with visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.With + The doc AST with node. + """ + with contextlib.ExitStack() as stack: + stack.enter_context(self.var_table.with_frame()) + for item in node.items: + frame = self.eval_expr(item.context_expr) + if not isinstance(frame, Frame): + self.report_error( + item.context_expr, "Invalid context expression in the with-statement." + ) + rhs = stack.enter_context(frame) + if item.optional_vars is not None: + self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + """The function definition visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.FunctionDef + The doc AST function definition node. + """ + with self.var_table.with_frame(): + self.var_table.add("range", T.serial) + with T.prim_func(): + T.func_name(node.name) + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + T.func_ret(ret_type) + with self.with_dispatch_token("tir"): + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + """The arguments visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.arguments + The doc AST arguments node. + """ + # TODO: handle different types of arguments: + # - vararg: arg | None + # - kwonlyargs: list[arg] + # - kw_defaults: list[expr | None] + # - kwarg: arg | None + # - defaults: list[expr] + # - posonlyargs: list[arg] + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation)) + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="tir", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr): + """The TVM annotation visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.expr + The doc AST expr node. + """ + annotation = self.eval_expr(node) + if callable(annotation): + annotation = annotation() + return annotation + + +@dispatch.register(token="tir", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + """The expr statement visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Expr + The doc AST Expr node. + """ + res = self.eval_expr(node.value) + if isinstance(res, Frame): + res.add_callback(partial(res.__exit__, None, None, None)) + res.__enter__() + + +@dispatch.register(token="tir", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + """The if visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.If + The doc AST if node. + """ + with self.var_table.with_frame(): + with T.If(self.eval_expr(node.test)): + with T.Then(): + self.visit_body(node.body) + if node.orelse: + with T.Else(): + self.visit_body(node.orelse) + + +@dispatch.register(token="tir", type_name="Assert") +def visit_assert(self: Parser, node: doc.Assert) -> None: + """The assert visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Assert + The doc AST assert node. + """ + cond = self.eval_expr(node.test) + msg = self.eval_expr(node.msg) + frame = T.Assert(cond, msg) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="Return") +def visit_return(self: Parser, node: doc.Return) -> None: + """The return visiting method for tir. + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + self.report_error(node, "Return is not allowed.") diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py new file mode 100644 index 000000000000..cfa1dc62b31b --- /dev/null +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unittests for tvm.script.parser.tir""" + +import pytest +import inspect +import tvm.testing +from tvm.script._parser import tir as T +from tvm import ir, tir + + +def test_tir_buffer_proxy(): + buffer_0 = T.Buffer((128, 128), "float32") + assert ( + isinstance(buffer_0, tir.Buffer) + and list(buffer_0.shape) == [128, 128] + and buffer_0.dtype == "float32" + ) + + buffer_1 = T.Buffer[(64, 64, 64), "int32"] + assert ( + isinstance(buffer_1, tir.Buffer) + and list(buffer_1.shape) == [64, 64, 64] + and buffer_1.dtype == "int32" + ) + + +def test_tir_ptr_proxy(): + ptr_0 = T.Ptr("int32", "global") + assert ( + isinstance(ptr_0, tir.Var) + and ptr_0.dtype == "handle" + and isinstance(ptr_0.type_annotation, ir.PointerType) + and ptr_0.type_annotation.element_type == ir.PrimType("int32") + and ptr_0.type_annotation.storage_scope == "global" + ) + + ptr_1 = T.Ptr["float32", "shared"] + assert ( + isinstance(ptr_1, tir.Var) + and ptr_1.dtype == "handle" + and isinstance(ptr_1.type_annotation, ir.PointerType) + and ptr_1.type_annotation.element_type == ir.PrimType("float32") + and ptr_1.type_annotation.storage_scope == "shared" + ) + + +if __name__ == "__main__": + tvm.testing.main()