Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 12 additions & 41 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as _np # type: ignore

import tvm
import tvm._ffi
import tvm.relax
import tvm.ir
import tvm.relax
from tvm import DataType
from tvm._ffi import base as _base
from tvm.runtime import ndarray as _nd, Object
from tvm.runtime import Object
from tvm.runtime import ndarray as _nd

from ..ir import BaseFunc, Node, SourceName, Span
from ..runtime import String
from ..runtime import Scriptable, String
from ..tir import PrimExpr
from . import _ffi_api

Expand All @@ -55,7 +57,7 @@ def __init__(self):

# NOTE: place base struct info in expr to avoid cyclic dep
# from expr to struct info.
class StructInfo(Node):
class StructInfo(Node, Scriptable):
"""The base class of all StructInfo.

StructInfo contains both the static type
Expand Down Expand Up @@ -110,7 +112,7 @@ def _binary_rhs_helper(rhs: "ExprWithOp") -> "ExprWithOp":
raise TypeError(f"type {type(rhs)} not supported")


class ExprWithOp(Expr):
class ExprWithOp(Expr, Scriptable):
"""Basetype of all relax expressions that defines op overloading."""

def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp":
Expand Down Expand Up @@ -436,7 +438,7 @@ def __init__(


@tvm._ffi.register_object("relax.expr.PrimValue")
class PrimValue(Expr):
class PrimValue(Expr, Scriptable):
"""The prim expr representing the value."""

value: PrimExpr
Expand All @@ -448,7 +450,7 @@ def __init__(self, value: Union[PrimExpr, int], span: Span = None) -> None:


@tvm._ffi.register_object("relax.expr.StringImm")
class StringImm(Expr):
class StringImm(Expr, Scriptable):
"""Represent a string literal constant."""

value: str
Expand All @@ -458,7 +460,7 @@ def __init__(self, value: str, span: Span = None) -> None:


@tvm._ffi.register_object("relax.expr.DataTypeImm")
class DataTypeImm(Expr):
class DataTypeImm(Expr, Scriptable):
"""Represent a data type constant."""

value: DataType
Expand All @@ -468,11 +470,9 @@ def __init__(self, value: Union[DataType, str], span: Span = None) -> None:


@tvm._ffi.register_object("relax.expr.Binding")
class Binding(Node):
class Binding(Node, Scriptable):
"""The base class of a binding in Relax."""

...


@tvm._ffi.register_object("relax.expr.MatchCast")
class MatchCast(Binding):
Expand Down Expand Up @@ -548,7 +548,7 @@ def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) ->


@tvm._ffi.register_object("relax.expr.Function")
class Function(BaseFunc):
class Function(BaseFunc, Scriptable):
"""A Relax function."""

params: List[Var]
Expand Down Expand Up @@ -588,35 +588,6 @@ def __call__(self, *args):
"""
return Call(self, args, None, None)

def script(self, show_meta: bool = False) -> str:
"""Print relax.Function into TVMScript

Parameters
----------
show_meta : bool
Whether to show meta information

Returns
-------
script : str
The TVM Script of the relax.Function
"""
return tvm._ffi.get_global_func("script.AsRelaxScript")(self, show_meta) # type: ignore

def show(self, style: str = "light") -> None:
"""
A sugar for print highlighted TVM script.

Parameters
----------
style : str, optional
Pygments styles extended by "light" (default) and "dark", by default "light"
"""
from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel

# Use deferred import to avoid circular import while keeping cprint under tvm/script
cprint(self, style=style)


@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
Expand Down
45 changes: 0 additions & 45 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ TVM_REGISTER_GLOBAL("relax.Call")
.set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<StructInfo> sinfo_args,
Span span) { return Call(op, args, attrs, sinfo_args, span); });

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const CallNode*>(ref.get());
p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", "
<< node->sinfo_args << ")";
});

If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) {
ObjectPtr<IfNode> n = make_object<IfNode>();
n->cond = std::move(cond);
Expand Down Expand Up @@ -137,13 +130,6 @@ TVM_REGISTER_GLOBAL("relax.If")
return If(cond, true_branch, false_branch, span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IfNode*>(ref.get());
p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", "
<< node->false_branch << ")";
});

Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
ObjectPtr<TupleNode> n = make_object<TupleNode>();
n->fields = std::move(fields);
Expand Down Expand Up @@ -179,12 +165,6 @@ Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, Optional<Span> o
return tuple;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleNode*>(ref.get());
p->stream << "Tuple(" << node->fields << ")";
});

TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
n->tuple = std::move(tuple);
Expand Down Expand Up @@ -216,12 +196,6 @@ TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int inde
return TupleGetItem(tuple, index);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TupleGetItemNode*>(ref.get());
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});

TVM_REGISTER_NODE_TYPE(ShapeExprNode);

