Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions topi/python/topi/cuda/tensor_intrin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Tensor intrinsics on CUDA."""
#pylint: disable=invalid-name
import tvm


def dp4a(x_scope='local', y_scope='local', z_scope='local'):
"""
Int8 dot product reduced by every 4 elements using __dp4a

Parameters
----------
x_scope : str, optional
The storage scope of buffer for lhs
y_scope : str, optional
The storage scope of buffer for rhs
z_scope : str, optional
The storage scope of buffer for result

Returns
-------
intrin : TensorIntrin
The dp4a TensorIntrin that can be used in tensorizing schedule.
"""

n = 4 # dp4a requires operands packed by 4
x = tvm.placeholder((n,), name='x', dtype='int8')
y = tvm.placeholder((n,), name='y', dtype='int8')

k = tvm.reduce_axis((0, n), name='rc')

z = tvm.compute((1,), lambda i: tvm.sum(
x[k].astype('int32') * y[k].astype('int32'), axis=[k]))

def _intrin_func(ins, outs):
def _instr(index):
xx, yy = ins
zz = outs[0]

if index == 1:
return zz.vstore(0, 0)

ib = tvm.ir_builder.create()

vec_x = xx.vload(0, dtype='int8x4')
vec_y = yy.vload(0, dtype='int8x4')
prev_z = 0 if index == 0 else zz.vload(0)

new_z = tvm.call_pure_extern('int32', '__dp4a', vec_x, vec_y, prev_z)
ib.emit(zz.vstore(0, new_z))

return ib.get()

return _instr(0), _instr(1), _instr(2) # body, reset, update

with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
scopes = {x: x_scope, y: y_scope, z: z_scope}
binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
scope=scopes[t]) for t in [x, y, z]}

return tvm.decl_tensor_intrin(z.op, _intrin_func, binds=binds)
38 changes: 3 additions & 35 deletions topi/recipe/gemm/gemm_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,12 @@
import numpy as np
import tvm
from tvm import autotvm
from topi.cuda.tensor_intrin import dp4a

DO_TUNING = True
PRETUNED_INDEX = 75333

def intrin_dot():
n = 4 # dp4a requires operands packed by 4
x = tvm.placeholder((n,), name='x', dtype='int8')
y = tvm.placeholder((n,), name='y', dtype='int8')
k = tvm.reduce_axis((0, n), name='k')

z = tvm.compute(
(1,), lambda _: tvm.sum(
x[k].astype('int32') * y[k].astype('int32'), axis=k))

def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
ib = tvm.ir_builder.create()

dp4a = zz.vstore(0, tvm.call_pure_extern('int32', '__dp4a',
xx.vload(0, dtype='int8x4'),
yy.vload(0, dtype='int8x4'),
zz.vload(0)))
ib.emit(dp4a)

body = ib.get()
return body, zz.vstore(0, 0), body

with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor,
scope='local') for t in [x, y, z]}
return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)


dot = intrin_dot()

intrin_dp4a = dp4a('local', 'local', 'local')

@autotvm.template
def gemm_int8(n, m, l):
Expand Down Expand Up @@ -70,7 +38,7 @@ def gemm_int8(n, m, l):

ko, kt, ki = cfg['tile_k'].apply(s, CC, k)

s[CC].tensorize(ki, dot)
s[CC].tensorize(ki, intrin_dp4a)

block_x = tvm.thread_axis('blockIdx.x')
block_y = tvm.thread_axis('blockIdx.y')
Expand Down