Skip to content

Conversation

@anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Aug 31, 2018

This PR implements a conv operator for INT8 operations for Intel Skylake and upcoming Intel processors. Currently, this supports input of NCHWc format. Later, there will be NNVM effort to pick up this kernel for conv and perform kernel transform using "CorrectLayout". This PR tackles only the schedule for INT8 conv kernel.

Background

Hardware support
Skylake provides HW support of performing a dot product of 2 4-int8 values while keeping the computational precision at INT32. For Skylake these instructions are vpamaddubsw, vpmaddwd. This support will be enhanced for VNNI instruction. More details can be found at this link (https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training).

Why new schedule?
These instructions require some modifications to the current FP32 schedule. The current schedule does not perform reduction across different elements of a vector register. But, Intel instructions allow reduction across 4 int8 values. Therefore, a new schedule is required.

Why not just rely on LLVM for codegen?
LLVM codegen is not mature for generating these instructions. LLVM has a very restrictive pattern matching to lower down to these INT8 operations. We will need decent efforts on both LLVM and TVM side to reach an IR where LLVM can directly generate these instructions. Therefore, I am currently calling LLVM intrinsics directly from LLVM.

Performance Speedup

These are different conv layers from resnet network. I don't have NNVM changes yet to run an end-to-end experiment. Will update these numbers when we have that.

Workload kernelSize FP32 time INT8 time Speedup
Workload#0 3x3 1.27E-04 7.61E-05 1.668205764
Workload#1 1x1 1.45E-05 1.62E-05 0.8956588432
Workload#2 3x3 9.77E-05 4.52E-05 2.163945956
Workload#3 1x1 9.56E-06 8.55E-06 1.117662311
Workload#4 3x3 0.000124318896 9.35E-05 1.32998721
Workload#5 3x3 8.53E-05 4.26E-05 2.00352076
Workload#6 1x1 9.68E-06 8.45E-06 1.145552654
Workload#7 3x3 0.000101081785 8.07E-05 1.252166183
Workload#8 3x3 7.25E-05 4.82E-05 1.502493515
Workload#9 1x1 9.35E-06 7.04E-06 1.328638704
Workload#10 3x3 0.000118993864 9.70E-05 1.226515769
Workload#11 1x1 0.000117436952 5.97E-05 1.966779456
Workload#12 1x1 0.000121834512 6.08E-05 2.00419261
Workload#13 1x1 6.25E-05 2.78E-05 2.245298107
Workload#14 1x1 0.000106934483 4.96E-05 2.157467654
Workload#15 1x1 0.000280448215 9.23E-05 3.038109489
Workload#16 1x1 0.000118953154 5.06E-05 2.351587926
Workload#17 1x1 3.67E-05 2.20E-05 1.666902107
Workload#18 1x1 4.73E-05 3.66E-05 1.293318986
Workload#19 1x1 0.000132912593 7.11E-05 1.868658243
Workload#20 1x1 5.77E-05 3.91E-05 1.474495608
Workload#21 1x1 3.15E-05 1.70E-05 1.858185831
Workload#22 1x1 5.08E-05 2.74E-05 1.851147766
Workload#23 1x1 0.000111464776 5.45E-05 2.044977186
Workload#24 1x1 5.70E-05 3.06E-05 1.85923854
      Mean ---> 1.732588287

Limitations

  • Current implementation requires input_channels to be a multiple of 4 and output_channels to be a multiple of 16. For other conv layers, we plan to use the FP32 schedule and not use avx512 bw instructions. If performance becomes a big concern, we can look into input channel padding.

@anijain2305
Copy link
Contributor Author

@yidawang @yzhliu @zhiics Please review. Also please feel free to add other reviewers who might be interested.

@FrozenGene
Copy link
Member

@yzhliu Could we start to convert x86 cpu schedule into auto tvm? I think we can leverage arm cpu auto tvm template. Then like this PR, we could avoid add workload manually.

@yidawang
Copy link
Contributor

