-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[RELAY/PASS] Simplify inference. #2033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
5905d35
[PASS] Simplify inference.
ZihengJiang c2e217c
[PASS] Update.
ZihengJiang a684e52
Merge branch 'master' into relay
ZihengJiang c7f0d99
[PASS] Fix lint.
ZihengJiang fd8df19
[PASS] Update.
ZihengJiang 28cda83
[PASS] Update.
ZihengJiang f756515
[PASS] Update.
ZihengJiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| /*! | ||
| * Copyright (c) 2018 by Contributors | ||
| * \file simplify_inference.cc | ||
| */ | ||
| #include <tvm/relay/pass.h> | ||
| #include <tvm/relay/expr_functor.h> | ||
| #include <tvm/relay/attrs/nn.h> | ||
| #include "./pattern_util.h" | ||
|
|
||
| namespace tvm { | ||
| namespace relay { | ||
|
|
||
| Expr BatchNormToInferUnpack(const Attrs attrs, | ||
| Expr data, | ||
| Expr gamma, | ||
| Expr beta, | ||
| Expr moving_mean, | ||
| Expr moving_var) { | ||
| const auto param = attrs.as<BatchNormAttrs>(); | ||
| Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon)); | ||
| Expr var_add_eps = Add(moving_var, epsilon); | ||
| Expr sqrt_var = Sqrt(var_add_eps); | ||
| Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var); | ||
|
|
||
| if (param->scale) { | ||
| scale = Multiply(scale, gamma); | ||
| } | ||
| Expr neg_mean = Negative(moving_mean); | ||
| Expr shift = Multiply(neg_mean, scale); | ||
| if (param->center) { | ||
| shift = Add(shift, beta); | ||
| } | ||
|
|
||
| int axis = param->axis; | ||
| const auto* tdata = data->type_as<TensorTypeNode>(); | ||
| scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis}); | ||
| shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis}); | ||
|
|
||
| Expr out = Multiply(data, scale); | ||
| out = Add(out, shift); | ||
| return out; | ||
| } | ||
|
|
||
| class InferenceSimplifier : public ExprMutator { | ||
| public: | ||
| Expr VisitExpr_(const TupleGetItemNode* n) final { | ||
| static const Op& batch_norm = Op::Get("nn.batch_norm"); | ||
| static const Op& dropout = Op::Get("nn.dropout"); | ||
|
|
||
| Expr new_e = ExprMutator::VisitExpr_(n); | ||
| const auto* new_n = new_e.as<TupleGetItemNode>(); | ||
ZihengJiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (new_n->index != 0) { | ||
| return new_e; | ||
| } | ||
| if (const auto* call = new_n->tuple.as<CallNode>()) { | ||
| if (call->op.same_as(batch_norm)) { | ||
| return BatchNormToInferUnpack(call->attrs, | ||
| call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]); | ||
| } else if (call->op.same_as(dropout)) { | ||
| return call->args[0]; | ||
| } | ||
| } | ||
| return new_e; | ||
| } | ||
| }; | ||
|
|
||
| Expr SimplifyInference(const Expr& e) { | ||
| return InferenceSimplifier().Mutate(e); | ||
| } | ||
|
|
||
| TVM_REGISTER_API("relay._ir_pass.simplify_inference") | ||
| .set_body([](TVMArgs args, TVMRetValue* ret) { | ||
| *ret = SimplifyInference(args[0]); | ||
| }); | ||
|
|
||
| } // namespace relay | ||
| } // namespace tvm | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from tvm import relay as rly | ||
| from tvm.relay.ir_pass import simplify_inference, alpha_equal | ||
|
|
||
| def test_simplify_batchnorm(): | ||
| def simple_bn(x, gamma, beta, moving_mean, moving_var, | ||
| axis=1, epsilon=1e-5, shape=None): | ||
| # expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta | ||
| scale = rly.multiply(rly.const(1, 'float32') / | ||
| rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma) | ||
| shift = rly.add( | ||
| rly.multiply(rly.negative(moving_mean), scale), beta) | ||
| num_newaxis = len(shape) - (axis + 1) | ||
| if num_newaxis: | ||
| scale = rly.expand_dims(scale, axis=1, num_newaxis=num_newaxis) | ||
| shift = rly.expand_dims(shift, axis=1, num_newaxis=num_newaxis) | ||
| return x * scale + shift | ||
|
|
||
| def check(dim, axis, nstep): | ||
| eps = 0.01 | ||
| ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32') | ||
| ttype2 = rly.TensorType((10,), 'float32') | ||
| x = rly.var("x", ttype1) | ||
| beta = rly.var("beta", ttype2) | ||
| gamma = rly.var("gamma", ttype2) | ||
| moving_var = rly.var("moving_var", ttype2) | ||
| moving_mean = rly.var("moving_mean", ttype2) | ||
| y1, y2 = x, x | ||
|
|
||
| for _ in range(nstep): | ||
| y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'), | ||
| gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) | ||
| y1 = rly.nn.dropout(y1) | ||
| y1 = rly.ir_pass.infer_type(y1) | ||
| y1 = simplify_inference(y1) | ||
|
|
||
| y2 = simple_bn(y2 + rly.const(1, 'float32'), | ||
| gamma, beta, moving_mean, moving_var, | ||
| epsilon=eps, axis=axis, shape=ttype1.shape) | ||
| assert rly.ir_pass.graph_equal(y1, y2) | ||
|
|
||
| check(2, 1, 1) | ||
| check(4, 1, 1) | ||
| check(4, 0, 3) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_simplify_batchnorm() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.