Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
from tvm import relay
from .. import op as reg
from ....topi.x86.utils import target_has_sse42

#################################################
# Register the functions for different operators.
Expand Down Expand Up @@ -318,7 +319,7 @@ def _shift(data, zero_point, out_dtype):
def is_fast_int8_on_intel():
"""Checks whether the hardware has support for fast Int8 arithmetic operations."""
target = tvm.target.Target.current(allow_none=False)
return target.mcpu in {"skylake-avx512", "cascadelake"}
return target_has_sse42(target.mcpu)


def is_fast_int8_on_arm():
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from ..generic import conv2d as conv2d_generic
from ..utils import get_const_tuple, simplify
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1
Expand Down Expand Up @@ -157,7 +157,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
kernel_vec,
conv_out,
last,
int32_lanes=16,
int32_lanes=get_simd_32bit_lanes(),
intrin=dot_16x1x16_uint8_int8_int32(),
)

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/x86/conv2d_avx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from ..generic import conv2d as conv2d_generic
from ..utils import get_const_tuple
from .tensor_intrin import dot_16x1x16_uint8_int8_int32
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
Expand Down Expand Up @@ -174,6 +174,6 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last):
kernel_vec,
conv_out,
last,
int32_lanes=16,
int32_lanes=get_simd_32bit_lanes(),
intrin=dot_16x1x16_uint8_int8_int32(),
)
5 changes: 2 additions & 3 deletions python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..utils import get_const_tuple, traverse_inline
from .. import nn
from . import conv2d_avx_1x1, conv2d_avx_common
from .utils import target_has_sse42


