Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
*/
void AutoInlineElemWise(Schedule sch);

/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineInjective(Schedule sch);

} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
33 changes: 33 additions & 0 deletions src/schedule/auto_inline_elem_wise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,38 @@ void AutoInlineElemWise(Schedule sch) {
}
}

bool IsBroadcast(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
if (compute->reduce_axis.size()) {
return false;
}
// TODO(nicolasvasilache): Implement Me
}
return false;
}

void AutoInlineBroadcast(Schedule sch) {
for (Stage s : sch->stages) {
if (!s.is_scheduled() && IsBroadcast(s->op) && !s->is_output) {
s.compute_inline();
}
}
}

bool IsInjective(const Operation& op) {
if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) {
return compute->reduce_axis.size() == 0;
}
return false;
}

void AutoInlineInjective(Schedule sch) {
for (Stage s : sch->stages) {
if (!s.is_scheduled() && IsInjective(s->op) && !s->is_output) {
s.compute_inline();
}
}
}

} // namespace schedule
} // namespace tvm
57 changes: 57 additions & 0 deletions topi/include/topi/broadcast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) 2017 by Contributors
* \brief Broadcast op constructions
* \file broadcast.h
*/
#ifndef TOPI_BROADCAST_H_
#define TOPI_BROADCAST_H_

#include <topi/detail/broadcast.h>

namespace topi {

inline tvm::Tensor broadcast_to(const tvm::Tensor& I,
const tvm::Array<tvm::Expr>& output_shape) {
CHECK_GE(output_shape.size(), I->shape.size())
<< "Not a broadcast, output dimensionality smaller than input.\noutput: "
<< output_shape << "\nvs\ninput: " << I;
auto bh = detail::BroadcastShape(output_shape, I->shape);
CHECK_EQ(output_shape.size(), bh.common_shape.size());
for (int i = 0; i < output_shape.size(); ++i) {
CHECK(tvm::ir::Equal(output_shape[i], bh.common_shape[i]));
}
auto l = [&](tvm::Array<tvm::Var> ovars) {
return I(detail::InputIndexFromBroadcast(ovars, I, bh.vars2, bh.all_vars));
};
return tvm::compute(
tvm::Array<tvm::Expr>(bh.common_shape.begin(), bh.common_shape.end()), l);
}

inline tvm::Tensor broadcast_add(const tvm::Tensor& A, const tvm::Tensor& B) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return a + b; };
return detail::WithBroadcast(l, A, B);
}

inline tvm::Tensor broadcast_sub(const tvm::Tensor& A, const tvm::Tensor& B) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return a - b; };
return detail::WithBroadcast(l, A, B);
}

inline tvm::Tensor broadcast_mul(const tvm::Tensor& A, const tvm::Tensor& B) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return a * b; };
return detail::WithBroadcast(l, A, B);
}

inline tvm::Tensor broadcast_div(const tvm::Tensor& A, const tvm::Tensor& B) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return a / b; };
return detail::WithBroadcast(l, A, B);
}

inline tvm::Tensor broadcast_mod(const tvm::Tensor& A, const tvm::Tensor& B) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return a % b; };
return detail::WithBroadcast(l, A, B);
}

} // namespace topi

#endif // TOPI_BROADCAST_H_
107 changes: 107 additions & 0 deletions topi/include/topi/detail/broadcast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) 2017 by Contributors
* \brief Detail broadcast.
* \file broadcast.h
*/
#ifndef TOPI_DETAIL_BROADCAST_H_
#define TOPI_DETAIL_BROADCAST_H_

#include <algorithm>
#include <deque>

#include "tvm/ir_pass.h"
#include "tvm/tvm.h"

namespace topi {
namespace detail {

struct BroadcastHelper {
std::deque<tvm::Expr> common_shape;
std::deque<tvm::Var> all_vars;
std::deque<tvm::Var> vars1;
std::deque<tvm::Var> vars2;
};

inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
const tvm::Array<tvm::Expr>& shape2) {
BroadcastHelper bh;
int s1_size = shape1.size();
int s2_size = shape2.size();
tvm::Expr one(1);
int i;
for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
bh.all_vars.push_front(tvm::Var());
if (tvm::ir::Equal(shape1[s1_size - i], shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (tvm::ir::Equal(one, shape1[s1_size - i])) {
CHECK(!tvm::ir::Equal(one, shape2[s2_size - i]));
bh.common_shape.push_front(shape2[s2_size - i]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (tvm::ir::Equal(one, shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]);
} else {
CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
<< " and " << shape2[s2_size - i] << " in: "
<< tvm::Array<tvm::Expr>(shape1.begin(), shape1.end())
<< " and "
<< tvm::Array<tvm::Expr>(shape2.begin(), shape2.end());
}
}
// Remaining dimensions whether on shape1 or shape2 can always be completed
auto max_size = std::max(s1_size, s2_size);
auto& shape = (s1_size > s2_size) ? shape1 : shape2;
auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
for (i = i; i <= max_size; ++i) {
bh.all_vars.push_front(tvm::Var());
bh.common_shape.push_front(shape[max_size - i]);
vars.push_front(bh.all_vars[0]);
}
return bh;
}

inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
const tvm::Array<tvm::Var>& ovars, const tvm::Tensor& T,
const std::deque<tvm::Var>& my_vars, const std::deque<tvm::Var>& all_vars) {
tvm::Array<tvm::Expr> ivars;
CHECK_EQ(ovars.size(), all_vars.size());
// N^2, could use a map but NBD..
int expected_dims = T->shape.size();
for (int i = 0; i < ovars.size(); ++i) {
bool found = false;
for (int j = 0; j < my_vars.size(); ++j) {
if (all_vars[i].same_as(my_vars[j])) {
ivars.push_back(ovars[i]);
found = true;
break;
}
}
// Only inject 0 here if we have not yet reached the dimension of I
// (i.e. this must be a 1)
if (!found && (ovars.size() - i) <= expected_dims) {
ivars.push_back(tvm::make_zero(ovars[i].type()));
}
}
CHECK(expected_dims == ivars.size());
return ivars;
}


template <typename FBinaryExpr>
inline tvm::Tensor WithBroadcast(FBinaryExpr op, const tvm::Tensor& A,
const tvm::Tensor& B) {
auto bh = BroadcastShape(A->shape, B->shape);
auto l = [&](tvm::Array<tvm::Var> ovars) {
return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars)));
};
return tvm::compute(
tvm::Array<tvm::Expr>(bh.common_shape.begin(), bh.common_shape.end()), l);
}

} // namespace detail
} // namespace topi

#endif // TOPI_DETAIL_BROADCAST_H_
13 changes: 7 additions & 6 deletions topi/include/topi/ewise.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2017 by Contributors
* \file topi.h
* \file ewise.h
* \brief Elementwise op constructions
*/
#ifndef TOPI_EWISE_H_
Expand All @@ -12,16 +12,17 @@ namespace topi {
using namespace tvm;

// Unary intrinsic operators
#define TOPI_DECLARE_UNARY_OP(OpName) \
inline Tensor OpName(const Tensor& x) { \
return compute(x->shape, [&](const Array<Var>& i) { \
return ::tvm::OpName(x(i)); \
}); \
#define TOPI_DECLARE_UNARY_OP(OpName) \
inline Tensor OpName(const Tensor& x) { \
return compute(x->shape, [&](const Array<Var>& i) { \
return ::tvm::OpName(x(i)); \
}, "tensor", "ewise"); \
}

TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);

} // namespace topi
#endif // TOPI_EWISE_H_
Loading