From dc21cd56a4712c3e06acbb290e6ab0e3c774ac73 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 23 Oct 2019 21:51:41 -0700 Subject: [PATCH 1/3] Add support for attaching params --- include/tvm/relay/expr.h | 14 +++++++++++++- python/tvm/relay/expr.py | 6 ++++++ src/relay/ir/expr.cc | 20 ++++++++++++++++++++ tests/python/relay/test_ir_nodes.py | 25 ++++++++++++++++++++++++- 4 files changed, 63 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 6df4273d34c0..ff075e3a8970 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -274,6 +274,19 @@ class FunctionNode : public ExprNode { tvm::Array ty_params, tvm::Attrs attrs = Attrs()); + /*! + * \brief Attach the function's parameters to its attributes for use in analysis. + * \return The function with its parameters attached. + */ + Function SetParams(const tvm::Map& parameters) const; + + /*! + * \brief Retrieve the function's parameters. + * + * \return The function's parameter. + */ + tvm::Map GetParams() const; + static constexpr const char* _type_key = "relay.Function"; TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); }; @@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key); TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data); - /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 88779dfd76e0..3237ddfb1dcd 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -305,6 +305,12 @@ def __call__(self, *args): """ return Call(self, args, None, None) + def get_params(self, params): + return _expr.FunctionGet(self, params) + + def set_params(self, params): + return _expr.FunctionSetParams(self, params) + @register_relay_node class Call(Expr): diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 35e4f2b4ab13..58f125b13361 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -159,6 +159,26 @@ bool FunctionNode::IsPrimitive() const { return pval && pval->value != 0; } +Function FunctionNode::SetParams(const tvm::Map& parameters) const { + return FunctionSetAttr(GetRef(this), "__params__", parameters); +} + +TVM_REGISTER_API("relay._expr.FunctionSetParms") +.set_body_typed&)>( + [](const Function& func, const tvm::Map& parameters) { + return func->SetParams(parameters); +}); + +tvm::Map FunctionNode::GetParams() const { + auto node_ref = FunctionGetAttr(GetRef(this), "__params__"); + return Downcast>(node_ref); +} + +TVM_REGISTER_API("relay._expr.FunctionGetParms") +.set_body_typed(const Function&)>([](const Function& func) { + return func->GetParams(); +}); + NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (!func->attrs.defined()) { return NodeRef(); } diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index b42a1e6d52c6..69f087608b97 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -20,7 +20,7 @@ from tvm.expr import * from tvm.relay import op from tvm.relay.analysis import graph_equal - +import numpy as np def check_json_roundtrip(node): json_str = tvm.save_json(node) @@ -175,6 +175,29 @@ def test_function(): str(fn) check_json_roundtrip(fn) +def test_function_attrs(): + param_names = ['a', 'b', 'c', 'd'] + params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names]) + ret_type = relay.TupleType(tvm.convert([])) + body = relay.Tuple(tvm.convert([])) + type_params = tvm.convert([]) + fn = relay.Function(params, body, ret_type, type_params) + model_params = {} + for param in params[:1]: + tensor = np.random.rand(*param.shape).astype(param.dtype) + model_params[param] = tvm.nd.array(tensor) + fn = fn.set_params(model_params) + assert fn.params == params + assert fn.body == body + assert fn.type_params == type_params + assert fn.span == None + str(fn) + check_json_roundtrip(fn) + json_str = tvm.save_json(fn) + fn_after = tvm.load_json(json_str) + model_params_after = fn_after.get_params() + for p1, p2 in zip(model_params, model_params_after): + assert p1.asnumpy() == p2.asnumpy() def test_call(): op = relay.Var('f') From 9cad1b6c591cdd09b9eb2b0445525d5bd2612237 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 23 Oct 2019 22:02:46 -0700 Subject: [PATCH 2/3] Fix types --- src/relay/ir/expr.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 58f125b13361..dfa20709c7fe 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -171,11 +171,11 @@ TVM_REGISTER_API("relay._expr.FunctionSetParms") tvm::Map FunctionNode::GetParams() const { auto node_ref = FunctionGetAttr(GetRef(this), "__params__"); - return Downcast>(node_ref); + return Downcast>(node_ref); } TVM_REGISTER_API("relay._expr.FunctionGetParms") -.set_body_typed(const Function&)>([](const Function& func) { +.set_body_typed(const Function&)>([](const Function& func) { return func->GetParams(); }); From 8f111b0c121d9e0383ad8afcf559f2fc754a8c99 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 24 Oct 2019 14:40:49 -0700 Subject: [PATCH 3/3] Fix test --- python/tvm/relay/expr.py | 10 ++++++++-- src/relay/ir/expr.cc | 4 ++-- tests/python/relay/test_ir_nodes.py | 14 ++++++++++---- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 3237ddfb1dcd..8d59e99d8388 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -27,6 +27,7 @@ from .._ffi import base as _base from .. import nd as _nd from .. import convert +from ..ndarray import NDArray # will be registered afterwards _op_make = None @@ -305,10 +306,15 @@ def __call__(self, *args): """ return Call(self, args, None, None) - def get_params(self, params): - return _expr.FunctionGet(self, params) + def get_params(self): + return _expr.FunctionGetParams(self) def set_params(self, params): + for key in params: + value = params[key] + if isinstance(value, NDArray): + params[key] = Constant(value) + return _expr.FunctionSetParams(self, params) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index dfa20709c7fe..c36b4c8566b8 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -163,7 +163,7 @@ Function FunctionNode::SetParams(const tvm::Map& parameters) cons return FunctionSetAttr(GetRef(this), "__params__", parameters); } -TVM_REGISTER_API("relay._expr.FunctionSetParms") +TVM_REGISTER_API("relay._expr.FunctionSetParams") .set_body_typed&)>( [](const Function& func, const tvm::Map& parameters) { return func->SetParams(parameters); @@ -174,7 +174,7 @@ tvm::Map FunctionNode::GetParams() const { return Downcast>(node_ref); } -TVM_REGISTER_API("relay._expr.FunctionGetParms") +TVM_REGISTER_API("relay._expr.FunctionGetParams") .set_body_typed(const Function&)>([](const Function& func) { return func->GetParams(); }); diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 69f087608b97..dec840a214a0 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -160,7 +160,6 @@ def test_global_var(): str(gv) check_json_roundtrip(gv) - def test_function(): param_names = ['a', 'b', 'c', 'd'] params = tvm.convert([relay.Var(n) for n in param_names]) @@ -184,7 +183,8 @@ def test_function_attrs(): fn = relay.Function(params, body, ret_type, type_params) model_params = {} for param in params[:1]: - tensor = np.random.rand(*param.shape).astype(param.dtype) + cty = param.type_annotation + tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype) model_params[param] = tvm.nd.array(tensor) fn = fn.set_params(model_params) assert fn.params == params @@ -196,8 +196,12 @@ def test_function_attrs(): json_str = tvm.save_json(fn) fn_after = tvm.load_json(json_str) model_params_after = fn_after.get_params() - for p1, p2 in zip(model_params, model_params_after): - assert p1.asnumpy() == p2.asnumpy() + after_keys = [item[0] for item in model_params_after.items()] + for key1, key2 in zip(model_params, after_keys): + assert key1.name_hint == key2.name_hint + p1 = model_params[key1] + p2 = model_params_after[key2] + np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy()) def test_call(): op = relay.Var('f') @@ -280,9 +284,11 @@ def test_conv2d_attrs(): test_local_var() test_global_var() test_function() + test_function_attrs() test_call() test_let() test_if() test_tuple_get_item() test_op() test_conv2d_attrs() +