diff --git a/include/tvm/script/printer/frame.h b/include/tvm/script/printer/frame.h new file mode 100644 index 000000000000..407ad16007e9 --- /dev/null +++ b/include/tvm/script/printer/frame.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_FRAME_H_ +#define TVM_SCRIPT_PRINTER_FRAME_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * Frame is the core data structure for semantic information + * when printing IR graph into TVMScript code. + */ +class FrameNode : public Object { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + virtual ~FrameNode() = default; + + /*! + * \brief Add a callback function to be called when this frame exits. + * \param cb The callback function. It should have signature void(). + */ + template + void AddExitCallback(TCallback&& cb) { + callbacks_.emplace_back(std::forward(cb)); + } + + /*! + * \brief Method that's called when Frame enters the scope. + */ + virtual void EnterWithScope() {} + + /*! + * \brief Method that's called when Frame exits the scope. + */ + virtual void ExitWithScope() { + for (const std::function& callback : callbacks_) { + callback(); + } + callbacks_.clear(); + } + + static constexpr const char* _type_key = "script.printer.Frame"; + TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object); + + private: + std::vector> callbacks_; +}; + +/*! + * \brief Reference type of FrameNode + */ +class Frame : public ObjectRef { + protected: + Frame() = default; + + public: + virtual ~Frame() = default; + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); +}; + +/*! + * \brief MetadataFrame contains information like contant parameter array. + */ +class MetadataFrameNode : public FrameNode { + public: + Array metadata; + + void VisitAttrs(tvm::AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("metadata", &metadata); + } + + static constexpr const char* _type_key = "script.printer.MetadataFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode); +}; + +/*! + * \brief Reference type of MetadataFrameNode + */ +class MetadataFrame : public Frame { + public: + MetadataFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame, MetadataFrameNode); +}; + +/*! + * \brief VarDefFrame contains information about the free variables that needs to be defined + * at the beginning of the printed snippet. + */ +class VarDefFrameNode : public FrameNode { + public: + Array stmts; + + void VisitAttrs(tvm::AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("stmts", &stmts); + } + + static constexpr const char* _type_key = "script.printer.VarDefFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode); +}; + +/*! + * \brief Reference type of VarDefFrameNode + */ +class VarDefFrame : public Frame { + public: + VarDefFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame, VarDefFrameNode); +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_FRAME_H_ diff --git a/python/tvm/script/printer/frame.py b/python/tvm/script/printer/frame.py new file mode 100644 index 000000000000..c967382b8b5d --- /dev/null +++ b/python/tvm/script/printer/frame.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Frame is the core data structure for semantic information when printing +IR graph into TVMScript code. +""" + +from typing import Callable, Sequence + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.script.printer.doc import StmtDoc + +from . import _ffi_api + + +class Frame(Object): + """ + Frame is the core data structure for semantic information + when printing IR graph into TVMScript code. + + Frame base class manages a list of callbacks to be executed + when frame goes out of scope. + """ + + def add_exit_callback(self, callback: Callable[[], None]) -> None: + """ + Adds a callback function to be executed when frame goes out of scope. + + Parameters + ---------- + callback : Callable[[], None] + The callback function. + """ + _ffi_api.FrameAddExitCallback(self, callback) # type: ignore # pylint: disable=no-member + + def __enter__(self): + _ffi_api.FrameEnterWithScope(self) # type: ignore # pylint: disable=no-member + return self + + def __exit__(self, *exception_info): + _ffi_api.FrameExitWithScope(self) # type: ignore # pylint: disable=no-member + + +@register_object("script.printer.MetadataFrame") +class MetadataFrame(Frame): + """ + MetadataFrame contains information like contant parameter array. + """ + + metadata: Sequence[Object] + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) # type: ignore # pylint: disable=no-member + + +@register_object("script.printer.VarDefFrame") +class VarDefFrame(Frame): + """ + VarDefFrame contains information about the free variables that needs to + be defined at the beginning of the printed snippet. + """ + + stmts: Sequence[StmtDoc] + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) # type: ignore # pylint: disable=no-member diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc new file mode 100644 index 000000000000..b342c7c886c7 --- /dev/null +++ b/src/script/printer/frame.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +MetadataFrame::MetadataFrame() : MetadataFrame(make_object()) {} + +VarDefFrame::VarDefFrame() : VarDefFrame(make_object()) {} + +TVM_REGISTER_NODE_TYPE(FrameNode); +TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback") + .set_body_typed([](Frame frame, runtime::TypedPackedFunc callback) { + frame->AddExitCallback(callback); + }); +TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope") + .set_body_method(&FrameNode::EnterWithScope); +TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope") + .set_body_method(&FrameNode::ExitWithScope); + +TVM_REGISTER_NODE_TYPE(MetadataFrameNode); +TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() { + return MetadataFrame(); +}); + +TVM_REGISTER_NODE_TYPE(VarDefFrameNode); +TVM_REGISTER_GLOBAL("script.printer.VarDefFrame").set_body_typed([]() { return VarDefFrame(); }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_frame.py b/tests/python/unittest/test_tvmscript_printer_frame.py new file mode 100644 index 000000000000..bd98d6445644 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_frame.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm.script.printer.frame import MetadataFrame + + +def test_frame_add_callback(): + frame = MetadataFrame() + + flag = 0 + + def callback1(): + nonlocal flag + flag += 1 + + def callback2(): + nonlocal flag + flag += 5 + + frame.add_exit_callback(callback1) + with frame: + frame.add_exit_callback(callback2) + assert flag == 0 + + assert flag == 6 + + +def test_frame_clear_callbacks_after_exit(): + frame = MetadataFrame() + + flag = 0 + + def callback(): + nonlocal flag + flag += 1 + + frame.add_exit_callback(callback) + + with frame: + pass + + assert flag == 1 + + with frame: + pass + + assert flag == 1