Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions topi/python/topi/cuda/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], ic)
s[Filter_S].compute_at(s[Out_L], w)

num_thread1 = 512
num_thread1 = tvm.target.current_target(allow_none=False).max_num_threads
thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

Expand Down Expand Up @@ -116,7 +116,7 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw)

num_thread = 512
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

Expand Down Expand Up @@ -186,15 +186,18 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag):
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
max_num_thread = tvm.target.current_target(allow_none=False).max_num_threads
if util.get_const_int(Filter_S.shape[1]) == 128:
oic = s[Out_L].fuse(dh, oic)
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)
num_thread = 512
num_thread = max_num_thread
else:
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], dw)
num_thread = 456
if max_num_thread < num_thread:
num_thread = max_num_thread

thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")
Expand Down Expand Up @@ -300,7 +303,7 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L):
s[temp_S].compute_at(s[Out_L], oic)
s[Filter_S].compute_at(s[Out_L], oic)

num_thread = 512
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x")
block_xx = tvm.thread_axis("blockIdx.x")

Expand Down
9 changes: 7 additions & 2 deletions topi/python/topi/cuda/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def traverse(OP):
traverse(outs[0].op)
return s

@generic.schedule_depthwise_conv2d_nhwc.register(["cuda", "gpu"])
def schedule_depthwise_conv2d_nhwc(outs):
"""Schedule for depthwise_conv2d nhwc forward.

Expand Down Expand Up @@ -151,8 +152,12 @@ def _schedule(temp, Filter, DepthwiseConv2d):

b, h, w, c = s[Output].op.axis

ic_val = tvm.ir_pass.Simplify(temp.shape[3]).value
xoc, xic = s[Output].split(c, factor=ic_val)
# num_thread here could be 728, it is larger than cuda.max_num_threads
num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value
target = tvm.target.current_target()
if target and target.target_name != "cuda":
num_thread = target.max_num_threads
xoc, xic = s[Output].split(c, factor=num_thread)
s[Output].reorder(xoc, b, h, w, xic)
xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2)
fused = s[Output].fuse(yo, xo)
Expand Down
4 changes: 1 addition & 3 deletions topi/python/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
def _schedule_injective(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
target = tvm.target.current_target()
target = target if target else tvm.target.cuda()
num_thread = target.max_num_threads
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def schedule_pool(outs):
s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Pool):
s[PaddedInput].compute_inline()
num_thread = 512
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
if Pool.op in s.outputs:
Out = Pool
OL = s.cache_write(Pool, "local")
Expand Down
7 changes: 6 additions & 1 deletion topi/python/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
if len(sch[data_out].op.axis) > 0:
all_reduce = False
num_thread = 32
target = tvm.target.current_target()
if target and target.target_name == "opencl":
# without it, CL_INVALID_WORK_GROUP_SIZE occured when running test_topi_reduce.py
# don't know why
num_thread = 16
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
else:
all_reduce = True
num_thread = 512
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")

# Fuse and refactor the reduce axis
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def schedule_depthwise_conv2d_nchw(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_depthwise_conv2d_nhwc(outs):
"""Schedule for depthwise nhcw conv2
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_reduce(outs):
"""Schedule for reduction
Expand Down
9 changes: 5 additions & 4 deletions topi/tests/python/test_topi_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def check_device(device):
s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu)

ctx = tvm.context(device, 0)
# build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
Expand Down Expand Up @@ -107,14 +106,16 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
# schedule
s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
s2 = schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = schedule_depthwise_conv2d_nhwc(Relu)

def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return

with tvm.target.create(device):
s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift)
s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu)
ctx = tvm.context(device, 0)
# build the kernels
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
Expand Down
3 changes: 2 additions & 1 deletion topi/tests/python/test_topi_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
def verify_relu(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.relu(A)
s = topi.cuda.schedule_elemwise(B)

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0)
Expand All @@ -17,6 +16,8 @@ def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_elemwise(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
Expand Down