diff --git a/python/tvm/topi/hexagon/injective.py b/python/tvm/topi/hexagon/injective.py index 34a9fb9a05e5..9ced0ac7d399 100644 --- a/python/tvm/topi/hexagon/injective.py +++ b/python/tvm/topi/hexagon/injective.py @@ -37,6 +37,12 @@ def schedule_injective(outs): outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs s = tvm.te.create_schedule([x.op for x in outs]) tvm.te.schedule.AutoInlineInjective(s) + + # Fuse axes and vectorize inner 128 elements + for x in outs: + fused = s[x].fuse(*x.op.axis) + _, inner = s[x].split(fused, factor=128) + s[x].vectorize(inner) return s diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index a195c9f05453..7b0081869a27 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -74,8 +74,19 @@ class CodeGenHexagon final : public CodeGenCPU { bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; void InitTarget(llvm::TargetMachine* tm) final; + using CodeGenCPU::VisitStmt_; + llvm::Value* VisitExpr_(const BufferLoadNode* op) override; + llvm::Module* GetModulePtr() const { return module_.get(); } + uint64_t GetTypeSizeInBits(llvm::Type* type) const { +#if TVM_LLVM_VERSION >= 100 + return data_layout_->getTypeSizeInBits(type).getFixedSize(); +#else + return data_layout_->getTypeSizeInBits(type); +#endif + } + protected: void CreatePrintf(const std::string& format, llvm::ArrayRef format_args) final; @@ -86,6 +97,9 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name); llvm::Value* GetContextPtr(llvm::GlobalVariable* gv); + + llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, Array index); + llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); }; void CodeGenHexagon::Init(const std::string& module_name, llvm::TargetMachine* tm, @@ -281,6 +295,139 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::V return TypedPointer(); } +llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID, + llvm::ArrayRef args) { + llvm::Function* intf = llvm::Intrinsic::getDeclaration(module_.get(), IntID); +#if TVM_LLVM_VERSION >= 90 + auto intf_callee = llvm::FunctionCallee(intf); +#else + auto intf_callee = intf; +#endif + std::vector conv_args; + llvm::FunctionType* intf_type = intf->getFunctionType(); + ICHECK(args.size() == intf_type->getNumParams()); + + for (int i = 0, e = args.size(); i != e; ++i) { + llvm::Value* arg = args[i]; + auto* need_type = llvm::dyn_cast(intf_type->getParamType(i)); + auto* have_type = llvm::dyn_cast(arg->getType()); + if (need_type != nullptr && have_type != nullptr && need_type != have_type) { + int need_width = GetTypeSizeInBits(need_type); + int have_width = GetTypeSizeInBits(have_type); + if (need_width == have_width) { + if (need_width == native_vector_bits_ || need_width == 2 * native_vector_bits_) { + arg = builder_->CreateBitCast(arg, need_type); + } + } // TODO(joshherr-quic): add handling of v128i1 <-> v1024i1 + } + conv_args.push_back(arg); + } + return builder_->CreateCall(intf_callee, conv_args); +} + +llvm::Value* CodeGenHexagon::VisitExpr_(const BufferLoadNode* op) { + if (!op->buffer.same_as(op->buffer->data)) { + // Check if we can generate a vector lookup. + if (!op->indices[0].as()) { + if (auto* vlut = VectorLookupLoad(op->buffer, op->dtype, op->indices)) { + return vlut; + } + } + } + return CodeGenLLVM::VisitExpr_(op); +} + +llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_type, + Array indices) { + PrimExpr index = indices[0]; + if (!index.dtype().is_vector()) { + return nullptr; + } + + if (buffer_type.bits() != 8) return nullptr; + + int table_elem_count = arith::Analyzer().Simplify(buffer->shape[0]).as()->value; + if (table_elem_count <= 0 || table_elem_count > 256) return nullptr; + + auto int32 = DataType::Int(32); + auto native_vector_bytes = native_vector_bits_ / 8; + + // Indexes + llvm::Value* trunc = MakeValue(Cast(index.dtype().with_bits(8), index)); + llvm::Value* index_pad = CreateVecPad(trunc, native_vector_bytes); + + // Values + std::vector vloads; + DataType table_type = buffer_type.with_lanes(table_elem_count); + + auto table_all = + MakeValue(BufferLoad(buffer, { + Ramp(IntImm(int32, 0), IntImm(int32, 1), table_elem_count), + })); + + // The number of value vectors should be a power of 2. + int table_vec_count = llvm::PowerOf2Ceil(GetVectorBytes(table_type) / native_vector_bytes); + int table_vec_length = native_vector_bytes / buffer_type.bytes(); + for (int i = 0; i != table_vec_count; ++i) { + // CreateVecSlice will generate undefs for elements outside the source vector. + vloads.push_back(CreateVecSlice(table_all, i * table_vec_length, table_vec_length)); + } + +#define VLO(x) Intrinsic(llvm::Intrinsic::hexagon_V6_lo_128B, {x}) +#define VHI(x) Intrinsic(llvm::Intrinsic::hexagon_V6_hi_128B, {x}) +#define VXOR(x, y) Intrinsic(llvm::Intrinsic::hexagon_V6_vxor_128B, {x, y}) +#define VSHUFF(x) Intrinsic(llvm::Intrinsic::hexagon_V6_vshuffb_128B, {x}) +#define VSPLATB(x) Intrinsic(llvm::Intrinsic::hexagon_V6_lvsplatb_128B, {x}) +#define VLUT32(x, y, z) Intrinsic(llvm::Intrinsic::hexagon_V6_vlutvvbi_128B, {x, y, z}) +#define VLUT32_OR(v, x, y, z) \ + Intrinsic(llvm::Intrinsic::hexagon_V6_vlutvvb_oracci_128B, {v, x, y, z}) + + // Shuffle table bytes: + // 127, 63, 126, 62,........68, 4, 67, 3, 66, 2, 65, 1, 64, 0 + std::vector table; + for (int i = 0; i != table_vec_count; ++i) table.push_back(VSHUFF(vloads[i])); + + // Get each 32 byte sub-table's output + std::vector results; + int table_iters = table_elem_count / 32; + for (int i = 0; i < table_iters; ++i) + results.push_back(VLUT32(index_pad, table[i / 4], ConstInt32(i % 8))); + + // Combine outputs + llvm::Value* result = results[0]; + for (int i = 1; i < table_iters; ++i) result = VXOR(result, results[i]); + + llvm::Type* res_type = result->getType(); + llvm::Type* ret_type = DTypeToLLVMType(buffer_type); + if (res_type == ret_type) { + return result; + } + + int res_bits = GetTypeSizeInBits(res_type); + int ret_bits = GetTypeSizeInBits(ret_type); + ICHECK_GE(res_bits, ret_bits); + if (ret_bits < res_bits) { +#if TVM_LLVM_VERSION >= 110 + llvm::Type* res_byte_type = llvm::VectorType::get(t_int8_, res_bits / 8, /*Scalable*/ false); +#else + llvm::Type* res_byte_type = llvm::VectorType::get(t_int8_, res_bits / 8); +#endif + result = CreateVecSlice(builder_->CreateBitCast(result, res_byte_type), 0, ret_bits / 8); + } + if (result->getType() != ret_type) { + return builder_->CreateBitCast(result, ret_type); + } + return result; + +#undef VLUT32_OR +#undef VLUT32 +#undef VSPLATB +#undef VSHUFF +#undef VXOR +#undef VHI +#undef VLO +} + namespace { DMLC_ATTRIBUTE_UNUSED std::ostream& operator<<(std::ostream& os, const llvm::Module& m) { std::string ms;