From b130533f023b901882c315895d8c9473e1b136ea Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 9 Aug 2022 17:05:12 -0400 Subject: [PATCH 1/4] Add frame definition --- include/tvm/script/printer/frame.h | 106 +++++++++++++++++++++++++++++ src/script/printer/frame.cc | 35 ++++++++++ 2 files changed, 141 insertions(+) create mode 100644 include/tvm/script/printer/frame.h create mode 100644 src/script/printer/frame.cc diff --git a/include/tvm/script/printer/frame.h b/include/tvm/script/printer/frame.h new file mode 100644 index 000000000000..0be4e644fe8f --- /dev/null +++ b/include/tvm/script/printer/frame.h @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_FRAME_H_ +#define TVM_SCRIPT_PRINTER_FRAME_H_ + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class FrameNode : public Object { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + virtual ~FrameNode() = default; + + template + void AddCallback(TCallback&& cb) { + callbacks_.emplace_back(std::forward(cb)); + } + + virtual void EnterWithScope() {} + + virtual void ExitWithScope() { + for (const std::function& callback : callbacks_) { + callback(); + } + } + + static constexpr const char* _type_key = "script.printer.Frame"; + TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); + + private: + std::vector> callbacks_; +}; + +class Frame : public ObjectRef { + protected: + Frame() = default; + + public: + virtual ~Frame() = default; + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); +}; + +class MetadataFrameNode : public FrameNode { + public: + Array metadata; + + void VisitAttrs(tvm::AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("metadata", &metadata); + } + + static constexpr const char* _type_key = "script.printer.MetadataFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode); +}; + +class MetadataFrame : public Frame { + public: + MetadataFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame, MetadataFrameNode); +}; + +class VarDefFrameNode : public FrameNode { + public: + Array stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("stmts", &stmts); + } + + static constexpr const char* _type_key = "script.printer.VarDefFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode); +}; + +class VarDefFrame : public Frame { + public: + explicit VarDefFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame, VarDefFrameNode); +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_FRAME_H_ diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc new file mode 100644 index 000000000000..bf8e494f0341 --- /dev/null +++ b/src/script/printer/frame.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +namespace tvm { +namespace script { +namespace printer { + +MetadataFrame::MetadataFrame() : MetadataFrame(make_object()) {} + +VarDefFrame::VarDefFrame() : VarDefFrame(make_object()) {} + +TVM_REGISTER_NODE_TYPE(FrameNode); +TVM_REGISTER_NODE_TYPE(MetadataFrameNode); +TVM_REGISTER_NODE_TYPE(VarDefFrameNode); + +} // namespace printer +} // namespace script +} // namespace tvm From 5b0977c03fa00ebffb8bc1819f27ae23003f7dc0 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 10 Aug 2022 16:35:04 -0400 Subject: [PATCH 2/4] Add doc and Python binding --- include/tvm/script/printer/frame.h | 36 ++++++++- python/tvm/script/printer/frame.py | 81 +++++++++++++++++++ src/script/printer/frame.cc | 17 ++++ .../unittest/test_tvmscript_printer_frame.py | 60 ++++++++++++++ 4 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 python/tvm/script/printer/frame.py create mode 100644 tests/python/unittest/test_tvmscript_printer_frame.py diff --git a/include/tvm/script/printer/frame.h b/include/tvm/script/printer/frame.h index 0be4e644fe8f..017fd10fb311 100644 --- a/include/tvm/script/printer/frame.h +++ b/include/tvm/script/printer/frame.h @@ -22,27 +22,45 @@ #include #include +#include +#include + namespace tvm { namespace script { namespace printer { +/*! + * Frame is the core data structure for semantic information + * when printing IR graph into TVMScript code. + */ class FrameNode : public Object { public: void VisitAttrs(tvm::AttrVisitor* v) {} virtual ~FrameNode() = default; + /*! + * \brief Add a callback function to be called when this frame exits. + * \param cb The callback function. It should have signature void(). + */ template void AddCallback(TCallback&& cb) { callbacks_.emplace_back(std::forward(cb)); } + /*! + * \brief Method that's called when Frame enters the scope. + */ virtual void EnterWithScope() {} + /*! + * \brief Method that's called when Frame exits the scope. + */ virtual void ExitWithScope() { for (const std::function& callback : callbacks_) { callback(); } + callbacks_.clear(); } static constexpr const char* _type_key = "script.printer.Frame"; @@ -52,6 +70,9 @@ class FrameNode : public Object { std::vector> callbacks_; }; +/*! + * \brief Reference type of FrameNode + */ class Frame : public ObjectRef { protected: Frame() = default; @@ -61,6 +82,9 @@ class Frame : public ObjectRef { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); }; +/*! + * \brief MetadataFrame contains information like contant parameter array. + */ class MetadataFrameNode : public FrameNode { public: Array metadata; @@ -74,12 +98,19 @@ class MetadataFrameNode : public FrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode); }; +/*! + * \brief Reference type of MetadataFrameNode + */ class MetadataFrame : public Frame { public: MetadataFrame(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame, MetadataFrameNode); }; +/*! + * \brief VarDefFrame contains information about the free variables that needs to be defined + * at the beginning of the printed snippet. + */ class VarDefFrameNode : public FrameNode { public: Array stmts; @@ -93,9 +124,12 @@ class VarDefFrameNode : public FrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode); }; +/*! + * \brief Reference type of VarDefFrameNode + */ class VarDefFrame : public Frame { public: - explicit VarDefFrame(); + VarDefFrame(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame, VarDefFrameNode); }; diff --git a/python/tvm/script/printer/frame.py b/python/tvm/script/printer/frame.py new file mode 100644 index 000000000000..2458d52235c8 --- /dev/null +++ b/python/tvm/script/printer/frame.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Frame is the core data structure for semantic information when printing +IR graph into TVMScript code. +""" + +from typing import Callable, Sequence + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.script.printer.doc import StmtDoc + +from . import _ffi_api + + +class Frame(Object): + """ + Frame is the core data structure for semantic information + when printing IR graph into TVMScript code. + + Frame base class manages a list of callbacks to be executed + when frame goes out of scope. + """ + + def add_callback(self, callback: Callable[[], None]) -> None: + """ + Adds a callback function to be executed when frame goes out of scope. + + Parameters + ---------- + callback : Callable[[], None] + The callback function. + """ + _ffi_api.FrameAddCallback(self, callback) + + def __enter__(self): + _ffi_api.FrameEnterWithScope(self) + return self + + def __exit__(self, *exception_info): + _ffi_api.FrameExitWithScope(self) + + +@register_object("script.printer.MetadataFrame") +class MetadataFrame(Frame): + """ + MetadataFrame contains information like contant parameter array. + """ + + metadata: Sequence[Object] + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) + + +@register_object("script.printer.VarDefFrame") +class VarDefFrame(Frame): + """ + VarDefFrame contains information about the free variables that needs to + be defined at the beginning of the printed snippet. + """ + + stmts: Sequence[StmtDoc] + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc index bf8e494f0341..c96723a490dc 100644 --- a/src/script/printer/frame.cc +++ b/src/script/printer/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include namespace tvm { @@ -27,8 +28,24 @@ MetadataFrame::MetadataFrame() : MetadataFrame(make_object()) VarDefFrame::VarDefFrame() : VarDefFrame(make_object()) {} TVM_REGISTER_NODE_TYPE(FrameNode); +TVM_REGISTER_GLOBAL("script.printer.FrameAddCallback") + .set_body_typed([](Frame frame, runtime::TypedPackedFunc callback) { + frame->AddCallback(callback); + }); +TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope").set_body_typed([](Frame frame) { + frame->EnterWithScope(); +}); +TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope").set_body_typed([](Frame frame) { + frame->ExitWithScope(); +}); + TVM_REGISTER_NODE_TYPE(MetadataFrameNode); +TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() { + return MetadataFrame(); +}); + TVM_REGISTER_NODE_TYPE(VarDefFrameNode); +TVM_REGISTER_GLOBAL("script.printer.VarDefFrame").set_body_typed([]() { return VarDefFrame(); }); } // namespace printer } // namespace script diff --git a/tests/python/unittest/test_tvmscript_printer_frame.py b/tests/python/unittest/test_tvmscript_printer_frame.py new file mode 100644 index 000000000000..3455889c0119 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_frame.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm.script.printer.frame import MetadataFrame + + +def test_frame_add_callback(): + frame = MetadataFrame() + + flag = 0 + + def callback1(): + nonlocal flag + flag += 1 + + def callback2(): + nonlocal flag + flag += 5 + + frame.add_callback(callback1) + with frame: + frame.add_callback(callback2) + assert flag == 0 + + assert flag == 6 + + +def test_frame_clear_callbacks_after_exit(): + frame = MetadataFrame() + + flag = 0 + + def callback(): + nonlocal flag + flag += 1 + + frame.add_callback(callback) + + with frame: + pass + + assert flag == 1 + + with frame: + pass + + assert flag == 1 From 38ebb163accc84b3297baee2326fd82a625d8dce Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 10 Aug 2022 23:47:35 -0400 Subject: [PATCH 3/4] Rename add_callback to add_exit_callback --- include/tvm/script/printer/frame.h | 2 +- python/tvm/script/printer/frame.py | 4 ++-- src/script/printer/frame.cc | 4 ++-- tests/python/unittest/test_tvmscript_printer_frame.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/tvm/script/printer/frame.h b/include/tvm/script/printer/frame.h index 017fd10fb311..407ad16007e9 100644 --- a/include/tvm/script/printer/frame.h +++ b/include/tvm/script/printer/frame.h @@ -44,7 +44,7 @@ class FrameNode : public Object { * \param cb The callback function. It should have signature void(). */ template - void AddCallback(TCallback&& cb) { + void AddExitCallback(TCallback&& cb) { callbacks_.emplace_back(std::forward(cb)); } diff --git a/python/tvm/script/printer/frame.py b/python/tvm/script/printer/frame.py index 2458d52235c8..9df33c5fb5bc 100644 --- a/python/tvm/script/printer/frame.py +++ b/python/tvm/script/printer/frame.py @@ -37,7 +37,7 @@ class Frame(Object): when frame goes out of scope. """ - def add_callback(self, callback: Callable[[], None]) -> None: + def add_exit_callback(self, callback: Callable[[], None]) -> None: """ Adds a callback function to be executed when frame goes out of scope. @@ -46,7 +46,7 @@ def add_callback(self, callback: Callable[[], None]) -> None: callback : Callable[[], None] The callback function. """ - _ffi_api.FrameAddCallback(self, callback) + _ffi_api.FrameAddExitCallback(self, callback) def __enter__(self): _ffi_api.FrameEnterWithScope(self) diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc index c96723a490dc..6582a5be50b9 100644 --- a/src/script/printer/frame.cc +++ b/src/script/printer/frame.cc @@ -28,9 +28,9 @@ MetadataFrame::MetadataFrame() : MetadataFrame(make_object()) VarDefFrame::VarDefFrame() : VarDefFrame(make_object()) {} TVM_REGISTER_NODE_TYPE(FrameNode); -TVM_REGISTER_GLOBAL("script.printer.FrameAddCallback") +TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback") .set_body_typed([](Frame frame, runtime::TypedPackedFunc callback) { - frame->AddCallback(callback); + frame->AddExitCallback(callback); }); TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope").set_body_typed([](Frame frame) { frame->EnterWithScope(); diff --git a/tests/python/unittest/test_tvmscript_printer_frame.py b/tests/python/unittest/test_tvmscript_printer_frame.py index 3455889c0119..bd98d6445644 100644 --- a/tests/python/unittest/test_tvmscript_printer_frame.py +++ b/tests/python/unittest/test_tvmscript_printer_frame.py @@ -30,9 +30,9 @@ def callback2(): nonlocal flag flag += 5 - frame.add_callback(callback1) + frame.add_exit_callback(callback1) with frame: - frame.add_callback(callback2) + frame.add_exit_callback(callback2) assert flag == 0 assert flag == 6 @@ -47,7 +47,7 @@ def callback(): nonlocal flag flag += 1 - frame.add_callback(callback) + frame.add_exit_callback(callback) with frame: pass From 4a4cd19d1cb0fbd3e224d752ee698ed8790d7754 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 11 Aug 2022 19:27:42 -0400 Subject: [PATCH 4/4] Use set_body_method and fix lints --- python/tvm/script/printer/frame.py | 10 +++++----- src/script/printer/frame.cc | 10 ++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/tvm/script/printer/frame.py b/python/tvm/script/printer/frame.py index 9df33c5fb5bc..c967382b8b5d 100644 --- a/python/tvm/script/printer/frame.py +++ b/python/tvm/script/printer/frame.py @@ -46,14 +46,14 @@ def add_exit_callback(self, callback: Callable[[], None]) -> None: callback : Callable[[], None] The callback function. """ - _ffi_api.FrameAddExitCallback(self, callback) + _ffi_api.FrameAddExitCallback(self, callback) # type: ignore # pylint: disable=no-member def __enter__(self): - _ffi_api.FrameEnterWithScope(self) + _ffi_api.FrameEnterWithScope(self) # type: ignore # pylint: disable=no-member return self def __exit__(self, *exception_info): - _ffi_api.FrameExitWithScope(self) + _ffi_api.FrameExitWithScope(self) # type: ignore # pylint: disable=no-member @register_object("script.printer.MetadataFrame") @@ -65,7 +65,7 @@ class MetadataFrame(Frame): metadata: Sequence[Object] def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) + self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) # type: ignore # pylint: disable=no-member @register_object("script.printer.VarDefFrame") @@ -78,4 +78,4 @@ class VarDefFrame(Frame): stmts: Sequence[StmtDoc] def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) + self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) # type: ignore # pylint: disable=no-member diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc index 6582a5be50b9..b342c7c886c7 100644 --- a/src/script/printer/frame.cc +++ b/src/script/printer/frame.cc @@ -32,12 +32,10 @@ TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback") .set_body_typed([](Frame frame, runtime::TypedPackedFunc callback) { frame->AddExitCallback(callback); }); -TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope").set_body_typed([](Frame frame) { - frame->EnterWithScope(); -}); -TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope").set_body_typed([](Frame frame) { - frame->ExitWithScope(); -}); +TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope") + .set_body_method(&FrameNode::EnterWithScope); +TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope") + .set_body_method(&FrameNode::ExitWithScope); TVM_REGISTER_NODE_TYPE(MetadataFrameNode); TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() {