Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
}

PrintType(GetType(v), stream);
Expand Down Expand Up @@ -179,7 +178,6 @@ std::string CodeGenC::GetBufferRef(
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)" << vid << ')';
} else {
Expand Down Expand Up @@ -213,15 +211,13 @@ std::string CodeGenC::GetBufferRef(
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t.element_of(), os);
os << "*)";
}
Expand Down Expand Up @@ -681,7 +677,6 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
value_temp << ' ';
}
}
PrintType(elem_type, value_temp);
Expand Down Expand Up @@ -731,7 +726,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
}
}
PrintType(elem_type, stream);
Expand Down Expand Up @@ -823,10 +817,8 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) {
const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream);
stream << ' '<< vid << '['
<< constant_size << "];\n";
stream << ' ' << vid << '[' << constant_size << "];\n";

RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
Expand Down
23 changes: 0 additions & 23 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,29 +257,6 @@ class CodeGenC :
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;

/*!
* \brief A RAII utility class for emitting code in a scoped region.
*/
class EnterScopeRAII {
// The codegen context.
CodeGenC* cg;

// The new scope level.
int scope;

public:
explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) {
cg->PrintIndent();
cg->stream << "{\n";
scope = cg->BeginScope();
}
~EnterScopeRAII() {
cg->EndScope(scope);
cg->PrintIndent();
cg->stream << "}\n";
}
};

private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
Expand Down
10 changes: 1 addition & 9 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ void CodeGenCUDA::PrintVecBinaryOp(
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
{
EnterScopeRAII scope(this);

// Unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
Expand Down Expand Up @@ -350,7 +348,7 @@ void CodeGenCUDA::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global");
if (scope == "shared") {
os << "__shared__";
os << "__shared__ ";
}
}

Expand All @@ -370,7 +368,6 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
this->PrintType(target_ty, stream);
stream << ' ' << sret << ";\n";
{
EnterScopeRAII scope(this);
std::string src = SSAGetID(PrintExpr(op->value), from_ty);
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val;
Expand Down Expand Up @@ -470,8 +467,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
this->PrintType(op->dtype, stream);
stream << ' ' << sret << ";\n";
{
EnterScopeRAII scope(this);

// Load arguments.
std::vector<std::string> sargs;
for (size_t i = 0; i < op->args.size(); ++i) {
Expand Down Expand Up @@ -541,7 +536,6 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream);
}
if ((op->dtype == DataType::Int(4) ||
Expand Down Expand Up @@ -657,8 +651,6 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
this->PrintType(op->dtype, stream);
stream << ' ' << r_var << ";\n";
{
EnterScopeRAII scope(this);

std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
Expand Down
7 changes: 3 additions & 4 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
stream << ' ';
PrintType(GetType(v), stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
Expand Down Expand Up @@ -236,11 +235,11 @@ void CodeGenMetal::PrintVecElemStore(const std::string& vec,
void CodeGenMetal::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "device";
os << "device ";
} else if (scope == "shared") {
os << "threadgroup";
os << "threadgroup ";
} else {
os << "thread";
os << "thread ";
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
os << ' ';
PrintType(t.element_of(), os);
os << "*)";
}
Expand Down Expand Up @@ -191,9 +190,9 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
void CodeGenOpenCL::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "__global";
os << "__global ";
} else if (scope == "shared") {
os << "__local";
os << "__local ";
}
}

Expand Down