-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Description
Hi everyone. I am working on the backward method for the depth wise convolution. The implementation I currently can think of has a lot of tvm.select. Is there any way we could simplify the code?
def trans(b, i, j, c):
global Out_grad_cond
Out_grad_cond = tvm.compute(
(batch, in_h, in_w, out_c),
lambda bo, io, jo, co: tvm.select(tvm.all(io >= tvm.select(0<(i - filter_h + pad_h + stride_h) / stride_h,(i-filter_h+pad_h+stride_h)/stride_h,tvm.const(0)),
io < tvm.select(0<((i + pad_h) / stride_h)+1-out_h, tvm.const(out_h - 1), (i + pad_h) / stride_h),
jo >= tvm.select(0<(j - filter_w + pad_w + stride_w) / stride_w,(j-filter_w+pad_w+stride_w)/stride_w,tvm.const(0)),
jo < tvm.select(0<((j + pad_w) / stride_w)+1-out_w, tvm.const(out_w - 1), (j + pad_w) / stride_w)),
Out_grad[b, i, j, c], tvm.const(0.0)))
di = tvm.reduce_axis((0, out_h-1), name='di')
dj = tvm.reduce_axis((0, out_w-1), name='dj')
dc = tvm.reduce_axis((0, channel_multiplier), name='dc')
return tvm.sum(Out_grad_cond[b, di, dj, c*channel_multiplier + dc] * Filter[i+pad_h-di*stride_h, j+pad_w-dj*stride_w, c, dc],axis=[di,dj,dc])
In_grad = tvm.compute(
(batch, in_h, in_w, in_c),
lambda b, i, j, c: trans(b,i,j,c),
name='In_grad')
Thank you!
Metadata
Metadata
Assignees
Labels
No labels