Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion topi/python/topi/arm_cpu/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,16 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
def schedule_bitserial_conv2d_nhwc(outs):
"""Raspverry pi schedule for bitserial conv2d"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'spatial_bitserial_conv_nhwc' in op.tag:
Expand All @@ -360,6 +362,7 @@ def traverse(op):

_schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
kernel, kernel_q, kernel_vec, conv_out, output, outs[0])
scheduled_ops.append(op)

traverse(outs[0].op)
return s
5 changes: 4 additions & 1 deletion topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
def schedule_conv2d_nchw_arm_cpu(cfg, outs):
"""TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should handle this logic in the utility function traverse_inline.
Currently the traverse inline logic takes much redundant code in TOPI. I recommend that developers should switch to use this utility function.

Copy link
Member Author

@masahi masahi Aug 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it was pain to add all those scheduled_ops stuff to all backends.

But it is not a bug, right?

I was not in a mood to refactor all traverse logic to use traverse_inline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I looked at your traverse_inline function, I thought I need to pass around scheduled_ops to traverse_inline function. I didn't want to make this change, so I chose a more straightforward approach.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my local version. Possibly you can update it for arm_cpu only in your NNVM fusion PR.

def traverse_inline(s, final_op, callback):
    """Traverse computation graph and do auto inline

    Parameters
    ----------
    s: schedule
        The schedule
    final_op: Operation
        The final output operator.
    callback: callable
        The callback function on each op
    """
    visited = set()

    def _traverse(op):
        if op in visited:
            return
        visited.add(op)
        if tag.is_injective(op.tag):
            if op not in s.outputs:
                s[op].compute_inline()
            for tensor in op.input_tensors:
                if tensor.op.input_tensors:
                    traverse_inline(s, tensor.op, callback)
        callback(op)

    _traverse(final_op)

Copy link
Member Author

@masahi masahi Aug 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be

 for tensor in op.input_tensors:
        if tensor.op.input_tensors:
               _traverse(tensor.op)

no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated #1548 according to your comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, yes..


def _callback(op):
# schedule conv2d
if 'spatial_conv_output' in op.tag:
if 'spatial_conv_output' in op.tag and op not in scheduled_ops:
output = op.output(0)
conv = op.input_tensors[0]

Expand All @@ -64,6 +65,8 @@ def _callback(op):
output = op.output(0)
_schedule_winograd(cfg, s, output, outs[0])

scheduled_ops.append(op)

traverse_inline(s, outs[0].op, _callback)
return s

Expand Down
6 changes: 5 additions & 1 deletion topi/python/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def _schedule(cfg, s, data, data_pad, kernel, output):

return s

scheduled_ops = []

def _callback(op):
if op.tag == 'depthwise_conv2d_nchw':
if op.tag == 'depthwise_conv2d_nchw' and op not in scheduled_ops:
output = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
Expand All @@ -90,5 +92,7 @@ def _callback(op):
data = data_pad.op.input_tensors[0]
_schedule(cfg, s, data, data_pad, kernel, output)

scheduled_ops.append(op)

traverse_inline(s, outs[0].op, _callback)
return s
6 changes: 5 additions & 1 deletion topi/python/topi/cuda/conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,15 @@ def schedule(Apad, W, B):
sch[WW].bind(tx, thread_x)
sch[WW].vectorize(fi)

scheduled_ops = []

def traverse(operator):
"""Traverse operators from computation graph"""
if tag.is_broadcast(operator.tag):
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn':
Apad = operator.input_tensors[0]
Expand All @@ -117,5 +119,7 @@ def traverse(operator):
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)

scheduled_ops.append(operator)

traverse(outs[0].op)
return sch
6 changes: 5 additions & 1 deletion topi/python/topi/cuda/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,16 @@ def schedule(temp, Filter, Output):
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)

scheduled_ops = []

def traverse(OP):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule conv2d
if 'conv2d_nchw' in OP.tag:
Expand All @@ -510,6 +512,8 @@ def traverse(OP):
Output = OP.output(0)
schedule(temp, Filter, Output)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s

Expand Down
6 changes: 5 additions & 1 deletion topi/python/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ def schedule(temp, Filter, Output):
else:
conv2d_56_64_64(s, Filter, temp_S, Filter_S, Out, Out_L)

