Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions include/tvm/script/printer/frame.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_FRAME_H_
#define TVM_SCRIPT_PRINTER_FRAME_H_

#include <tvm/node/node.h>
#include <tvm/script/printer/doc.h>

#include <utility>
#include <vector>

namespace tvm {
namespace script {
namespace printer {

/*!
* Frame is the core data structure for semantic information
* when printing IR graph into TVMScript code.
*/
class FrameNode : public Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}

virtual ~FrameNode() = default;

/*!
* \brief Add a callback function to be called when this frame exits.
* \param cb The callback function. It should have signature void().
*/
template <typename TCallback>
void AddExitCallback(TCallback&& cb) {
callbacks_.emplace_back(std::forward<TCallback>(cb));
}

/*!
* \brief Method that's called when Frame enters the scope.
*/
virtual void EnterWithScope() {}

/*!
* \brief Method that's called when Frame exits the scope.
*/
virtual void ExitWithScope() {
for (const std::function<void()>& callback : callbacks_) {
callback();
}
callbacks_.clear();
}

static constexpr const char* _type_key = "script.printer.Frame";
TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object);

private:
std::vector<std::function<void()>> callbacks_;
};

/*!
* \brief Reference type of FrameNode
*/
class Frame : public ObjectRef {
protected:
Frame() = default;

public:
virtual ~Frame() = default;
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode);
};

/*!
* \brief MetadataFrame contains information like contant parameter array.
*/
class MetadataFrameNode : public FrameNode {
public:
Array<ObjectRef> metadata;

void VisitAttrs(tvm::AttrVisitor* v) {
FrameNode::VisitAttrs(v);
v->Visit("metadata", &metadata);
}

static constexpr const char* _type_key = "script.printer.MetadataFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode);
};

/*!
* \brief Reference type of MetadataFrameNode
*/
class MetadataFrame : public Frame {
public:
MetadataFrame();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame, MetadataFrameNode);
};

/*!
* \brief VarDefFrame contains information about the free variables that needs to be defined
* at the beginning of the printed snippet.
*/
class VarDefFrameNode : public FrameNode {
public:
Array<StmtDoc> stmts;

void VisitAttrs(tvm::AttrVisitor* v) {
FrameNode::VisitAttrs(v);
v->Visit("stmts", &stmts);
}

static constexpr const char* _type_key = "script.printer.VarDefFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode);
};

/*!
* \brief Reference type of VarDefFrameNode
*/
class VarDefFrame : public Frame {
public:
VarDefFrame();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame, VarDefFrameNode);
};

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_FRAME_H_
81 changes: 81 additions & 0 deletions python/tvm/script/printer/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Frame is the core data structure for semantic information when printing
IR graph into TVMScript code.
"""

from typing import Callable, Sequence

from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.script.printer.doc import StmtDoc

from . import _ffi_api


class Frame(Object):
"""
Frame is the core data structure for semantic information
when printing IR graph into TVMScript code.

Frame base class manages a list of callbacks to be executed
when frame goes out of scope.
"""

def add_exit_callback(self, callback: Callable[[], None]) -> None:
"""
Adds a callback function to be executed when frame goes out of scope.

Parameters
----------
callback : Callable[[], None]
The callback function.
"""
_ffi_api.FrameAddExitCallback(self, callback) # type: ignore # pylint: disable=no-member

def __enter__(self):
_ffi_api.FrameEnterWithScope(self) # type: ignore # pylint: disable=no-member
return self

def __exit__(self, *exception_info):
_ffi_api.FrameExitWithScope(self) # type: ignore # pylint: disable=no-member


@register_object("script.printer.MetadataFrame")
class MetadataFrame(Frame):
"""
MetadataFrame contains information like contant parameter array.
"""

metadata: Sequence[Object]

def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) # type: ignore # pylint: disable=no-member


@register_object("script.printer.VarDefFrame")
class VarDefFrame(Frame):
"""
VarDefFrame contains information about the free variables that needs to
be defined at the beginning of the printed snippet.
"""

stmts: Sequence[StmtDoc]

def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) # type: ignore # pylint: disable=no-member
50 changes: 50 additions & 0 deletions src/script/printer/frame.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/frame.h>

namespace tvm {
namespace script {
namespace printer {

MetadataFrame::MetadataFrame() : MetadataFrame(make_object<MetadataFrameNode>()) {}

VarDefFrame::VarDefFrame() : VarDefFrame(make_object<VarDefFrameNode>()) {}

TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback")
.set_body_typed([](Frame frame, runtime::TypedPackedFunc<void()> callback) {
frame->AddExitCallback(callback);
});
TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope")
.set_body_method<Frame>(&FrameNode::EnterWithScope);
TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope")
.set_body_method<Frame>(&FrameNode::ExitWithScope);

TVM_REGISTER_NODE_TYPE(MetadataFrameNode);
TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() {
return MetadataFrame();
});

TVM_REGISTER_NODE_TYPE(VarDefFrameNode);
TVM_REGISTER_GLOBAL("script.printer.VarDefFrame").set_body_typed([]() { return VarDefFrame(); });

} // namespace printer
} // namespace script
} // namespace tvm
60 changes: 60 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tvm.script.printer.frame import MetadataFrame


def test_frame_add_callback():
frame = MetadataFrame()

flag = 0

def callback1():
nonlocal flag
flag += 1

def callback2():
nonlocal flag
flag += 5

frame.add_exit_callback(callback1)
with frame:
frame.add_exit_callback(callback2)
assert flag == 0

assert flag == 6


def test_frame_clear_callbacks_after_exit():
frame = MetadataFrame()

flag = 0

def callback():
nonlocal flag
flag += 1

frame.add_exit_callback(callback)

with frame:
pass

assert flag == 1

with frame:
pass

assert flag == 1