Skip to content

Conversation

@crazydemo
Copy link

This PR add conv2d-add-sum-relu pattern, and the corresponding test case is added.

@crazydemo crazydemo force-pushed the upstream-sum_pattern branch from e8d8a67 to b9ed19d Compare July 22, 2022 03:12
@crazydemo crazydemo force-pushed the upstream-sum_pattern branch 3 times, most recently from fe68b77 to 63a44ba Compare July 26, 2022 02:37
if use_dnnl:
processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
check_dnnl_used(processed_mod)
print(processed_mod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@crazydemo crazydemo force-pushed the upstream-sum_pattern branch from 63a44ba to d89db44 Compare July 27, 2022 02:03
Comment on lines 398 to 403
dnnl_patterns.append(
("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker))
),
dnnl_patterns.append(
("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker))
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dnnl_patterns.append(
("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker))
),
dnnl_patterns.append(
("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker))
),
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum_relu",
make_conv_bias_sum_relu_pattern("nn.conv2d"),
make_predicate(add_checker),
)
)
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum",
make_conv_bias_sum_relu_pattern("nn.conv2d", False),
make_predicate(add_checker),
)
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

// TODO(@apeskov): Simulation of inplace primitive. just as PoC.
auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout);
if (op_name.find("_sum") != std::string::npos) {
sum_in_tr = GetInput(nid, node.GetInputs().size()-1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sum_in_tr = GetInput(nid, node.GetInputs().size()-1);
sum_in_tr = GetInput(nid, node.GetInputs().size() - 1);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

param_lst += ["data1"]
return relay.nn.relu(out), dic, param_lst

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype)
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype)
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 1774 to 1775
# tvm.testing.main()
test_conv2d_bias_sum_relu(True) No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# tvm.testing.main()
test_conv2d_bias_sum_relu(True)
tvm.testing.main()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@crazydemo crazydemo force-pushed the upstream-sum_pattern branch from 654bab6 to b2dca37 Compare July 28, 2022 08:27
@crazydemo crazydemo force-pushed the upstream-sum_pattern branch from b2dca37 to 001898f Compare July 29, 2022 01:35
@crazydemo
Copy link
Author

@masahi Could you please review this PR? This PR adds conv2d-add-sum-relu pattern with required checks, and the corresponding test case is added.

@masahi masahi merged commit c07d77f into apache:main Aug 1, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* add post_sum pattern

* add checkers for sum pattern

* fix lint

* fix error in test_pass_partition_graph

* fix lint error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants