Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/cpp/build_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ TEST(BuildModule, Heterogeneous) {
return copy[i] - C[i];
}, "elemwise_sub");

const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope");
(*enter_target_scope_func)(target_cuda);
auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});

(*enter_target_scope_func)(target_llvm);
auto s2 = create_schedule({elemwise_sub->op});

auto config = BuildConfig::Create();
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_runtime_heterogeneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def check_device(device, target_device):

dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
check_device(device, target)
with tvm.target.create(device):
check_device(device, target)


def get_duplex_graph(host_dev_type, device_dev_type):
Expand Down Expand Up @@ -394,7 +395,8 @@ def check_load_module():

dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
check_device(device, target)
with tvm.target.create(device):
check_device(device, target)

if __name__ == "__main__":
test_simplex_data_transferring()
Expand Down
86 changes: 0 additions & 86 deletions topi/include/topi/cuda/extern.h

This file was deleted.

29 changes: 16 additions & 13 deletions topi/include/topi/cuda/injective.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,24 @@ namespace topi {
using namespace tvm;

namespace cuda {

/*!
* \brief Schedule a given injective operation.
*
* \param target The target to generate a schedule for.
* \param op The operation representing the injective operation.
* \param s The schedule to apply this scheduling to
*/
inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
auto x = op.output(0);
auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
auto fused = detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
auto target = Target::Current(false);
auto num_thread = target->max_num_threads;
IterVar bx, tx;
s[x].split(fused, num_thread, &bx, &tx);
s[x].bind(bx, thread_axis(Range(), "blockIdx.x"));
s[x].bind(tx, thread_axis(Range(), "threadIdx.x"));
sch[out].split(fused, num_thread, &bx, &tx);
sch[out].bind(bx, thread_axis(Range(), "blockIdx.x"));
sch[out].bind(tx, thread_axis(Range(), "threadIdx.x"));
return sch;
}

/*!
Expand All @@ -66,7 +69,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
ScheduleInjectiveOp(target, out->op, s);
schedule_injective_from_existing(s, out);
}
return s;
}
Expand Down
10 changes: 10 additions & 0 deletions topi/include/topi/generic/extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
#include "injective.h"

namespace topi {
using namespace tvm;
Expand All @@ -47,6 +48,15 @@ inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);

tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
if (out->op->derived_from<ExternOpNode>()) {
continue;
}
tvm::GenericFunc::Get("schedule_injective_from_existing")(s, out);
}

return s;
}

Expand Down
15 changes: 14 additions & 1 deletion topi/include/topi/generic/injective.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ using namespace tvm;

namespace generic {

/*!
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
return sch;
}

/*!
* \brief Create a generic schedule for the given injective ops.
*
Expand All @@ -50,7 +63,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
schedule_injective_from_existing(s, x);

return s;
}
Expand Down
32 changes: 23 additions & 9 deletions topi/include/topi/x86/injective.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ namespace topi {
using namespace tvm;

namespace x86 {

/*!
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
auto axis = sch[out]->op.as<ComputeOpNode>()->axis;
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h
sch[out].parallel(fused);
} else {
sch[out].parallel(axis[0]);
}
return sch;
}

/*!
* \brief Create an x86 schedule for the given injective ops.
*
Expand All @@ -50,15 +72,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
tvm::schedule::AutoInlineInjective(s);

auto x = outs[0];
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused);
} else {
s[x].parallel(axis[0]);
}
schedule_injective_from_existing(s, x);

return s;
}
Expand Down
35 changes: 27 additions & 8 deletions topi/python/topi/arm_cpu/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,32 @@
import tvm
from .. import generic

@generic.schedule_injective_from_existing.register(["arm_cpu"])
def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.

Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.

Returns
-------
sch: Schedule
The updated schedule.
"""
if len(sch[out].op.axis) >= 4:
fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
sch[out].parallel(fused)
elif len(sch[out].op.axis) >= 3:
fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
sch[out].parallel(fused)
elif len(sch[out].op.axis) >= 2:
sch[out].parallel(sch[out].op.axis[0])
return sch

@generic.schedule_injective.register(["arm_cpu"])
def schedule_injective(outs):
"""ARM CPU schedule for injective op.
Expand All @@ -42,14 +68,7 @@ def schedule_injective(outs):
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 4:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 2:
s[x].parallel(s[x].op.axis[0])
schedule_injective_from_existing(s, x)
return s

@generic.schedule_concatenate.register(["arm_cpu"])
Expand Down
1 change: 0 additions & 1 deletion topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_adaptive_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
from .vision import *
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/cuda/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tvm
from tvm import autotvm

from .injective import _schedule_injective
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
Expand Down Expand Up @@ -172,8 +172,8 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output):
if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
packed_kernel.name == 'packed_kernel':
# data and kernel are not pre-computed, schedule layout transform here
_schedule_injective(packed_data.op, s)
_schedule_injective(packed_kernel.op, s)
schedule_injective_from_existing(s, packed_data)
schedule_injective_from_existing(s, packed_kernel)

if pad_data != packed_data:
s[pad_data].compute_inline()
Expand Down
48 changes: 0 additions & 48 deletions topi/python/topi/cuda/extern.py

This file was deleted.

Loading