-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[OpenCL] Add vectorization to cuda conv2d_nhwc schedule #8636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b66ac7f
66135c9
d8cf61f
8014e80
828cdfa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| # pylint: disable=invalid-name, unused-argument | ||
| """Schedule for conv2d operator""" | ||
| from tvm import te, autotvm | ||
|
|
||
| from .. import nn | ||
| from ..utils import traverse_inline | ||
| from .conv2d_nhwc import schedule_conv2d_nhwc_direct | ||
|
|
||
|
|
||
| @autotvm.register_topi_compute("conv2d_nhwc.gpu") | ||
| def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): | ||
| """Compute conv2d with NHWC layout""" | ||
| return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) | ||
|
|
||
|
|
||
| @autotvm.register_topi_schedule("conv2d_nhwc.gpu") | ||
| def schedule_conv2d_nhwc(cfg, outs): | ||
| """Create the schedule for conv2d_nhwc""" | ||
| outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs | ||
| s = te.create_schedule([x.op for x in outs]) | ||
|
|
||
| def _callback(op): | ||
| if op.tag == "conv2d_nhwc": | ||
| schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) | ||
|
|
||
| traverse_inline(s, outs[0].op, _callback) | ||
| return s |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,12 +54,13 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): | |
| cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2]) | ||
| cfg.define_knob("vthread_c", [1, 2]) | ||
| cfg.define_knob("step", [16, 3, 32, 64]) | ||
| cfg.define_knob("vectorize", [1, 2, 4, 8]) | ||
|
|
||
| # fallback support | ||
| target = tvm.target.Target.current() | ||
| if cfg.is_fallback: | ||
| ref_log = autotvm.tophub.load_reference_log( | ||
| target.kind.name, target.model, "conv2d_nhwc.cuda" | ||
| target.kind.name, target.model, "conv2d_nhwc.gpu" | ||
| ) | ||
| cfg.fallback_with_reference_log(ref_log) | ||
|
|
||
|
|
@@ -70,6 +71,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): | |
| vthread_n = cfg["vthread_n"].val | ||
| vthread_c = cfg["vthread_c"].val | ||
| step = cfg["step"].val | ||
| vec_factor = cfg["vectorize"].val | ||
| block_factor_c = tile_c * num_thread_c * vthread_c | ||
|
|
||
| offset = 8 | ||
|
|
@@ -85,15 +87,17 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): | |
| thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy") | ||
|
|
||
| # Schedule for output | ||
| ni, hi, wi, fi = s[output].op.axis | ||
| bx = s[output].fuse(hi, wi) | ||
| ni, _, wi, fi = s[output].op.axis | ||
| bx = wi | ||
| fi, vec = s[output].split(fi, factor=vec_factor) | ||
| s[output].vectorize(vec) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this supposed to vectorize the conv2d inner loop? Based on generated code, I think it only vectorize the last stage, which can be copying local to global mem or fused activation computation. I wonder where 6-7x perf improvement comes from? Here is an example of generated code where
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late reply, I was on vacation. Thank you for your question @masahi! __kernel void my_conv_kernel0(__global float* restrict inp, __global float* restrict w, __global float* restrict Conv2dOutput) {
float Conv2dOutput_local[4];
__local float PaddedInput_shared[24];
__local float w_shared[256];
float PaddedInput_shared_local[1];
float w_shared_local[4];
for (int yy = 0; yy < 298; ++yy) {
- for (int ff_c_init = 0; ff_c_init < 4; ++ff_c_init) {
- Conv2dOutput_local[(ff_c_init)] = 0.000000e+00f;
- }
+ vstore4(((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)), 0, Conv2dOutput_local + 0);
for (int rc_outer = 0; rc_outer < 2; ++rc_outer) {
for (int ry = 0; ry < 3; ++ry) {
for (int rx = 0; rx < 3; ++rx) {
barrier(CLK_LOCAL_MEM_FENCE);
PaddedInput_shared[(((((int)get_local_id(1)) * 4) + ((int)get_local_id(0))))] = inp[((((((((yy * 9600) + (ry * 9600)) + (((int)get_group_id(2)) * 32)) + (rx * 32)) + (rc_outer * 16)) + (((int)get_local_id(1)) * 4)) + ((int)get_local_id(0))))];
for (int ax2_ax3_fused_outer_outer_outer = 0; ax2_ax3_fused_outer_outer_outer < 8; ++ax2_ax3_fused_outer_outer_outer) {
vstore2(vload2(0, w + (((((((ry * 3072) + (rx * 1024)) + (rc_outer * 512)) + (ax2_ax3_fused_outer_outer_outer * 64)) + ((((((int)get_local_id(1)) * 8) + (((int)get_local_id(0)) * 2)) >> 4) * 32)) + (((int)get_group_id(0)) * 16)) + (((((int)get_local_id(1)) * 8) + (((int)get_local_id(0)) * 2)) & 15))), 0, w_shared + (((ax2_ax3_fused_outer_outer_outer * 32) + (((int)get_local_id(1)) * 8)) + (((int)get_local_id(0)) * 2)));
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int rc_inner = 0; rc_inner < 16; ++rc_inner) {
if (((int)get_local_id(1)) < 1) {
PaddedInput_shared_local[(0)] = PaddedInput_shared[(((((int)get_local_id(1)) * 24) + rc_inner))];
}
for (int ax3 = 0; ax3 < 4; ++ax3) {
w_shared_local[(ax3)] = w_shared[((((rc_inner * 16) + (((int)get_local_id(0)) * 4)) + ax3))];
}
- for (int ff_c = 0; ff_c < 4; ++ff_c) {
- if (((int)get_local_id(1)) < 1) {
- Conv2dOutput_local[(ff_c)] = (Conv2dOutput_local[(ff_c)] + (PaddedInput_shared_local[(0)] * w_shared_local[(ff_c)]));
- }
+ if (((int)get_local_id(1)) < 1) {
+ vstore4((vload4(0, Conv2dOutput_local + 0) + (((float4)(PaddedInput_shared_local[(0)], PaddedInput_shared_local[(0)], PaddedInput_shared_local[(0)], PaddedInput_shared_local[(0)])) * vload4(0, w_shared_local + 0))), 0, Conv2dOutput_local + 0);
}
}
}
}
}
for (int ff_outer_inner = 0; ff_outer_inner < 2; ++ff_outer_inner) {
if (((int)get_local_id(1)) < 1) {
vstore2(vload2(0, Conv2dOutput_local + (ff_outer_inner * 2)), 0, Conv2dOutput + ((((((((int)get_local_id(1)) * 2841728) + (yy * 9536)) + (((int)get_group_id(2)) * 32)) + (((int)get_group_id(0)) * 16)) + (((int)get_local_id(0)) * 4)) + (ff_outer_inner * 2)));
}
}
}
}With this vectorization, the execution time didn't change in comparison with previous code generation. What about performance boost. First, let me share my performance numbers (the numbers are average execution time in 10 runs) which I got today on the Samsung Galaxy A71:
Bug fix. In the first commit, I also fixed one accuracy problem in OpenCL. Here after fusing
To answer on this question, let's compare the generated OpenCL code for version with bug fix and with the latest code: __kernel void my_conv_kernel0(__global float* restrict inp, __global float* restrict w, __global float* restrict Conv2dOutput) {
- float Conv2dOutput_local[2];
+ float Conv2dOutput_local[4];
__local float PaddedInput_shared[24];
__local float w_shared[256];
float PaddedInput_shared_local[1];
- float w_shared_local[2];
+ float w_shared_local[4];
for (int yy = 0; yy < 298; ++yy) {
- for (int ff_c_init = 0; ff_c_init < 2; ++ff_c_init) {
- Conv2dOutput_local[(ff_c_init)] = 0.000000e+00f;
- }
+ vstore4(((float4)(0.000000e+00f, 0.000000e+00f, 0.000000e+00f, 0.000000e+00f)), 0, Conv2dOutput_local + 0);
for (int rc_outer = 0; rc_outer < 2; ++rc_outer) {
for (int ry = 0; ry < 3; ++ry) {
for (int rx = 0; rx < 3; ++rx) {
barrier(CLK_LOCAL_MEM_FENCE);
PaddedInput_shared[(((((int)get_local_id(1)) * 4) + ((int)get_local_id(0))))] = inp[((((((((yy * 9600) + (ry * 9600)) + (((int)get_group_id(2)) * 32)) + (rx * 32)) + (rc_outer * 16)) + (((int)get_local_id(1)) * 4)) + ((int)get_local_id(0))))];
- for (int ax2_ax3_fused_outer_outer = 0; ax2_ax3_fused_outer_outer < 8; ++ax2_ax3_fused_outer_outer) {
- w_shared[((((ax2_ax3_fused_outer_outer * 32) + ((((((int)get_local_id(1)) * 4) + ((int)get_local_id(0))) >> 3) * 16)) + (((((int)get_local_id(1)) * 4) + ((int)get_local_id(0))) & 7)))] = w[((((((((ry * 3072) + (rx * 1024)) + (rc_outer * 512)) + (ax2_ax3_fused_outer_outer * 64)) + ((((((int)get_local_id(1)) * 4) + ((int)get_local_id(0))) >> 3) * 32)) + (((int)get_group_id(0)) * 8)) + (((((int)get_local_id(1)) * 4) + ((int)get_local_id(0))) & 7)))];
+ for (int ax2_ax3_fused_outer_outer_outer = 0; ax2_ax3_fused_outer_outer_outer < 8; ++ax2_ax3_fused_outer_outer_outer) {
+ vstore2(vload2(0, w + (((((((ry * 3072) + (rx * 1024)) + (rc_outer * 512)) + (ax2_ax3_fused_outer_outer_outer * 64)) + ((((((int)get_local_id(1)) * 8) + (((int)get_local_id(0)) * 2)) >> 4) * 32)) + (((int)get_group_id(0)) * 16)) + (((((int)get_local_id(1)) * 8) + (((int)get_local_id(0)) * 2)) & 15))), 0, w_shared + (((ax2_ax3_fused_outer_outer_outer * 32) + (((int)get_local_id(1)) * 8)) + (((int)get_local_id(0)) * 2)));
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int rc_inner = 0; rc_inner < 16; ++rc_inner) {
if (((int)get_local_id(1)) < 1) {
PaddedInput_shared_local[(0)] = PaddedInput_shared[(((((int)get_local_id(1)) * 24) + rc_inner))];
}
- for (int ax3 = 0; ax3 < 2; ++ax3) {
- w_shared_local[(ax3)] = w_shared[((((rc_inner * 16) + (((int)get_local_id(0)) * 2)) + ax3))];
+ for (int ax3 = 0; ax3 < 4; ++ax3) {
+ w_shared_local[(ax3)] = w_shared[((((rc_inner * 16) + (((int)get_local_id(0)) * 4)) + ax3))];
}
- for (int ff_c = 0; ff_c < 2; ++ff_c) {
- if (((int)get_local_id(1)) < 1) {
- Conv2dOutput_local[(ff_c)] = (Conv2dOutput_local[(ff_c)] + (PaddedInput_shared_local[(0)] * w_shared_local[(ff_c)]));
- }
+ if (((int)get_local_id(1)) < 1) {
+ vstore4((vload4(0, Conv2dOutput_local + 0) + (((float4)(PaddedInput_shared_local[(0)], PaddedInput_shared_local[(0)], PaddedInput_shared_local[(0)], PaddedInput_shared_local[(0)])) * vload4(0, w_shared_local + 0))), 0, Conv2dOutput_local + 0);
}
}
}
}
}
- for (int ff_inner = 0; ff_inner < 2; ++ff_inner) {
+ for (int ff_outer_inner = 0; ff_outer_inner < 2; ++ff_outer_inner) {
if (((int)get_local_id(1)) < 1) {
- Conv2dOutput[(((((((((int)get_local_id(1)) * 2841728) + (yy * 9536)) + (((int)get_group_id(2)) * 32)) + (((int)get_group_id(0)) * 8)) + (((int)get_local_id(0)) * 2)) + ff_inner))] = Conv2dOutput_local[(ff_inner)];
+ vstore2(vload2(0, Conv2dOutput_local + (ff_outer_inner * 2)), 0, Conv2dOutput + ((((((((int)get_local_id(1)) * 2841728) + (yy * 9536)) + (((int)get_group_id(2)) * 32)) + (((int)get_group_id(0)) * 16)) + (((int)get_local_id(0)) * 4)) + (ff_outer_inner * 2)));
}
}
}
}I suppose that the performance boost is connected with decreasing memory latency. We read more data in one execution unit and store them in vector data types.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @echuraev for providing details on the performance gains here, quite a drastic improvement here. |
||
| tx, fi = s[output].split(fi, factor=tile_c) | ||
| txz, tx = s[output].split(tx, factor=num_thread_c) | ||
| bz, txz = s[output].split(txz, factor=vthread_c) | ||
| ty, ni = s[output].split(ni, factor=tile_n) | ||
| tyz, ty = s[output].split(ty, factor=num_thread_n) | ||
| by, tyz = s[output].split(tyz, factor=vthread_n) | ||
| s[output].reorder(bx, by, bz, tyz, txz, ty, tx, ni, fi) | ||
| s[output].reorder(bx, by, bz, tyz, txz, ty, tx, ni, fi, vec) | ||
| s[output].bind(bz, block_z) | ||
| s[output].bind(by, block_y) | ||
| s[output].bind(bx, block_x) | ||
|
|
@@ -106,6 +110,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): | |
| ni, yi, xi, fi = s[OL].op.axis | ||
| ry, rx, rc = s[OL].op.reduce_axis | ||
| rco, rci = s[OL].split(rc, factor=step) | ||
| s[OL].vectorize(fi) | ||
| s[OL].reorder(rco, ry, rx, rci, ni, fi) | ||
|
|
||
| s[AA].compute_at(s[OL], rx) | ||
|
|
@@ -125,6 +130,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): | |
| _, _, ic, o = s[WW].op.axis | ||
| t = s[WW].fuse(ic, o) | ||
| s[WW].storage_align(ic, W_align - 1, W_align) | ||
| t, vec = s[WW].split(t, factor=vec_factor) | ||
| s[WW].vectorize(vec) | ||
| ty, tx = s[WW].split(t, factor=num_thread_c) | ||
| _, ty = s[WW].split(ty, factor=num_thread_n) | ||
| s[WW].bind(tx, thread_x) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.