diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 187bdc74fe29..0ff0531b5c20 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -372,6 +372,10 @@ void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base stream << ref << " = " << value << ";\n"; } +void CodeGenC::PrintVecConstructor(DataType t, std::ostream& os) { // NOLINT(*) + PrintType(t, os); +} + std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) { if (from == target) return value; std::ostringstream os; @@ -869,8 +873,47 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) os << ")"; } -void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { - LOG(FATAL) << "Shuffle: not supported "; +void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT(*) + // Shuffle support + // vec = concat(vectors) + // result = (vec[indices[0]], vec[indices[1]], ...) + // + // print shuffle as: + // target_dtype(e0, e1, e2, .. en) + + // construct the concat + std::vector concat_vec; + // NOTE: important to print expr first + // in case each expr have their own nested expressions + // print each elements + for (const PrimExpr& vec : op->vectors) { + std::string vec_value = this->PrintExpr(vec); + if (vec.dtype().lanes() == 1) { + concat_vec.push_back(vec_value); + } else { + // print out each element + for (int i = 0; i < vec.dtype().lanes(); ++i) { + // access i-th element of each vector + std::ostringstream vec_elem_strm; + vec_elem_strm << vec_value << "[" << i << "]"; + concat_vec.push_back(vec_elem_strm.str()); + } + } + } + if (op->indices.size() == 1) { + // This is an extract element + os << concat_vec[Downcast(op->indices[0])->value]; + } else { + // Print the shuffle as vector constructor + // vec(e0, e1, e2, .. en) + PrintVecConstructor(op->dtype, os); + os << '('; + for (size_t i = 0; i < op->indices.size(); ++i) { + if (i != 0) os << ", "; + os << concat_vec[Downcast(op->indices[i])->value]; + } + os << ')'; + } } void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 2921a56ef3a1..9a20566d5b3e 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -221,6 +221,8 @@ class CodeGenC : public ExprFunctor, // print store of single element. virtual void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value); + // print vector constructor + virtual void PrintVecConstructor(DataType t, std::ostream& os); // Get a cast type from to virtual std::string CastFromTo(std::string value, DataType from, DataType target); // Get load of single element with expression diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7639ce606563..ef69b7a7d167 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -437,6 +437,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } +void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { + os << "make_"; + PrintType(t, os); +} + void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. @@ -1156,15 +1161,14 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; - os << "(make_"; - PrintType(op->dtype, os); + PrintVecConstructor(op->dtype, os); os << "("; for (int i = 0; i < op->lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; if (i != op->lanes - 1) os << ", "; } - os << "))"; + os << ")"; } void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) @@ -1184,8 +1188,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_float16()) { std::string v = PrintExpr(op->value); - os << "make_"; - PrintType(op->dtype, os); + PrintVecConstructor(op->dtype, os); os << '('; for (int i = 0; i < op->lanes / 2; ++i) { if (i != 0) os << ", "; @@ -1197,8 +1200,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); - os << "make_"; - PrintType(op->dtype, os); + PrintVecConstructor(op->dtype, os); os << '('; for (int i = 0; i < op->lanes / 2; ++i) { if (i != 0) os << ", "; @@ -1230,8 +1232,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO os << "(int)" << v; } } else if (op->lanes == 16 || op->lanes == 32) { - os << "make_"; - PrintType(op->dtype, os); + PrintVecConstructor(op->dtype, os); os << '('; for (int i = 0; i < op->lanes / 8; ++i) { if (i != 0) os << ", "; @@ -1253,8 +1254,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } std::string v = PrintExpr(op->value); - os << "make_"; - PrintType(op->dtype, os); + PrintVecConstructor(op->dtype, os); os << '('; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; @@ -1263,24 +1263,6 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO os << ')'; } -void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) { - std::vector to_shuffle(op->vectors.size()); - for (int i = 0, e = op->vectors.size(); i < e; ++i) { - ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; - to_shuffle[i] = PrintExpr(op->vectors[i]); - } - os << "make_"; - PrintType(op->dtype, os); - os << '('; - for (int i = 0, e = op->indices.size(); i < e; ++i) { - const int64_t* val = as_const_int(op->indices[i]); - ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size()); - if (i != 0) os << ", "; - os << to_shuffle[*val]; - } - os << ')'; -} - void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { // Non-vector cases. if (!op->dtype.is_vector()) { @@ -1459,8 +1441,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val if (t.is_float16()) { if (i == 0) { - os << "make_"; - PrintType(t, os); + PrintVecConstructor(t, os); os << '('; } if (i % 2 == 0) { @@ -1478,8 +1459,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val if (t.is_bfloat16()) { if (i == 0) { - os << "make_"; - PrintType(t, os); + PrintVecConstructor(t, os); os << '('; } if (i % 2 == 0) { @@ -1496,8 +1476,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val } if (i == 0) { - os << "make_"; - PrintType(t, os); + PrintVecConstructor(t, os); os << "("; } os << value; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index bc7b34b500d8..7fe818b6b4fb 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -54,6 +54,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintVecConstructor(DataType t, std::ostream& os) final; void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; @@ -62,7 +63,6 @@ class CodeGenCUDA final : public CodeGenC { std::string CastFromTo(std::string value, DataType from, DataType target) final; // overload visitor void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ebb7566489d6..ddd7d25f3b5f 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -22,9 +22,12 @@ */ #include "codegen_metal.h" +#include + #include #include #include +#include #include #include "../../runtime/metal/metal_module.h" @@ -327,6 +330,7 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO runtime::Module BuildMetal(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); std::ostringstream source_maker; std::unordered_map smap;