Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 87 additions & 46 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,68 @@ enum class CallingConv : int {
kDeviceKernelLaunch = 2,
};

/*!
* \brief Supported linkage types.
*/
enum class LinkageType : int {
/*!
* \brief Internal linkage.
*/
kInternal = 0,
/*!
* \brief External linkage.
- Function with external linkage should have a global symbol attached to it.
*/
kExternal = 1
};

/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";

/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";

/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

} // namespace attr

/*!
* \brief Base node of all functions.
*
Expand Down Expand Up @@ -130,6 +192,31 @@ class BaseFuncNode : public RelayExprNode {
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
/*!
* \brief Get the type of the linkage.
*
* Currently, we only consider external/internal linkage.
* This can be extended in the future when necessary.
*
* \return Linkage type.
*
* \code
*
* void Example(const BaseFunc& f) {
* if (f->GetLinkageType() == tvm::LinkageType::kExternal) {
* // Do not remove a function with external linkage
* }
* }
*
* \endcode
*/

LinkageType GetLinkageType() const {
if (GetAttr<String>(attr::kGlobalSymbol))
return LinkageType::kExternal;
else
return LinkageType::kInternal;
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
Expand All @@ -145,51 +232,5 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";

/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";

/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
15 changes: 15 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,22 @@ TVM_DLL Pass RewriteDataflowReshape();
* \return The Pass.
*/
TVM_DLL Pass AttachGlobalSymbol();
/*!
* \brief Bind params of function of the module to constant tensors.
*
* \param func_name The name of the function to bind parameters.
* \param params The parameters to bind.
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params);

/*!
* \brief Fold constant expressions.
*
* \return The Pass.
*/
TVM_DLL Pass FoldConstant();
} // namespace transform
} // namespace relax
} // namespace tvm
Expand Down
62 changes: 61 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import functools
import inspect
import types
from typing import Callable, Union
from typing import Callable, Dict, Union, Optional, List
import numpy as np # type: ignore

import tvm.ir
from . import _ffi_api
Expand Down Expand Up @@ -115,6 +116,65 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
return _ffi_api.AttachGlobalSymbol() # type: ignore


def BindParams(
func_name: str,
params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
) -> tvm.ir.transform.Pass:
"""Bind params of function of the module to constant tensors.

Parameters
----------

func_name: str
The function name to be bound

params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
The map from param name to constant tensors.

Returns
-------
ret: tvm.ir.transform.Pass
"""
tvm_params = {}
for k, v in params.items():
if isinstance(v, np.ndarray):
v = tvm.nd.array(v)
assert isinstance(
v, tvm.runtime.NDArray
), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}"
tvm_params[k] = v

return _ffi_api.BindParams(func_name, tvm_params) # type: ignore


def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass:
"""Remove unused relax/prim functions without external linkage in a IRModule.

Parameters
----------
entry_functions: Optional[List[str]]
The set of entry functions to start from.

Returns
-------
ret : tvm.transform.Pass
The registered pass to remove unused functions.
"""
if entry_functions is None:
entry_functions = ["main"]
return _ffi_api.RemoveUnusedFunctions(entry_functions) # type: ignore


def FoldConstant() -> tvm.ir.transform.Pass:
"""Fold constant expressions.

Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.FoldConstant() # type: ignore


def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
"""Annotate Op Pattern Kind for TIR functions

Expand Down
113 changes: 113 additions & 0 deletions src/relax/transform/bind_params.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/driver/driver_api.h>
#include <tvm/ir/function.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>

#include <utility>

namespace tvm {
namespace relax {

/*!
* \brief Bind params to function by using name
* \param func Relax function
* \param params params dict
* \return Function
*/
inline Function BindParamsByName(Function func, const Map<String, runtime::NDArray>& params) {
std::unordered_map<std::string, Var> name_dict;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(name_dict[name]);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = Constant(kv.second);
}
Expr bound_expr = Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function."
<< "\n";
return ret;
}

/*!
* \brief Bind params to a specific function in a module
* \param m The module
* \param func_name The name of the specific function
* \param param The param dict
* \return The module after binding params.
*/
IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> param) {
IRModuleNode* new_module = m.CopyOnWrite();
Map<GlobalVar, BaseFunc> functions = m->functions;
for (const auto& func_pr : functions) {
if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
if (relax_f->GetLinkageType() == LinkageType::kExternal) {
// Use global_symbol if it's external linkage
Optional<String> gsymbol = relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined() && gsymbol.value() == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
new_module->Update(func_pr.first, f_after_bind);
}
} else {
// Use global var's name_hint if it's internal linkage
if (func_pr.first->name_hint == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
new_module->Update(func_pr.first, f_after_bind);
}
}
}
}
return GetRef<IRModule>(new_module);
}

namespace transform {

Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); };
return CreateModulePass(pass_func, 0, "BindParams", {});
}

TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams);

} // namespace transform

} // namespace relax
} // namespace tvm
Loading