-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[BYOC-DNNL] add post_sum pattern #12151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e8d8a67 to
b9ed19d
Compare
fe68b77 to
63a44ba
Compare
tests/python/contrib/test_dnnl.py
Outdated
| if use_dnnl: | ||
| processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) | ||
| check_dnnl_used(processed_mod) | ||
| print(processed_mod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
63a44ba to
d89db44
Compare
python/tvm/relay/op/contrib/dnnl.py
Outdated
| dnnl_patterns.append( | ||
| ("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker)) | ||
| ), | ||
| dnnl_patterns.append( | ||
| ("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker)) | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| dnnl_patterns.append( | |
| ("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker)) | |
| ), | |
| dnnl_patterns.append( | |
| ("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker)) | |
| ), | |
| dnnl_patterns.append( | |
| ( | |
| "dnnl.conv2d_bias_sum_relu", | |
| make_conv_bias_sum_relu_pattern("nn.conv2d"), | |
| make_predicate(add_checker), | |
| ) | |
| ) | |
| dnnl_patterns.append( | |
| ( | |
| "dnnl.conv2d_bias_sum", | |
| make_conv_bias_sum_relu_pattern("nn.conv2d", False), | |
| make_predicate(add_checker), | |
| ) | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
| // TODO(@apeskov): Simulation of inplace primitive. just as PoC. | ||
| auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout); | ||
| if (op_name.find("_sum") != std::string::npos) { | ||
| sum_in_tr = GetInput(nid, node.GetInputs().size()-1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| sum_in_tr = GetInput(nid, node.GetInputs().size()-1); | |
| sum_in_tr = GetInput(nid, node.GetInputs().size() - 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/python/contrib/test_dnnl.py
Outdated
| param_lst += ["data1"] | ||
| return relay.nn.relu(out), dic, param_lst | ||
|
|
||
| conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype) | |
| conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( | |
| x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/python/contrib/test_dnnl.py
Outdated
| config = conv2d_bn_sum_relu, dic, param_lst | ||
| run_and_verify_func(config, run_module=run_module, dtype=dtype) | ||
|
|
||
| conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype) | |
| conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( | |
| x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/python/contrib/test_dnnl.py
Outdated
| # tvm.testing.main() | ||
| test_conv2d_bias_sum_relu(True) No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # tvm.testing.main() | |
| test_conv2d_bias_sum_relu(True) | |
| tvm.testing.main() | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
654bab6 to
b2dca37
Compare
b2dca37 to
001898f
Compare
|
@masahi Could you please review this PR? This PR adds |
* add post_sum pattern * add checkers for sum pattern * fix lint * fix error in test_pass_partition_graph * fix lint error
This PR add
conv2d-add-sum-relupattern, and the corresponding test case is added.