def _get_default_config_int8(
Expand Down Expand Up @@ -73,9 +74,7 @@ def is_int8_hw_support(data_dtype, kernel_dtype):

# 3) Check target
mcpu = tvm.target.Target.current().mcpu
is_target_support = False
if mcpu in ("skylake-avx512", "cascadelake"):
is_target_support = True
is_target_support = target_has_sse42(mcpu)

return is_dtype_support and is_llvm_support and is_target_support

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..nn.utils import get_pad_tuple3d, infer_pad3d
from ..nn.pad import pad
from ..utils import get_const_tuple, simplify, get_const_int
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes

Workload3D = namedtuple(
"Workload",
Expand Down Expand Up @@ -520,7 +520,7 @@ def _get_conv3d_workload(data, kernel, stride, padding, out_dtype, data_layout="


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad
DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.contrib import mkl
from tvm.contrib import mkldnn

from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes
from .. import generic, tag
from ..utils import traverse_inline, get_const_tuple

Expand Down Expand Up @@ -107,7 +107,7 @@ def _default_dense_pack_config(cfg, M, N, K):
if isinstance(K, (tvm.tir.Var, tvm.tir.Any)):
K = 16

vec_width = get_fp32_len()
vec_width = get_simd_32bit_lanes()
tilex_ii = 1
for bn in range(vec_width * 2, 0, -1):
if N % bn == 0:
Expand Down Expand Up @@ -145,7 +145,7 @@ def _default_dense_nopack_config(cfg, M, N, K):
if isinstance(K, (tvm.tir.Var, tvm.tir.Any)):
K = 16

vec_width = get_fp32_len()
vec_width = get_simd_32bit_lanes()
tilek_bn = 1
for bn in range(vec_width * 2, 0, -1):
if K % bn == 0:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..nn.depthwise_conv2d import _get_workload, depthwise_conv2d_infer_layout
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..utils import traverse_inline
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def _fallback_schedule(cfg, wkl):
Expand All @@ -40,7 +40,7 @@ def _fallback_schedule(cfg, wkl):
wkl : topi.nn.depthwise_conv2d.Workload
Convolution workload
"""
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()

pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/group_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm import te
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity

from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes
from ..utils import get_const_tuple
from ..nn.pad import pad
from .. import tag
Expand Down Expand Up @@ -62,7 +62,7 @@ def _get_default_config(


def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
pad_left, pad_right = wkl.padl, wkl.padr
stride_w = wkl.stride_w
out_width = (wkl.width + pad_left + pad_right - wkl.kernel_w) // stride_w + 1
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from tvm import te

from ..utils import traverse_inline, get_const_int
from .utils import get_fp32_len
from .utils import get_simd_32bit_lanes


def schedule_sparse_dense(outs):
"""Create schedule for sparse dense"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
simd_width = get_fp32_len()
simd_width = get_simd_32bit_lanes()
if op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_lhs_csrmm":
(y_o, y_i) = s[op].split(s[op].op.axis[1], 2)
fused = s[op].fuse(s[op].op.axis[0], y_o)
Expand Down
60 changes: 41 additions & 19 deletions python/tvm/topi/x86/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@
import tvm
from tvm import te
import tvm.target.codegen
from .utils import target_has_sse42, target_has_vnni, get_simd_32bit_lanes


def dot_16x1x16_uint8_int8_int32():
"""Dispatch the most optimized intrin depending on the target"""
mcpu = tvm.target.Target.current().mcpu

assert mcpu in (
"skylake-avx512",
"cascadelake",
), "An old Intel machine that does not have fast Int8 support."
if mcpu == "skylake-avx512":
return dot_16x1x16_uint8_int8_int32_skylake()
# cascadelake
return dot_16x1x16_uint8_int8_int32_cascadelake()
assert target_has_sse42(mcpu), "An old Intel machine that does not have fast Int8 support."
if target_has_vnni(mcpu):
# VNNI capable platform
return dot_16x1x16_uint8_int8_int32_cascadelake()
# vpmaddubsw/vpmaddwd fallback
return dot_16x1x16_uint8_int8_int32_skylake()


def dot_16x1x16_uint8_int8_int32_skylake():
Expand Down Expand Up @@ -64,7 +63,7 @@ def dot_16x1x16_uint8_int8_int32_skylake():
The Skylake int8 TensorIntrin that can be used in tensorizing schedule
"""

int32_lanes = 16 # 16 int32 lanes in AVX512
int32_lanes = get_simd_32bit_lanes()
num_int8_elements = 4 # 4 int8 elements in int32
data = te.placeholder((num_int8_elements,), dtype="uint8", name="data")
kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel")
Expand All @@ -84,35 +83,58 @@ def dot_16x1x16_uint8_int8_int32_skylake():

def _intrin_func(ins, outs):
def _instr(index):
# int_lx32 - output datatype after pmaddubs - 16 bits to number of lanes
# int_8xl - input datatype to pmaddubs - 8 bits to number of lanes
# int_32xl - output datatype after pmaddw - 32 bits per number of lanes

if int32_lanes == 4:
int_lx32 = "int16x8"
int_8xl = "int8x16"
int_32xl = "int32x4"
pmaddubs = "llvm.x86.ssse3.pmadd.ub.sw.128"
pmaddw = "llvm.x86.sse2.pmadd.wd"
elif int32_lanes == 8:
int_lx32 = "int16x16"
int_8xl = "int8x32"
int_32xl = "int32x8"
pmaddubs = "llvm.x86.avx2.pmadd.ub.sw"
pmaddw = "llvm.x86.avx2.pmadd.wd"
elif int32_lanes == 16:
int_lx32 = "int16x32"
int_8xl = "int8x64"
int_32xl = "int32x16"
pmaddubs = "llvm.x86.avx512.pmaddubs.w.512"
pmaddw = "llvm.x86.avx512.pmaddw.d.512"

ib = tvm.tir.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.tir.const(0, "int32x16")))
ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl)))
return ib.get()

a_int8 = ins[0].vload([0], "uint8x4")
re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
vec_ai32 = re_int32.astype("int32x16")
vec_a = tvm.tir.call_intrin("int8x64", "tir.reinterpret", vec_ai32)
vec_b = ins[1].vload([0, 0], "int8x64")
vec_one = tvm.tir.const(1, "int16x32")
vec_ai32 = re_int32.astype(int_32xl)
vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)
vec_b = ins[1].vload([0, 0], int_8xl)
vec_one = tvm.tir.const(1, int_lx32)
pair_reduction = tvm.tir.call_llvm_pure_intrin(
"int16x32",
"llvm.x86.avx512.pmaddubs.w.512",
int_lx32,
pmaddubs,
tvm.tir.const(0, "uint32"),
vec_a,
vec_b,
)
quad_reduction = tvm.tir.call_llvm_pure_intrin(
"int32x16",
"llvm.x86.avx512.pmaddw.d.512",
int_32xl,
pmaddw,
tvm.tir.const(0, "uint32"),
pair_reduction,
vec_one,
)
if index == 0:
ib.emit(outs[0].vstore(0, quad_reduction))
else:
ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x16")))
ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], int_32xl)))
return ib.get()

# body, reset, update
Expand Down
92 changes: 89 additions & 3 deletions python/tvm/topi/x86/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,95 @@
import tvm


def get_fp32_len():
def target_has_sse42(target):
return (
target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"silvermont",
"slm",
"goldmont",
"goldmont-plus",
"tremont",
"nehalem",
"corei7",
"westmere",
"bdver1",
"bdver2",
"bdver3",
"x86-64-v2",
}
)


def target_has_avx(target):
return (
target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target in {"sandybridge", "corei7-avx", "ivybridge", "core-avx-i"}
)


def target_has_avx2(target):
return (
target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"haswell",
"core-avx2",
"broadwell",
"skylake",
"bdver4",
"znver1",
"znver2",
"znver3",
"x86-64-v3",
}
)


def target_has_avx512(target):
return target in {
"skylake-avx512",
"skx",
"knl",
"knm",
"x86-64-v4",
"cannonlake",
# explicit enumeration of VNNI capable due to collision with alderlake
"cascadelake",
"icelake-client",
"rocketlake",
"icelake",
"tigerlake",
"cooperlake",
"sapphirerapids",
}


def target_has_vnni(target):
return target in {
"cascadelake",
"icelake-client",
"rocketlake",
"icelake",
"tigerlake",
"cooperlake",
"sapphirerapids",
"alderlake",
}


def get_simd_32bit_lanes():
mcpu = tvm.target.Target.current().mcpu
fp32_vec_len = 8
if mcpu in ("skylake-avx512", "cascadelake"):
fp32_vec_len = 4
if target_has_avx512(mcpu):
fp32_vec_len = 16
elif target_has_avx2(mcpu):
fp32_vec_len = 8
return fp32_vec_len
Loading