-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Tensorize][runtime] Add support for AMX(Advanced Matrix Extensions) through Tensor intrinsics #13642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tensorize][runtime] Add support for AMX(Advanced Matrix Extensions) through Tensor intrinsics #13642
Changes from all commits
3230886
2312242
3e2fc4e
3f19099
c53c394
98b9a23
79d6636
b866673
48fa37e
dd1eb24
73f45ef
b921052
e749360
5718a05
581331a
2bda03e
c2e9f26
4469fd9
f763d52
1f59aff
383d0b2
9422363
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| if(USE_AMX) | ||
| file(GLOB AMX_RUNTIME_CONFIG src/runtime/contrib/amx/amx_config.cc) | ||
| list(APPEND COMPILER_SRCS ${AMX_RUNTIME_CONFIG}) | ||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=sapphirerapids") | ||
| message(STATUS "Build with Intel AMX support...") | ||
| endif() |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,8 +14,8 @@ | |||||
| # KIND, either express or implied. See the License for the | ||||||
| # specific language governing permissions and limitations | ||||||
| # under the License. | ||||||
| # pylint: disable=invalid-name,too-many-locals,unused-variable | ||||||
| # pylint: disable=no-value-for-parameter | ||||||
| # pylint: disable=invalid-name,too-many-locals,unused-argument | ||||||
| # pylint: disable=no-value-for-parameter,unused-variable | ||||||
| """x86 dense operators""" | ||||||
| from __future__ import absolute_import as _abs | ||||||
|
|
||||||
|
|
@@ -27,7 +27,9 @@ | |||||
| from .. import generic, tag | ||||||
| from ..utils import get_const_tuple, traverse_inline | ||||||
| from .tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake | ||||||
| from .utils import get_simd_32bit_lanes | ||||||
| from .tensor_intrin import dot_32x128x32_u8s8s32_sapphirerapids | ||||||
| from .tensor_intrin import acc_32x32_int32_sapphirerapids | ||||||
| from .utils import get_simd_32bit_lanes, target_has_vnni, target_has_amx | ||||||
|
|
||||||
|
|
||||||
| def _schedule_dense_pack_template(cfg, s, C, O): | ||||||
|
|
@@ -278,11 +280,45 @@ def _callback(op): | |||||
| return s | ||||||
|
|
||||||
|
|
||||||
| def dense_vnni_compute(cfg, X, packed_w, bias=None): | ||||||
| @autotvm.register_topi_compute("dense_int8.x86") | ||||||
| def dense_int8(cfg, data, weight, bias=None, out_dtype=None): | ||||||
| """Compute for uint8 x int8 -> int32 dense""" | ||||||
| if out_dtype is None: | ||||||
| out_dtype = data.dtype | ||||||
| assert len(weight.shape) == 4 | ||||||
| assert data.dtype == "uint8" and weight.dtype == "int8" | ||||||
| _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim | ||||||
| assert n_inner == 16 and k_inner == 4 | ||||||
| return dense_int8_compute(cfg, data, weight, bias) | ||||||
|
|
||||||
|
|
||||||
| @autotvm.register_topi_schedule("dense_int8.x86") | ||||||
| def schedule_dense_int8(cfg, outs): | ||||||
| """Create a schedule for dense__int8""" | ||||||
| s = te.create_schedule([x.op for x in outs]) | ||||||
| mcpu = tvm.target.Target.current().mcpu | ||||||
|
|
||||||
| def _callback(op): | ||||||
| if "dense_int8" in op.tag: | ||||||
| if target_has_amx(mcpu): | ||||||
| dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) | ||||||
| elif target_has_vnni(mcpu): | ||||||
| dense_vnni_schedule(cfg, s, op.output(0), outs[0]) | ||||||
|
|
||||||
| traverse_inline(s, outs[0].op, _callback) | ||||||
| return s | ||||||
|
|
||||||
|
|
||||||
| def dense_int8_compute(cfg, X, packed_w, bias=None): | ||||||
| """Compute for uint8 x int8 -> int32 dense""" | ||||||
| m, k = X.shape | ||||||
| n_o, _, n_i, _ = packed_w.shape | ||||||
| ak = te.reduce_axis((0, k), name="k") | ||||||
| mcpu = tvm.target.Target.current().mcpu | ||||||
| if target_has_vnni(mcpu): | ||||||
| target_attr = {"schedule_rule": "meta_schedule.x86.dense_vnni"} | ||||||
| else: | ||||||
| target_attr = None | ||||||
|
|
||||||
| C = te.compute( | ||||||
| (m, n_o * n_i), | ||||||
|
|
@@ -293,16 +329,13 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): | |||||
| ), | ||||||
| axis=ak, | ||||||
| ), | ||||||
| tag="dense_vnni", | ||||||
| attrs={"schedule_rule": "dense_vnni"}, | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be removed, it is used here
Since this only affects MetaSchedule, you don't have to provide this value for AMX. So only when
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @masahi Is this case only use the x86 int8 compute method and inject a particular TIR scheduling? Can we just change the attribute
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but it is important that we specify that we use this compute for VNNI. If the schedule rule annotation only says "dense_int8", we don't know which intrinsic to tensorize this compute with.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@masahi Sorry, may given the misunderstanding, I mean that can we use the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course the test still works, because it was already written for VNNI. The point is that the name So please revert that commit and restore and pass
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @masahi yep, got it, the schedule rule for ms dense_vnni are restored.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest checking the target string in the op strategy, and create separate compute for VNNI or AMX (rather than using the same function, |
||||||
| tag="dense_int8", | ||||||
| attrs=target_attr, | ||||||
| ) | ||||||
|
|
||||||
| if bias is not None: | ||||||
| C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST) | ||||||
|
|
||||||
| a_y, _ = C.op.axis | ||||||
| cfg.define_split("tile_y", a_y, num_outputs=2) | ||||||
|
|
||||||
| return C | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -317,6 +350,7 @@ def split_y(out): | |||||
| if cfg.is_fallback: | ||||||
| return s[out].split(a_y, factor=default_y_split_factor) | ||||||
|
|
||||||
| cfg.define_split("tile_y", a_y, num_outputs=2) | ||||||
| return cfg["tile_y"].apply(s, out, a_y) | ||||||
|
|
||||||
| (a_k,) = C.op.reduce_axis | ||||||
|
|
@@ -348,26 +382,111 @@ def split_y(out): | |||||
| return s, fused | ||||||
|
|
||||||
|
|
||||||
| @autotvm.register_topi_compute("dense_vnni.x86") | ||||||
| def dense_vnni(cfg, data, weight, bias=None, out_dtype=None): | ||||||
| """Compute for uint8 x int8 -> int32 dense""" | ||||||
| if out_dtype is None: | ||||||
| out_dtype = data.dtype | ||||||
| assert len(weight.shape) == 4 | ||||||
| assert data.dtype == "uint8" and weight.dtype == "int8" | ||||||
| _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim | ||||||
| assert n_inner == 16 and k_inner == 4 | ||||||
| return dense_vnni_compute(cfg, data, weight, bias) | ||||||
| def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True): | ||||||
| """Schedule dense compute using AMX TMUL instruction""" | ||||||
| # C: The output of GEMM | ||||||
| # O: The output of the fused op | ||||||
| def split_x(out): | ||||||
| default_x_split_factor1 = 32 | ||||||
| default_x_split_factor2 = 2 | ||||||
| default_x_split_factor3 = 2 | ||||||
| default_x_split_factor4 = 2 | ||||||
| a_x = s[out].op.axis[-2] | ||||||
|
|
||||||
| if cfg.is_fallback: | ||||||
| a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1) | ||||||
| a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2) | ||||||
| a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3) | ||||||
| a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4) | ||||||
| return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi] | ||||||
|
|
||||||
| cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: x.size[-1] == 32) | ||||||
| return cfg["tile_x"].apply(s, out, a_x) | ||||||
|
|
||||||
| def split_y(out): | ||||||
| default_y_split_factor1 = 32 | ||||||
| default_y_split_factor2 = 4 | ||||||
| default_y_split_factor3 = 4 | ||||||
| default_y_split_factor4 = 4 | ||||||
| a_y = s[out].op.axis[-1] | ||||||
|
|
||||||
| if cfg.is_fallback: | ||||||
| a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1) | ||||||
| a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2) | ||||||
| a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3) | ||||||
| a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4) | ||||||
| return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo] | ||||||
|
|
||||||
| cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: y.size[-1] == 32) | ||||||
| return cfg["tile_y"].apply(s, out, a_y) | ||||||
|
|
||||||
| def split_k(out, rd_axis): | ||||||
| default_k_split_factor1 = 128 | ||||||
| default_k_split_factor2 = 2 | ||||||
| default_k_split_factor3 = 2 | ||||||
| default_k_split_factor4 = 2 | ||||||
|
|
||||||
| if cfg.is_fallback: | ||||||
| a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1) | ||||||
| a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2) | ||||||
| a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3) | ||||||
| a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4) | ||||||
| return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki] | ||||||
|
|
||||||
| cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128) | ||||||
| return cfg["tile_k"].apply(s, out, rd_axis) | ||||||
|
|
||||||
| a_x, a_y = C.op.axis | ||||||
| (a_k,) = C.op.reduce_axis | ||||||
| CF = s.cache_write(C, "amx.tmm") | ||||||
|
|
||||||
| a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C) | ||||||
| a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C) | ||||||
| s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi) | ||||||
|
|
||||||
| s[CF].compute_at(s[C], a_yo) | ||||||
|
|
||||||
| (a_k_f,) = CF.op.reduce_axis | ||||||
| a_x_f, a_y_f = CF.op.axis | ||||||
|
|
||||||
| a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) | ||||||
|
|
||||||
| a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32) | ||||||
| a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f) | ||||||
| s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) | ||||||
|
|
||||||
| (m, k) = CF.op.input_tensors[0].shape | ||||||
| (n, c, n_i, c_i) = CF.op.input_tensors[1].shape | ||||||
| n = n * n_i | ||||||
|
|
||||||
| s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k))) | ||||||
| s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n))) | ||||||
|
|
||||||
| if C == O: | ||||||
| fused = s[O].fuse(a_x3, a_y3) | ||||||
| else: | ||||||
| a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O) | ||||||
| a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O) | ||||||
|
|
||||||
| s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, a_xi) | ||||||
| s[O].vectorize(a_xi) | ||||||
|
|
||||||
| fused = s[O].fuse(a_x3, a_y3) | ||||||
|
|
||||||
| if do_parallel: | ||||||
| s[O].parallel(fused) | ||||||
|
|
||||||
| return s, fused | ||||||
|
|
||||||
|
|
||||||
| @autotvm.register_topi_schedule("dense_vnni.x86") | ||||||
| def schedule_dense_vnni(cfg, outs): | ||||||
| """Create a schedule for dense_vnni""" | ||||||
| @autotvm.register_topi_schedule("dense_amx_int8.x86") | ||||||
| def schedule_dense_amx_int8(cfg, outs): | ||||||
| """Create a schedule for dense_amx_int8""" | ||||||
| s = te.create_schedule([x.op for x in outs]) | ||||||
|
|
||||||
| def _callback(op): | ||||||
| if "dense_vnni" in op.tag: | ||||||
| dense_vnni_schedule(cfg, s, op.output(0), outs[0]) | ||||||
| if "dense_amx_int8" in op.tag: | ||||||
| dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) | ||||||
|
|
||||||
| traverse_inline(s, outs[0].op, _callback) | ||||||
| return s | ||||||
Qianshui-Jiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.