diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 88d68408381c..eff86a950b90 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -36,6 +36,6 @@ pip3 install \ pytest-xdist \ requests \ scipy \ - synr==0.3.0 \ + synr==0.4.0 \ six \ tornado diff --git a/python/gen_requirements.py b/python/gen_requirements.py index a9a86077816d..7470ccc92496 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -244,7 +244,7 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", "==0.3.0"), + ("synr", "==0.4.0"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 60fc49678866..51ee0aed982c 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -594,6 +594,19 @@ def transform_For(self, node): self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res + def transform_While(self, node): + """While visitor + AST abstract grammar: + While(expr condition, stmt* body) + """ + condition = self.transform(node.condition) + # body + self.context.enter_scope(nodes=node.body.stmts) + body = self.parse_body(node) + self.context.exit_scope() + + return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span)) + def transform_With(self, node): """With visitor AST abstract grammar: diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index eb200df0c599..44006239acfd 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -27,7 +27,7 @@ from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, Any -from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For +from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt from .stmt import ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index be31961e1e11..906dc258560a 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -170,6 +170,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const WhileNode* op) override; Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; Doc VisitStmt_(const BlockRealizeNode* op) override; @@ -830,6 +831,13 @@ Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) { return doc; } +Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) { + Doc doc; + doc << "while " << Print(op->condition) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + return doc; +} + Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) { Doc doc; doc << "ty." << runtime::DLDataType2String(node->dtype); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index e0f0c6d8cc5b..7c123afdc4d0 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3066,5 +3066,22 @@ def test_same_name_var(): assert out_str.find("i_") == -1 +@tvm.script.tir +def while_loop(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16,), "float32") + B = tir.match_buffer(b, (16,), "float32") + i = tir.alloc_buffer((), "int32", scope="local") + with tir.block([16]) as [vi]: + B[vi] = 0 + while i[()] < 10: + for j in range(16): + B[j] += A[j] + + +def test_while_loop(): + rt_func = tvm.script.from_source(tvm.script.asscript(while_loop, True)) + tvm.ir.assert_structural_equal(while_loop, rt_func) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index 753d17d8afe5..01d5587e70ad 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -30,7 +30,7 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.3.0 +python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.0 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in # Jenkinsfile. We expect config.cmake to be present from pack_lib().