Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ def get_configs(self, target: Target) -> Config:
inner_x=False,
)
elif target.kind.name == "opencl" and (
("android" in str(target.host)) or ("windows" in str(target.host))
("android" in str(target.host)) or ("adreno" in str(target.attrs))
):
return Matmul.Config(
block_size_x=32,
Expand Down Expand Up @@ -991,7 +991,10 @@ def is_inner_reduction(block_stmt, iter_infos):
end_it = block_stmt.reads[-1].region[-1].min
return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == "R"

if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, iter_infos):
if (
target.kind.name == "opencl"
and (("android" in str(target.host)) or ("adreno" in str(target.attrs)))
) and not is_inner_reduction(block_stmt, iter_infos):
ret = self.sch_outer_reduction(sch, config, main_block, blocks)
if ret is not None:
return ret
Expand Down Expand Up @@ -1122,6 +1125,16 @@ def sch_outer_reduction(
reduction_block: tir.schedule.BlockRV,
blocks: List[tir.schedule.BlockRV],
) -> Optional[tir.Schedule]:

"""Get vectorization factor"""

def get_max_factor(n, factors):
factors = sorted(factors, reverse=True)
for factor in factors:
if n % factor == 0:
return factor
return 1

reduction_loops = sch.get_loops(reduction_block)
if not len(reduction_loops) == 4:
return None
Expand All @@ -1140,13 +1153,17 @@ def sch_outer_reduction(
config.vector_size,
config.unroll,
)

is_dequant_block = len(blocks) > 1
if is_dequant_block:
compute_block, dequant_block, matmul_block = blocks
sch.compute_inline(compute_block)
else:
(matmul_block,) = blocks
VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize)
dequant_block = None
matmul_block = reduction_block
epilogue_block = None
if blocks[-1] is not matmul_block:
epilogue_block = blocks[-1]
for blk in blocks[:-1]:
if "dequantize" in sch.get(blk).name_hint:
dequant_block = blk
elif blk is not matmul_block:
sch.compute_inline(blk)

m = sch.fuse(mb, ms)

Expand All @@ -1162,20 +1179,21 @@ def sch_outer_reduction(
sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)

sch.compute_at(rmat_block, k0)
if is_dequant_block:
if dequant_block is not None:
sch.compute_at(dequant_block, k3)
sch.reverse_compute_at(wmat_block, mi)
sch.set_scope(rmat_block, 0, "shared")
sch.set_scope(matmul_block, 0, "local")
if is_dequant_block:

if dequant_block is not None:
sch.set_scope(dequant_block, 0, "local")

sch.bind(mo, "blockIdx.y")
sch.bind(no, "blockIdx.x")
sch.bind(mi, "threadIdx.y")
sch.bind(ni, "threadIdx.x")
sch.vectorize(sch.get_loops(matmul_block)[-1])
if is_dequant_block:
if dequant_block is not None:
sch.vectorize(sch.get_loops(dequant_block)[-1])

# Co-operative Memory Fetch
Expand All @@ -1187,7 +1205,7 @@ def sch_outer_reduction(
sch.vectorize(wv)

# Scale and Quant Cache
if is_dequant_block:
if dequant_block is not None:
qb = sch.cache_read(dequant_block, 0, "local")
sb = sch.cache_read(dequant_block, 1, "local")
sch.compute_at(sb, k1)
Expand All @@ -1197,5 +1215,11 @@ def sch_outer_reduction(
sch.vectorize(sch.get_loops(qb)[-1])
sch.vectorize(sch.get_loops(sb)[-1])

if epilogue_block is not None:
sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True)
sch.set_scope(wmat_block, 0, "local")
sch.compute_inline(wmat_block)
sch.vectorize(sch.get_loops(epilogue_block)[-1])

sch.decompose_reduction(matmul_block, k0)
return sch
89 changes: 48 additions & 41 deletions tests/python/dlight/test_gpu_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,47 +685,54 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
# fmt: off
@T.prim_func
def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle):
def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
seq_len = T.int64()
rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16")
matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16")
T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
# with T.block("root"):
compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(12288)), "float16")
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(lv840[v_i0 // T.int64(8), v_i1])
T.reads(lv452[v_i0 // T.int64(8), v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
with T.block("dequantize"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1])
T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1])
T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1]
dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1]
for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), T.int64(4096)):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
T.reads(rms_norm260[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2])
T.reads(rms_norm130[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2])
T.writes(matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2]
matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2]
for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2])
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]

@T.prim_func
def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle):
def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle):
T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
seq_len = T.int64()
rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16")
matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16")
T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
# with T.block("root"):
dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local")
rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared")
rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared")
matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", scope="local")
lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local")
lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local")
lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local")
lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local")
for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.y"):
for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
Expand All @@ -743,57 +750,57 @@ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T
for ax0 in range(T.int64(4)):
for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
for ax1_1 in T.vectorized(T.int64(8)):
with T.block("rms_norm260_pad"):
with T.block("rms_norm130_pad"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
T.reads(rms_norm260[v0, v1, v2])
T.writes(rms_norm260_pad_shared[v0, v1, v2])
rms_norm260_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0))
T.reads(rms_norm130[v0, v1, v2])
T.writes(rms_norm130_pad_shared[v0, v1, v2])
rms_norm130_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0))
for k_1 in range(T.int64(8)):
for ax0 in T.vectorized(T.int64(8)):
with T.block("lv841_local"):
with T.block("lv453_local"):
v0 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + k_1)
v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
T.reads(lv841[v0, v1])
T.writes(lv841_local[v0, v1])
lv841_local[v0, v1] = lv841[v0, v1]
T.reads(lv453[v0, v1])
T.writes(lv453_local[v0, v1])
lv453_local[v0, v1] = lv453[v0, v1]
for k_2 in range(T.int64(4)):
for ax0 in T.vectorized(T.int64(8)):
with T.block("lv840_local"):
with T.block("lv452_local"):
v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)
v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
T.reads(lv840[v0, v1])
T.writes(lv840_local[v0, v1])
lv840_local[v0, v1] = lv840[v0, v1]
T.reads(lv452[v0, v1])
T.writes(lv452_local[v0, v1])
lv452_local[v0, v1] = lv452[v0, v1]
for k_3 in range(T.int64(8)):
for ax0 in T.vectorized(T.int64(8)):
with T.block("dequantize"):
v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3)
v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
T.reads(lv840_local[v_i0 // T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1])
T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1])
T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1])
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1]
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1]
for i0_i1_fused_2 in range(T.int64(4)):
for i2_2 in T.vectorized(T.int64(8)):
with T.block("matmul_update"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3)
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm260_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2])
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2])
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
for ax0 in range(T.int64(4)):
for ax1 in T.vectorized(T.int64(8)):
with T.block("matmul_intermediate_pad"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
v2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
T.where((i0_i1_fused_0 - (seq_len + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len)
T.reads(matmul_intermediate_pad_local[v0, v1, v2])
T.writes(matmul_intermediate[v0, v1, v2])
matmul_intermediate[v0, v1, v2] = matmul_intermediate_pad_local[v0, v1, v2]
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
for ax2 in T.vectorized(T.int64(8)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(1), ax0)
v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1)
v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2)
T.where(i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1 < seq_len)
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2])
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]
# fmt: on


Expand Down