ShapeExpr::ShapeExpr(Array<PrimExpr> values, Span span) {
Expand All @@ -245,19 +219,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array<PrimExpr> values,
return ShapeExpr(values, span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShapeExprNode>([](const ObjectRef& ref, ReprPrinter* p) {
const ShapeExprNode* node = static_cast<const ShapeExprNode*>(ref.get());
p->stream << "ShapeExpr(";
for (auto it = node->values.begin(); it != node->values.end(); it++) {
if (it != node->values.begin()) {
p->stream << ", ";
}
p->stream << *it;
}
p->stream << ")";
});

TVM_REGISTER_NODE_TYPE(VarNode);

Var::Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span) {
Expand Down Expand Up @@ -572,12 +533,6 @@ TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol,
return ExternFunc(global_symbol, span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ExternFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
const auto* node = static_cast<const ExternFuncNode*>(ref.get());
p->stream << "ExternFunc(\"" << node->global_symbol << "\")";
});

Expr GetShapeOf(const Expr& expr) {
// default case, to be normalized.
ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr";
Expand Down
43 changes: 0 additions & 43 deletions src/relax/ir/struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) {
return ObjectStructInfo(span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ObjectStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
p->stream << "ObjectStructInfo()";
});

// Prim
PrimStructInfo::PrimStructInfo(DataType dtype, Span span) {
ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
Expand All @@ -60,12 +55,6 @@ TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Sp
return PrimStructInfo(dtype, span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
const auto* node = static_cast<const PrimStructInfoNode*>(ref.get());
p->stream << "PrimStructInfo(" << node->dtype << ")";
});

// Shape
ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) {
ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
Expand Down Expand Up @@ -102,16 +91,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeStructInfo")
}
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShapeStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
const auto* node = static_cast<const ShapeStructInfoNode*>(ref.get());
if (node->values.defined()) {
p->stream << "ShapeStructInfo(" << node->values.value() << ")";
} else {
p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")";
}
});

// Tensor
TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) {
ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
Expand Down Expand Up @@ -150,16 +129,6 @@ TVM_REGISTER_GLOBAL("relax.TensorStructInfo")
}
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
const auto* node = static_cast<const TensorStructInfoNode*>(ref.get());
if (node->shape.defined()) {
p->stream << "TensorStructInfo(" << node->shape.value() << ", " << node->dtype << ")";
} else {
p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" << node->ndim << ")";
}
});

// Tuple
TupleStructInfo::TupleStructInfo(Array<StructInfo> fields, Span span) {
ObjectPtr<TupleStructInfoNode> n = make_object<TupleStructInfoNode>();
Expand All @@ -175,12 +144,6 @@ TVM_REGISTER_GLOBAL("relax.TupleStructInfo")
return TupleStructInfo(fields, span);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
const auto* node = static_cast<const TupleStructInfoNode*>(ref.get());
p->stream << "TupleStructInfo(" << node->fields << ")";
});

// Func
FuncStructInfo::FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span span) {
ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
Expand Down Expand Up @@ -223,12 +186,6 @@ TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc")
}
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FuncStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
const auto* node = static_cast<const FuncStructInfoNode*>(ref.get());
p->stream << "FuncStructInfo(" << node->params << ", " << node->ret << ")";
});

// Helper functions
void UpdateStructInfo(Expr expr, StructInfo struct_info) {
ICHECK(!expr->struct_info_.defined())
Expand Down
87 changes: 87 additions & 0 deletions src/script/printer/relax/binding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"

namespace tvm {
namespace script {
namespace printer {

IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& d, //
const Optional<ExprDoc>& var, const Optional<ExprDoc>& ann) {
using relax::SeqExpr;
ExprDoc cond = d->AsDoc<ExprDoc>(n->cond, n_p->Attr("cond"));
std::vector<Array<StmtDoc>> branches{
PrintSeqExpr(Downcast<SeqExpr>(n->true_branch), n_p->Attr("true_branch"), d, false),
PrintSeqExpr(Downcast<SeqExpr>(n->false_branch), n_p->Attr("false_branch"), d, false),
};
if (var.defined()) {
for (Array<StmtDoc>& stmts : branches) {
ExprDoc ret = Downcast<ExprStmtDoc>(stmts.back())->expr;
stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann));
}
}
return IfDoc(cond, branches[0], branches[1]);
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::MatchCast>(
"", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc {
using relax::StructInfo;
using relax::MatchStructInfo;
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
ExprDoc rhs = Relax(d, "match_cast")
->Call({d->AsDoc<ExprDoc>(n->value, n_p->Attr("value")),
d->AsDoc<ExprDoc>(n->struct_info, n_p->Attr("struct_info_"))});
ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
return AssignDoc(lhs, rhs, ann);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::VarBinding>( //
"", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc {
if (const auto if_ = n->value.as<relax::IfNode>()) {
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
return PrintIfExpr(GetRef<relax::If>(if_), n_p->Attr("value"), d, lhs, ann);
} else if (n->value->IsInstance<tvm::BaseFuncNode>()) {
IdDoc lhs = DefineVar(n->var, d->frames.back(), d);
d->cfg->binding_names.push_back(lhs->name);
Doc ret = d->AsDoc(n->value, n_p->Attr("value"));
d->cfg->binding_names.pop_back();
return ret;
} else {
ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
return AssignDoc(lhs, rhs, ann);
}
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::If>("", [](relax::If n, ObjectPath n_p, IRDocsifier d) -> Doc {
return PrintIfExpr(n, n_p, d, NullOpt, NullOpt);
});

TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::VarBindingNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::IfNode, ReprPrintRelax);

} // namespace printer
} // namespace script
} // namespace tvm
Loading