From 1035b126665b35064670163771a408a75eb434e7 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 4 May 2024 16:55:48 +0800 Subject: [PATCH] [TVMScript] Fix error reporting inside Macro func --- python/tvm/script/parser/core/parser.py | 53 ++++++++++++++++++------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index b41a05689d45..0ecf669566a2 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -145,26 +145,27 @@ def __call__(self, *args, **kwargs): local_vars = param_binding.arguments parser = self._find_parser_def() - if self.hygienic: - saved_var_table = parser.var_table - parser.var_table = VarTable() + with parser.with_diag_source(self.source): + if self.hygienic: + saved_var_table = parser.var_table + parser.var_table = VarTable() - with parser.var_table.with_frame(): - for k, v in self.closure_vars.items(): - parser.var_table.add(k, v) - for k, v in local_vars.items(): - parser.var_table.add(k, v) + with parser.var_table.with_frame(): + for k, v in self.closure_vars.items(): + parser.var_table.add(k, v) + for k, v in local_vars.items(): + parser.var_table.add(k, v) - parse_result = self.parse_macro(parser) + parse_result = self.parse_macro(parser) - parser.var_table = saved_var_table + parser.var_table = saved_var_table - else: - with parser.var_table.with_frame(): - for k, v in local_vars.items(): - parser.var_table.add(k, v) + else: + with parser.var_table.with_frame(): + for k, v in local_vars.items(): + parser.var_table.add(k, v) - parse_result = self.parse_macro(parser) + parse_result = self.parse_macro(parser) return parse_result @@ -415,6 +416,28 @@ def pop_token(): return _deferred(pop_token) + def with_diag_source(self, source: Source): + """Add a new source as with statement. + + Parameters + ---------- + source : Source + The source for diagnostics. + + Returns + ------- + res : Any + The context with new source. + """ + + last_diag = self.diag + self.diag = Diagnostics(source) + + def pop_source(): + self.diag = last_diag + + return _deferred(pop_source) + def eval_expr( self, node: Union[doc.Expression, doc.expr],