scheduled_ops = []

def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_injective(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule conv2d_transpose_nchw
if 'conv2d_transpose_nchw' in OP.tag:
Expand All @@ -91,6 +93,8 @@ def traverse(OP):
Output = OP.output(0)
schedule(temp, Filter, Output)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s

Expand Down
6 changes: 5 additions & 1 deletion topi/python/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,16 @@ def _schedule(Dense):
s[Dense].set_store_predicate(thread_x.var.equal(0))
s[Out].set_store_predicate(thread_x.var.equal(0))

scheduled_ops = []

def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
Expand All @@ -102,5 +104,7 @@ def traverse(OP):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s
12 changes: 10 additions & 2 deletions topi/python/topi/cuda/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,16 @@ def _schedule(PaddedInput, Filter, DepthwiseConv2d):
s[FS].bind(ty, thread_y)
s[FS].bind(tx, thread_x)

scheduled_ops = []

def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nchw':
Expand All @@ -119,6 +121,8 @@ def traverse(OP):
DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s

Expand Down Expand Up @@ -180,14 +184,16 @@ def _schedule(temp, Filter, DepthwiseConv2d):
fused = s[FS].fuse(fi, ci)
s[FS].bind(fused, thread_x)

scheduled_ops = []

def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc':
Expand All @@ -198,6 +204,8 @@ def traverse(OP):
DepthwiseConv2d = OP.output(0)
_schedule(PaddedInput, Filter, DepthwiseConv2d)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s

Expand Down
12 changes: 10 additions & 2 deletions topi/python/topi/cuda/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ def _schedule(Pool):
else:
s[Pool].compute_at(s[Out], tx)

scheduled_ops = []

def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('global_pool'):
Expand All @@ -61,6 +63,8 @@ def traverse(OP):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s

Expand Down Expand Up @@ -101,14 +105,16 @@ def _schedule(PaddedInput, Pool):
else:
s[Pool].compute_at(s[Out], tx)

scheduled_ops = []

def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
Expand All @@ -118,5 +124,7 @@ def traverse(OP):
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)

scheduled_ops.append(OP)

traverse(outs[0].op)
return s
17 changes: 13 additions & 4 deletions topi/python/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def schedule_reduce(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
sch = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse_before_reduce(operator):
"""Internal travserse function"""
Expand All @@ -96,24 +97,32 @@ def traverse_before_reduce(operator):
elif tag.is_injective(operator.tag):
sch[operator].compute_inline()
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)

scheduled_ops.append(operator)

def traverse_after_reduce(operator):
"""Internal travserse function"""
if tag.is_broadcast(operator.tag):
raise RuntimeError("Not yet support ewise after reduce")
elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
elif operator.tag == 'comm_reduce_idx':
_schedule_reduce(operator, sch, is_idx_reduce=True)
for tensor in operator.input_tensors[0].op.input_tensors:
traverse_before_reduce(tensor.op)
input_tensors = operator.input_tensors[0].op.input_tensors
for tensor in input_tensors:
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)

scheduled_ops.append(operator)

traverse_after_reduce(outs[0].op)
return sch
6 changes: 5 additions & 1 deletion topi/python/topi/cuda/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def _default_schedule(outs):
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if "nms" in op.tag:
Expand All @@ -32,9 +34,11 @@ def traverse(op):
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)

scheduled_ops.append(op)

traverse(outs[0].op)
return s

Expand Down
10 changes: 8 additions & 2 deletions topi/python/topi/intel_graphics/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,22 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \
or "1_16" in op.tag:
_schedule_cl_spatialpack_NCHWc(s, op)

scheduled_ops.append(op)

traverse(outs[0].op)

return s
Expand Down Expand Up @@ -360,19 +363,22 @@ def schedule_conv2d_nchw(outs):
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []

def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
if "4_5" in op.tag or "4_4" in op.tag or "2_7" in op.tag or "2_14" in op.tag \
or "1_16" in op.tag:
_schedule_cl_spatialpack(s, op)

scheduled_ops.append(op)

traverse(outs[0].op)
return s

Expand Down
Loading