@FrozenGene We are indeed working on applying auto tvm to x86 cpus. This PR is about INT8 quantization, using intrinsics provided by avx-512 bw, which is potentially applicable to auto tvm as well but we still need to anyway set it up manually first.

@yidawang
Copy link
Contributor

@anijain2305 Can you edit the PR description to put the preliminary performance results on?

Copy link
Contributor

@yidawang yidawang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, please identify and fix the lint issues by running make lint locally.


target_name = 'llvm -mcpu=skylake-avx512'
avx2_len = 16
ctx = tvm.context(target_name, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need of the semicolon. Same comment applies to other similar lines in Python

_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis

# Sylake and future processors have 16 vector lanes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skylake


ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n)

# Sylake and future processors have 16 vector lanes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skylake

@anijain2305
Copy link
Contributor Author

Thanks @yidawang for the comments :) I will start working on them

@tqchen
Copy link
Member

tqchen commented Aug 31, 2018

cc @ajtulloch @eqy @cowanmeg

avx2_len = 16
ctx = tvm.context(target_name, 0);

def getShape(im_height, im_width, in_filter, out_filter, kh, kw, hpad, wpad,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep naming consistent: s/getShape/get_shape

def getShape(im_height, im_width, in_filter, out_filter, kh, kw, hpad, wpad,
hstride, wstride, outDtype):
## Find shapes
dataShape = (1, in_filter/avx2_len, im_height, im_width, avx2_len)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same naming style for variables, s/dataShape/data_shape
It also applies to the other parts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be careful about / vs. //, in this case we will get a floating point value

s = tvm.create_schedule(out.op);
func = tvm.build(s, [data, kernel, out], target=target_name, name='out')
func(a, b, cOrig)
#print(tvm.lower(s, [data, kernel], simple_mode=True));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove debugging code?

else:
HSTR, WSTR = stride, stride
assert data.dtype == kernel.dtype, \
assert data.dtype == kernel.dtype or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(data.dtype == kernel.dtype)

cSch = tvm.nd.array(np.zeros(oShape, dtype=outDtype), ctx);


with tvm.target.create(target_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that quantization currently only works for x86 conv2d, so you want to invoke this "specialized" function directly. In the long team, I think the annotation I am doing could help here if quantization works on more devices. You can annotate node with the target and call conv2d_nchwc from a layer higher so that the dispatcher could find the correct compute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. This is a specific usecase to trigger x86 conv2d compute/schedule.

indices.push_back(i);
}
return builder_->CreateShuffleVector(v0, v1, indices);
} else if (op->is_intrinsic("broadcast16")){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may want to avoid using the string literals directly from both backend and frontend sides because it might be error prone or user unfriendly as the number of them increases. Instead we can probably create a "mapping" or "enum" to do this. But again, this is fine for now.

def getShape(im_height, im_width, in_filter, out_filter, kh, kw, hpad, wpad,
hstride, wstride, outDtype):
## Find shapes
dataShape = (1, in_filter/avx2_len, im_height, im_width, avx2_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be careful about / vs. //, in this case we will get a floating point value

## Find shapes
dataShape = (1, in_filter/avx2_len, im_height, im_width, avx2_len)

if outDtype == 'int32':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cosmetics: keep CamelCase vs. snake_case consistent

else:
a = tvm.nd.array(np.random.randint(100, size=dataShape).astype(dataDtype));
b = tvm.nd.array(np.random.randint(100, size=kernelShape).astype(kernelDtype));
#a = tvm.nd.array(np.ones(dataShape, dtype='uint8'), ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete comments if they are not useful here

avx2_len = 16
else:
return s
assert(avx2_len != -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parenthesis not needed here (lint may complain about this)

"""
This function sets up the compute for INT8 conv 2d
Inputs are in INT8 datatype
Ouptut is in INT32 datatype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output

Copy link
Contributor

@yidawang yidawang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ajtulloch
Copy link
Contributor

a) Would you be able to report achieved GOPS (ideally as a fraction of peak) instead of just time? Additionally, could you compare against MKL-DNN or similar for fp32/int8? (i.e. using benchdnn from MKL-DNN)
b) Do you find padding to be particularly expensive (either spatial or channel padding)? I've noticed that the codegen for tvm_if_then_else seems to be particularly poor, and I wonder if it's worth tackling that at some point.

@anijain2305
Copy link
Contributor Author

@ajtulloch Both good points.

I will update the numbers sometime next week. I agree GOPS is much better metric than just time. Tells us how much is left to optimize for.

For padding, I did not do anything specific for padding. The kernel is built on top of current x86 NCHWc kernel, which hid handling of padding for my implementation. But, I will look deeper and see if the speedup for padded kernels is worse.

strides=[1])
b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
offset_factor=1,
strides=[tvm.var('ldw'), 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel all these strides bindings are unnecessary and can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I dint have strides earlier. The memory accesses were wrong in that case. So, I had to put strides.

Honestly, I am not fully aware of what these different parameters of tvm.decl_buffer mean. I will look into it in more detail to ensure that I have good understanding of why presence of strides make it work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's bizarre. In my understanding, the strides is implicitly inferred (given input tensor is compact), and var(ldw) is for binding the inferred strides. Actually if you changed innermost stride 1 to some other number, I expect it would fail with some binding mismatch error.
@tqchen Could you help with this? I'm also not fully understand the strides for buffer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1725 The usage here is correct, thus it does not block merging this PR.

Copy link
Contributor

@ajtulloch ajtulloch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor nits.

"""
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")

def check_skylake(target):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really belong in a generic file like nn/conv2d.py right? Shouldn't this be in some x86/ specific directory?

target = tvm.target.current_target(allow_none=False)
for opt in target.options:
if opt == '-mcpu=skylake-avx512':
fp32_vec_len = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this reuse the check_skylake function?

@yzhliu
Copy link
Member

yzhliu commented Sep 17, 2018

@ajtulloch Could you take a look again and approve explicitly if it is good? thanks.

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also @tqchen please review again.

@@ -0,0 +1,107 @@
"""Core kernel of dot product of 4 Int8 operations"""
#pylint: disable=invalid-name
import tvm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us rename it to tensor_intrin.py to be consistent with #1707


if __name__ == "__main__":
LOGGER.info("Workload, Kernel_size, FP32_time, INT8_time, Speedup")
SPEEDUP_ARRAY = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since it is in the unitest case, need to write this file in other form of nose-tests, and skip it when target is not supported. Alternatively, move it to topi/recipe for now

with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})

def _intrin_reduce4int8_1x1(vec_size, num_elements_intel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove _intrin prefix if it is already in the tensor_intrin.py file. Make it a public function, document all the arguments and return types

@tqchen
Copy link
Member

tqchen commented Sep 17, 2018

Make some comments mainly on documenting and make the code clear.

@tqchen
Copy link
Member

tqchen commented Sep 20, 2018

related PR for CUDA #1735

@tqchen
Copy link
Member

tqchen commented Sep 20, 2018

@anijain2305 please follow up to fix the recent reviews comments and let us bring this in

Int8 dot product by every 4 elements using AVX2 Skylake instructions

Parameters
-------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring issue, the underline should be the same as the data https://docs.tvm.ai/contribute/document.html#document-python

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the pointer


def reduce_4int8_1x1(int32_lanes, num_elements_intel):
"""
Int8 dot product by every 4 elements using AVX2 Skylake instructions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we give a more detailed example of the semantics here? i.i. what is the input what is the output. The parameter naming also seems obscure to me.

@tqchen
Copy link
Member

tqchen commented Sep 21, 2018

Thanks for all the changes. The only complain I have is that the intrinsic functions' parameter naming seems to be confusing and it is hard for me to tell what it does exactly, we should be cautious on how we name the API since they are going to be used by the users. Maybe one way to make things clear is to document the behavior of the intrinsic using array and pseudo code.

Everyone is also welcomed to put weight on the API @ajtulloch @vinx13 @yizhi

@anijain2305
Copy link
Contributor Author

@tqchen Thanks for helping out with clear documentation. I have thought more clearly about the API and realized that it doesn't need any parameters as the tensor intrin is specific for Skylake machine. I have added a small summary with a pseudo code. Please review again and let me know if it needs more improvement.

function returns a TensorIntrin that can be used to tensorize
a schedule.

Parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is no parameters, we do not need to do parameters

datatype. Each entry of output array output[i] is equal to dot product
of data[4] and corresponding kernel[i][4]. The pseudo code is as follows

for (int i = 0; i < 16; i++)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can embed c code in the docstring via restructured text tag. See example in https://docs.tvm.ai/contribute/document.html#document-python
look for (.. code::)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is helpful to declare the pseudo code as function, like

void intrin_name(int8 data[4], int8 kernel[16][4], int32 output[16]) {
    body of the code
}

import tvm


def reduce_4int8_common():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "common" mean in here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 2 different schedules for Intel x86. First one is for 1x1 and second one is for 3x3 kernel. The common here means other kernel sizes.

One way to resolve this confusion is to remove "common". Other one can have 1x1. Thoughts?

@tqchen
Copy link
Member

tqchen commented Sep 21, 2018

Thanks for the set of changes. Maybe we could put a bit more thought in terms of the intrinsic naming. I think there are two sensible ways to do so:

  • Use the native name for the intrinsic, e.g. dp4a
  • Use the mathematical meaning of the intrinsic
    • Most intrinsic are doing matrix-vector product or dot product
      • We could use things like dot_8x1x8_int8_int8_int32

Thoughts?

@anijain2305
Copy link
Contributor Author

I like the second option better as it is more accurate.
For naming, how about we put the vector length and data type together. e.g. dot_4xint8_16x4xint8_16xint32. (this is AVX512, so there are 16 vector lanes)

@tqchen
Copy link
Member

tqchen commented Sep 25, 2018

Thanks @anijain2305, @yzhliu this can be merged

@yzhliu yzhliu merged commit 72ad9a3 into apache:master Sep 25, 2018
@yzhliu
Copy link
Member

yzhliu commented Sep 25, 2018

Thanks everyone's effort!

@masahi
Copy link
Member

masahi commented Oct 15, 2018

@anijain2305 what LLVM version do I need to run test_conv_int8_intel.py? I'm getting

AssertionError: llvm.x86.avx512.pmaddubs.w.512 is not an LLVM intrinsic

with LLVM 6.0.

@anijain2305
Copy link
Contributor Author

This error is due to older LLVM version. Looks like, LLVM 6.0 does not support AVX512BW instructions.
I am using LLVM 8.0 and it works with that.

@masahi
Copy link
Member

masahi commented Oct 16, 2018

thanks, got it working with llvm trunk.

FrozenGene pushed a commit to FrozenGene/tvm that referenced this pull request Dec 27, 2018
…chines (apache#1680)

* Int8 implementation for convolution operator on Intel Skylake

* Int8 implementation for convolution operator on Intel Skylake

* PR changes

* PR changes

* PR changes

* Fixing an error

* Fixing an error

* Minor typos fix

* Minor typos fix

* Removing the broadcast16 CPP code. Using astype feature instead

* Replacing constant by variable name num_elements_intel

* Name fixes and tensorize update rule updated

* Fixing the bug about checking skylake

* Replacing bitcast with reinterpret

* Isolating INT8 and FP32 schedules to ease out future AutoTVM PR merge

* Putting check_skylake function in the x86 directory

* Added documentation and organizing files to better locations

* Tensor intrin renaming. Avoid code duplication for intrin by kernel reshaping
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

status: need update need update based on feedbacks

Projects

None yet

Development

Successfully merging this pull request may close these issues.