Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL bool ContainGlobalVar(const std::string& name) const;

/*!
* \brief Check if the global_type_var_map_ contains a global type variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;

/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
Expand Down Expand Up @@ -198,13 +205,6 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL TypeData LookupDef(const std::string& var) const;

/*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;

/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
Expand Down
9 changes: 6 additions & 3 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize"
TVM_DLL Pass CanonicalizeCast();

/*!
* \brief Add abstraction over a function
* \brief Add abstraction over a constructor or global variable bound to a function.
*
* For example: `square` is transformed to
* `fun x -> square x`.
* `fn (%x: int32) -> int32 { square(x) }`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \param expand_constructor Whether to expand constructors.
* \param expand_global_var Whether to expand global variables.
*
* \return The pass.
*/
TVM_DLL Pass EtaExpand();
TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);

/*!
* \brief Print the IR for a module to help debugging.
Expand Down
15 changes: 3 additions & 12 deletions python/tvm/relay/std/prelude.rly
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,9 @@ def @sum(%xs: List[Tensor[(), int32]]) {
/*
* Concatenates two lists.
*/

def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] {
let %updater = fn(%x: A, %xss: List[A]) -> List[A] {
Cons(%x, %xss)
};
@foldr(%updater, %ys, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldr(Cons, %ys, %xs)
@foldr(Cons, %ys, %xs)
}

/*
Expand Down Expand Up @@ -199,12 +195,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] {
* Reverses a list.
*/
def @rev[A](%xs: List[A]) -> List[A] {
let %updater = fn(%xss: List[A], %x: A) -> List[A] {
Cons(%x, %xss)
};
@foldl(%updater, Nil, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldl(@flip(Cons), Nil, %xs)
@foldl(@flip(Cons), Nil, %xs)
}

/*
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,15 +513,23 @@ def ToCPS(expr, mod=None):
return _transform.to_cps(expr, mod)


def EtaExpand():
"""Add abstraction over a function
def EtaExpand(expand_constructor=False, expand_global_var=False):
"""Add abstraction over a constructor or global variable bound to a function

Parameters
----------
expand_constructor: bool
Whether to expand constructors.

expand_global_var: bool
Whether to expand global variables.

Returns
-------
ret: tvm.relay.Pass
The registered pass that eta expands an expression.
"""
return _transform.EtaExpand()
return _transform.EtaExpand(expand_constructor, expand_global_var)


def ToGraphNormalForm():
Expand Down Expand Up @@ -938,6 +946,7 @@ def create_function_pass(pass_arg):
return create_function_pass(pass_func)
return create_function_pass


@function_pass(opt_level=1)
class ChangeBatch:
"""
Expand Down
11 changes: 11 additions & 0 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h>
Expand Down Expand Up @@ -789,6 +790,16 @@ CreateInterpreter(
Module mod,
DLContext context,
Target target) {
if (mod.defined()) {
// eta expand to support constructors in argument position
transform::Sequential seq({
transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false)});
transform::PassContext pass_ctx = transform::PassContext::Current();
tvm::With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);
}

auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) {
auto f = DetectFeature(expr);
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,10 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets)
pass_seqs.push_back(transform::Legalize());
}

// eta expand to support constructors in argument position
pass_seqs.push_back(transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false));

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
Expand Down
22 changes: 16 additions & 6 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) {
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct LambdaLifter : ExprMutator {
Module module_;
class LambdaLifter : public ExprMutator {
public:
explicit LambdaLifter(const Module& module) : module_(module) {}

Expr VisitExpr_(const FunctionNode* func_node) final {
Expand Down Expand Up @@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator {
// The "inner" function should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
if (free_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
} else {
lifted_func =
FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
Expand All @@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator {
auto name = GenerateName(lifted_func);
auto global = GlobalVarNode::make(name);

// Add the lifted function to the module.
module_->Add(global, lifted_func);
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
// If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name);
} else {
// Add the lifted function to the module.
module_->Add(global, lifted_func);
}

if (free_vars.size() == 0) {
return std::move(global);
Expand Down Expand Up @@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator {
}
return module_;
}

private:
Module module_;
};

} // namespace vm
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class AlphaEqualHandler:
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->HasDef(p.first->var->name_hint) ||
if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false;
}
Expand Down
12 changes: 6 additions & 6 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}

bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}

GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
Expand Down Expand Up @@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id);
}

bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}

Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
Expand Down Expand Up @@ -336,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add")
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
mod_copy = transform::EtaExpand()(mod_copy);
mod_copy = transform::EtaExpand(
/* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<Function>(func), update);
} else {
Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ class PrettyPrinter :
Doc VisitExpr_(const ConstructorNode* n) final {
Doc doc;
doc << n->name_hint;
if (n->inputs.size() != 0) {
if (in_adt_def_ && n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
Expand Down Expand Up @@ -775,6 +775,7 @@ class PrettyPrinter :
}

Doc VisitType_(const TypeDataNode* node) final {
in_adt_def_ = true;
Doc doc;
doc << "type " << Print(node->header);

Expand Down Expand Up @@ -802,6 +803,7 @@ class PrettyPrinter :
adt_body << ",";
}
doc << Brace(adt_body);
in_adt_def_ = false;
return doc;
}

Expand Down Expand Up @@ -876,6 +878,8 @@ class PrettyPrinter :
TextMetaDataContext meta_;
/*! \brief counter of temporary variable */
size_t temp_var_counter_{0};
/*! \brief whether the printer is currently in an ADT definition */
bool in_adt_def_;
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
Expand Down
Loading