Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion xla/service/gpu/fusions/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1505,10 +1505,12 @@ absl::Status ReductionFusion::ReductionEmitter::EmitKernel(
reduction_codegen_info_.GetIndexGroups();
Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape();

llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
llvm::Value* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_);
llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
llvm::cast<llvm::Instruction>(raw_block_id_y));
raw_block_id_y = builder_->CreateZExtOrTrunc(
raw_block_id_y, builder_->getInt32Ty(), "raw_block_id_y");
for (int i = 0; i < instr_index_groups.size(); ++i) {
TF_RETURN_IF_ERROR(ksl.IfWithStatus(
absl::StrCat("reduce-group-", i),
Expand Down
32 changes: 32 additions & 0 deletions xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,29 @@ llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
intrinsic, {b->getInt32(-1), value, offset, b->getInt32(WarpSize() - 1)});
}

// Helper function to emit call to SPIR shfl_down intrinsic.
llvm::Value* EmitSPIRShflDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* b) {
CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
if (value->getType()->isFloatTy()) {
return EmitDeviceFunctionCall(
"_Z34__spirv_GroupNonUniformShuffleDownffj",
{b->getInt32(3), value, offset}, {U32, F32, U32}, F32,
llvm::AttrBuilder(b->getContext())
.addAttribute(llvm::Attribute::NoUnwind)
.addAttribute(llvm::Attribute::Convergent),
b);
} else {
return EmitDeviceFunctionCall(
"_Z34__spirv_GroupNonUniformShuffleDownjjj",
{b->getInt32(3), value, offset}, {U32, U32, U32}, U32,
llvm::AttrBuilder(b->getContext())
.addAttribute(llvm::Attribute::NoUnwind)
.addAttribute(llvm::Attribute::Convergent),
b);
}
}

llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder) {
int bit_width = value->getType()->getPrimitiveSizeInBits();
Expand All @@ -299,6 +322,8 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
return EmitNVPTXShflDown(value, offset, builder);
} else if (target_triple.getArch() == llvm::Triple::amdgcn) {
return EmitAMDGPUShflDown(value, offset, builder);
} else if (target_triple.isSPIR()) {
return EmitSPIRShflDown(value, offset, builder);
} else {
LOG(FATAL) << "Invalid triple " << target_triple.str();
}
Expand All @@ -320,6 +345,9 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
} else if (target_triple.getArch() == llvm::Triple::amdgcn) {
insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
offset, builder);
} else if (target_triple.isSPIR()) {
insert_val = EmitSPIRShflDown(builder->CreateExtractElement(x, i), offset,
builder);
} else {
LOG(FATAL) << "Invalid triple " << target_triple.str();
}
Expand Down Expand Up @@ -1192,6 +1220,10 @@ bool IsAMDGPU(const llvm::Module* module) {
return llvm::Triple(module->getTargetTriple()).isAMDGPU();
}

bool IsSPIR(const llvm::Module* module) {
return llvm::Triple(module->getTargetTriple()).isSPIR();
}

absl::StatusOr<DenseDataIntermediate> LiteralToXlaFormat(
const Literal& literal) {
PrimitiveType element_type = literal.shape().element_type();
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/ir_emission_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ std::string GetIrNameFromLoc(mlir::Location loc);
// Whether the module's target is an AMD GPU.
bool IsAMDGPU(const llvm::Module* module);

// Whether the module's target is a SPIR.
bool IsSPIR(const llvm::Module* module);

// This class stores either a non-owning reference or owns data that represents
// a dense array in XLA format. It is used for intermediate storage during IR
// constant emission.
Expand Down
18 changes: 18 additions & 0 deletions xla/service/gpu/target_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ inline const char* DataLayout() {

} // namespace amdgpu

namespace spir {
// The triple that represents our target on SPIR backend.
inline const char* TargetTriple() {
static constexpr char kTargetTriple[] = "spir64-unknown-unknown";
return kTargetTriple;
}

// The data layout of the emitted module.
inline const char* DataLayout() {
static constexpr char kDataLayout[] =
"e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:"
"32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:"
"128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:"
"1024";
return kDataLayout;
}
} // namespace spir

} // namespace gpu
} // namespace xla

Expand Down
Loading