Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ inline const char* IterVarType2String(IterVarType t) {
switch (t) {
case kDataPar: return "DataPar";
case kThreadIndex: return "ThreadIndex";
case kCommReduce: return "CommRedude";
case kCommReduce: return "CommReduce";
case kOrdered: return "Ordered";
case kOpaque: return "Opaque";
case kUnrolled: return "Unrolled";
Expand Down
30 changes: 30 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ struct Reduce : public ExprNode<Reduce> {
static Expr make(std::string op, Expr src,
Array<IterVar> rdom,
Expr condition = const_true());
/*!
* \brief Get initial value for reduction.
* \param op The operator
* \param type The data type.
* \return The initial value that can be assigned to reduction.
*/
static Expr InitValue(const std::string& op, Type type);
/*!
* \brief Combine two values with given reduction.
* \param op The operator
* \param a The left operand.
* \param b The left operand.
* \return The combined reduction result.
*/
static Expr Combine(const std::string& op, Expr a, Expr b);

void VisitAttrs(AttrVisitor* v) final {
v->Visit("dtype", &type);
Expand Down Expand Up @@ -86,6 +101,10 @@ constexpr const char* thread_extent = "thread_extent";
* \brief Mark launching of a virtual thread.
*/
constexpr const char* virtual_thread = "virtual_thread";
/*!
* \brief Mark the scope as volatile access for certain handle.
*/
constexpr const char* volatile_scope = "volatile_scope";
/*!
* \brief Mark storage scope of buffers
*/
Expand Down Expand Up @@ -164,6 +183,17 @@ constexpr const char* tvm_call_packed = "tvm_call_packed";
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief See pesudo code
*
* Expr tvm_thread_allreduce(std::string op, Expr value, Expr cond,
* Var thread_idx1, thread_idx2...) {
* // constraint by the other thread_idx remain the same.
* return reduce(op, value, cond,
* over [thread_idx1, thread_idx2] passed by any caller)
* }
*/
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";

/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);

/*!
* \brief Lower cross thread alleduce in the stmt.
* \paramf f The device function to be lowered.
* \param wrap_size the size of warp where no sync is needed.
* \return Transformed stmt.
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc stmt, int warp_size);
} // namespace ir
} // namespace tvm

Expand Down
8 changes: 8 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class Stage : public NodeRef {
* \return reference to self.
*/
Stage& compute_root(); // NOLINT(*)
/*!
* \brief Rebase the parent iter var as rebased variable.
*
* \param parent The parent iteration domain.
* \param rebased The variable to be used in rebase.
* \return reference to self.
*/
Stage& rebase(IterVar parent, IterVar rebased);
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def lower(sch,
return fapi



def build(sch,
args=None,
target="llvm",
Expand Down Expand Up @@ -128,6 +127,8 @@ def build(sch,
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
warp_size = 32 if target == "cuda" else 1
fsplits[i] = ir_pass.LowerThreadAllreduce(fsplits[i], warp_size)

if len(fsplits) > 1:
mhost = codegen.build(fsplits[0], target_host)
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ def rfactor(self, tensor, axis):
@register_node
class Stage(NodeBase):
"""A Stage represents schedule for one operation."""
def rebase(self, parent, rebased):
"""Rebase parent by an existing thread axis.

Parameters
----------
parent : IterVar
The parent iter var.

rebased : IterVar
The rebased iter var.
Returns
-------
rebased : IterVar
The rebased itervar.
"""
_api_internal._StageRebase(self, parent, rebased)
return rebased

def split(self, parent, factor=None, outer=None):
"""Split the stage either by factor providing outer scope, or both

Expand Down
7 changes: 7 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ TVM_REGISTER_API(_StageSetScope)
.set_scope(args[1]);
});

TVM_REGISTER_API(_StageRebase)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
args[0].operator Stage()
.rebase(args[1], args[2]);
});

TVM_REGISTER_API(_StageSplitByFactor)
.set_body([](TVMArgs args, TVMRetValue* ret) {
IterVar outer, inner;
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS1(NarrowChannelAccess);
} // namespace ir
} // namespace tvm
82 changes: 54 additions & 28 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,26 @@ void CodeGenC::PrintSSAAssign(
}

// Print a reference expression to a buffer.
void CodeGenC::PrintBufferRef(
std::string CodeGenC::GetBufferRef(
const Variable* buffer,
Type t, Expr index,
std::ostream& os) { // NOLINT(*)
Type t, Expr index) {
std::ostringstream os;
std::string vid = GetVarID(buffer);
std::string scope;
if (alloc_storage_scope_.count(buffer)) {
scope = alloc_storage_scope_.at(buffer);
}
bool is_vol = volatile_buf_.count(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t)) {
if (!HandleTypeMatch(buffer, t) || is_vol) {
os << "((";
if (is_vol) {
os << "volatile ";
}
if (scope.length() != 0) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)" << vid << ')';
} else {
Expand All @@ -107,17 +119,24 @@ void CodeGenC::PrintBufferRef(
} else {
// Buffer declared as vector type.
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t)) {
if (HandleTypeMatch(buffer, t) && !is_vol) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return;
return os.str();
}
}
os << "((";
if (is_vol) {
os << "volatile ";
}
if (scope.length() != 0) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
Expand All @@ -129,6 +148,7 @@ void CodeGenC::PrintBufferRef(
PrintExpr(index, os);
os << "))[0]";
}
return os.str();
}


Expand Down Expand Up @@ -162,18 +182,17 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
<< " = " << value << ";\n";
}

void CodeGenC::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
PrintBufferRef(buffer, t, base, os);
std::string CodeGenC::GetVecLoad(const Variable* buffer,
Type t, Expr base) {
return GetBufferRef(buffer, t, base);
}

void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
std::string ref = GetBufferRef(buffer, t, base);
this->PrintIndent();
PrintBufferRef(buffer, t, base, stream);
stream << " = " << value << ";\n";
stream << ref << " = " << value << ";\n";
}

void CodeGenC::PrintThreadIndexExpr(
Expand Down Expand Up @@ -483,24 +502,21 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {

void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes();
std::string svalue = GetUniqueName("_");
// delcare type.
this->PrintIndent();
this->PrintType(op->type, stream);
stream << ' ' << svalue;
if (op->type.lanes() == 1) {
stream << " = ";
this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, stream);
stream << ";\n";
std::string ref = GetBufferRef(op->buffer_var.get(), op->type, op->index);
os << ref;
} else {
Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
stream << " = ";
this->PrintVecLoad(op->buffer_var.get(), op->type, base, stream);
stream << ";\n";
std::string ref = GetVecLoad(op->buffer_var.get(), op->type, base);
os << ref;
} else {
// Load elements seperately
stream << ";\n";
// load seperately.
std::string svalue = GetUniqueName("_");
this->PrintIndent();
this->PrintType(op->type, stream);
stream << ' ' << svalue << ";\n";
std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
std::string vid = GetVarID(op->buffer_var.get());
Type elem_type = op->type.element_of();
Expand All @@ -518,18 +534,18 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
value_temp << ']';
PrintVecElemStore(svalue, op->type, i, value_temp.str());
}
os << svalue;
}
}
os << svalue;
}

void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type();
if (t.lanes() == 1) {
std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(op->buffer_var.get(), t, op->index);
this->PrintIndent();
this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream);
stream << " = " << value << ";\n";
stream << ref << " = " << value << ";\n";
} else {
Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
Expand Down Expand Up @@ -577,7 +593,13 @@ void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*
}

void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Select: not supported ";
os << "(";
PrintExpr(op->condition, os);
os << " ? ";
PrintExpr(op->true_value, os);
os << " : ";
PrintExpr(op->false_value, os);
os << ")";
}

void CodeGenC::VisitStmt_(const LetStmt* op) {
Expand Down Expand Up @@ -649,6 +671,10 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
} else if (op->type_key == ir::attr::volatile_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
volatile_buf_.insert(v);
}
this->PrintStmt(op->body);
}
Expand Down
15 changes: 8 additions & 7 deletions src/codegen/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "./codegen_source_base.h"

namespace tvm {
Expand Down Expand Up @@ -132,9 +133,8 @@ class CodeGenC :
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load
virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os); // NOLINT(*)
virtual std::string GetVecLoad(const Variable* buffer,
Type t, Expr base);
// print vector store
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
Expand All @@ -149,9 +149,8 @@ class CodeGenC :

protected:
// print reference to a buffer as type t in index.
void PrintBufferRef(const Variable* buffer,
Type t, Expr index,
std::ostream& os); // NOLINT(*)
std::string GetBufferRef(const Variable* buffer,
Type t, Expr index);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
Expand All @@ -172,9 +171,11 @@ class CodeGenC :

private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
bool print_ssa_form_{false};
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief set of volatile buf access */
std::unordered_set<const Variable*> volatile_buf_;
};

} // namespace codegen
Expand Down
10 changes: 6 additions & 4 deletions src/codegen/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,13 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
os << GetVarID(buffer) << " + ";
PrintExpr(base, os);
}
void CodeGenOpenCL::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
std::string CodeGenOpenCL::GetVecLoad(const Variable* buffer,
Type t, Expr base) {
std::ostringstream os;
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
os << ")";
return os.str();
}

void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
Expand All @@ -121,7 +122,8 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
}
}

void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
void CodeGenOpenCL::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "__global";
} else if (scope == "shared") {
Expand Down
Loading