Skip to content
177 changes: 110 additions & 67 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28185,74 +28185,40 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, var_types simd

#if defined(TARGET_XARCH)

if (simdSize == 64)
{
GenTree* op1Dup = fgMakeMultiUse(&op1);

op1 = gtNewSimdGetLowerNode(TYP_SIMD32, op1, simdBaseType, simdSize);
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD32, op1Dup, simdBaseType, simdSize);

if (varTypeIsFloating(simdBaseType))
{
// We need to ensure deterministic results which requires
// consistently adding values together. Since many operations
// end up operating on 128-bit lanes, we break sum the same way.

op1 = gtNewSimdSumNode(type, op1, simdBaseType, 32);
op1Dup = gtNewSimdSumNode(type, op1Dup, simdBaseType, 32);

return gtNewOperNode(GT_ADD, type, op1, op1Dup);
}

simdSize = 32;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1, op1Dup, simdBaseType, 32);
}

if (simdSize == 32)
{
GenTree* op1Dup = fgMakeMultiUse(&op1);

op1 = gtNewSimdGetLowerNode(TYP_SIMD16, op1, simdBaseType, simdSize);
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD16, op1Dup, simdBaseType, simdSize);

if (varTypeIsFloating(simdBaseType))
{
// We need to ensure deterministic results which requires
// consistently adding values together. Since many operations
// end up operating on 128-bit lanes, we break sum the same way.

op1 = gtNewSimdSumNode(type, op1, simdBaseType, 16);
op1Dup = gtNewSimdSumNode(type, op1Dup, simdBaseType, 16);

return gtNewOperNode(GT_ADD, type, op1, op1Dup);
}

simdSize = 16;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseType, 16);
}

assert(simdSize == 16);

if (varTypeIsFloating(simdBaseType))
{
// For floating-point we first run the horizontal permute+add sequence
// at the full simd width. vpermilps/vpermilpd permute WITHIN each
// 128-bit lane, so this is effectively 2x (V256) or 4x (V512) V128
// reductions running in parallel with no duplicated work.
//
// After that, each 128-bit lane of op1 holds the sum of its elements
// broadcast across the lane. We then reduce the lanes by combining
// upper/lower halves step-by-step down to a single V128. Floating-
// point addition is not associative, so the halve-combine grouping
// below deliberately preserves the prior recursive
// `Sum(lower) + Sum(upper)` shape.

if (simdBaseType == TYP_FLOAT)
{
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
NamedIntrinsic permIntrinsic = (simdSize == 64) ? NI_AVX512_Permute4x32 : NI_AVX_Permute;

if (compOpportunisticallyDependsOn(InstructionSet_AVX))
if ((simdSize > 16) || compOpportunisticallyDependsOn(InstructionSet_AVX))
{
// The permute below gives us [0, 1, 2, 3] -> [1, 0, 3, 2]
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode((int)0b10110001, TYP_INT), NI_AVX_Permute,
// Per lane, the permute below gives us [0, 1, 2, 3] -> [1, 0, 3, 2]
op1 = gtNewSimdHWIntrinsicNode(simdType, op1, gtNewIconNode((int)0b10110001, TYP_INT), permIntrinsic,
simdBaseType, simdSize);
// The add below now results in [0 + 1, 1 + 0, 2 + 3, 3 + 2]
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Shuffled, simdBaseType, simdSize);
// Per lane, the add below now results in [0 + 1, 1 + 0, 2 + 3, 3 + 2]
op1 = gtNewSimdBinOpNode(GT_ADD, simdType, op1, op1Shuffled, simdBaseType, simdSize);
op1Shuffled = fgMakeMultiUse(&op1);
// The permute below gives us [0 + 1, 1 + 0, 2 + 3, 3 + 2] -> [2 + 3, 3 + 2, 0 + 1, 1 + 0]
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode((int)0b01001110, TYP_INT), NI_AVX_Permute,
// Per lane, the permute below gives us [0 + 1, 1 + 0, 2 + 3, 3 + 2] -> [2 + 3, 3 + 2, 0 + 1, 1 + 0]
op1 = gtNewSimdHWIntrinsicNode(simdType, op1, gtNewIconNode((int)0b01001110, TYP_INT), permIntrinsic,
simdBaseType, simdSize);
}
else
{
assert(simdSize == 16);
// The shuffle below gives us [0, 1, 2, 3] -> [1, 0, 3, 2]
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, op1Shuffled, gtNewIconNode((int)0b10110001, TYP_INT),
NI_X86Base_Shuffle, simdBaseType, simdSize);
Expand All @@ -28265,34 +28231,111 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, var_types simd
NI_X86Base_Shuffle, simdBaseType, simdSize);
op1Shuffled = fgMakeMultiUse(&op1Shuffled);
}
// Finally adding the results gets us [(0 + 1) + (2 + 3), (1 + 0) + (3 + 2), (2 + 3) + (0 + 1), (3 + 2) + (1
// + 0)]
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Shuffled, simdBaseType, simdSize);
return gtNewSimdToScalarNode(type, op1, simdBaseType, simdSize);
// Per lane, adding the results gets us [(0 + 1) + (2 + 3), (1 + 0) + (3 + 2), (2 + 3) + (0 + 1),
// (3 + 2) + (1 + 0)]
op1 = gtNewSimdBinOpNode(GT_ADD, simdType, op1, op1Shuffled, simdBaseType, simdSize);
}
else
{
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
NamedIntrinsic permIntrinsic = (simdSize == 64) ? NI_AVX512_Permute2x64 : NI_AVX_Permute;

if (compOpportunisticallyDependsOn(InstructionSet_AVX))
if ((simdSize > 16) || compOpportunisticallyDependsOn(InstructionSet_AVX))
{
// The permute below gives us [0, 1] -> [1, 0]
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode((int)0b0001, TYP_INT), NI_AVX_Permute,
// Per lane, the permute below gives us [0, 1] -> [1, 0]
// vpermilpd uses one imm bit per double element (2 for V128, 4 for V256,
// 8 for V512); 0b01010101 swaps within each 128-bit lane at all widths.
op1 = gtNewSimdHWIntrinsicNode(simdType, op1, gtNewIconNode((int)0b01010101, TYP_INT), permIntrinsic,
simdBaseType, simdSize);
}
else
{
assert(simdSize == 16);
// The shuffle below gives us [0, 1] -> [1, 0]
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, op1Shuffled, gtNewIconNode((int)0b0001, TYP_INT),
NI_X86Base_Shuffle, simdBaseType, simdSize);
op1Shuffled = fgMakeMultiUse(&op1Shuffled);
}
// Finally adding the results gets us [0 + 1, 1 + 0]
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Shuffled, simdBaseType, simdSize);
return gtNewSimdToScalarNode(type, op1, simdBaseType, simdSize);
// Per lane, adding the results gets us [0 + 1, 1 + 0]
op1 = gtNewSimdBinOpNode(GT_ADD, simdType, op1, op1Shuffled, simdBaseType, simdSize);
}

// At this point every 128-bit lane of op1 contains that lane's reduced
// sum broadcast across the lane. Combine the lanes into a single V128
// by reducing upper/lower halves step-by-step. Floating-point addition
// is not associative, so the grouping used here deliberately matches
// the prior recursive shape:
// V512: Sum = Sum(v512.GetLower()) + Sum(v512.GetUpper())
// V256: Sum = (v256.GetLower() + v256.GetUpper()).ToScalar()
// V128: Sum = v128.ToScalar()

if (simdSize == 64)
{
// Extract each of the four 128-bit lanes directly from the V512
// using GetLower128 (lane 0) and AVX512 ExtractVector128 (lanes
// 1-3), then combine as `(s0 + s1) + (s2 + s3)` to preserve the
// prior recursive `Sum(lower256) + Sum(upper256)` grouping.
GenTree* op1Lane1 = fgMakeMultiUse(&op1);
GenTree* op1Lane2 = fgMakeMultiUse(&op1);
GenTree* op1Lane3 = fgMakeMultiUse(&op1);

GenTree* op1Lane0 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, NI_Vector512_GetLower128, simdBaseType, 64);
op1Lane1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1Lane1, gtNewIconNode(1), NI_AVX512_ExtractVector128,
simdBaseType, 64);
op1Lane2 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1Lane2, gtNewIconNode(2), NI_AVX512_ExtractVector128,
simdBaseType, 64);
op1Lane3 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1Lane3, gtNewIconNode(3), NI_AVX512_ExtractVector128,
simdBaseType, 64);

GenTree* lowerSum = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1Lane0, op1Lane1, simdBaseType, 16);
GenTree* upperSum = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1Lane2, op1Lane3, simdBaseType, 16);

simdSize = 16;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, lowerSum, upperSum, simdBaseType, 16);
}
else if (simdSize == 32)
{
GenTree* op1Dup = fgMakeMultiUse(&op1);

op1 = gtNewSimdGetLowerNode(TYP_SIMD16, op1, simdBaseType, 32);
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD16, op1Dup, simdBaseType, 32);

simdSize = 16;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseType, 16);
Comment thread
tannergooding marked this conversation as resolved.
}

assert(simdSize == 16);
return gtNewSimdToScalarNode(type, op1, simdBaseType, 16);
}

// Integer: integer addition is associative, so we can safely reduce the
// upper/lower halves element-wise down to a single V128 before running
// the V128 reduction.

if (simdSize == 64)
{
GenTree* op1Dup = fgMakeMultiUse(&op1);

op1 = gtNewSimdGetLowerNode(TYP_SIMD32, op1, simdBaseType, simdSize);
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD32, op1Dup, simdBaseType, simdSize);

simdSize = 32;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1, op1Dup, simdBaseType, 32);
}

if (simdSize == 32)
{
GenTree* op1Dup = fgMakeMultiUse(&op1);

op1 = gtNewSimdGetLowerNode(TYP_SIMD16, op1, simdBaseType, simdSize);
op1Dup = gtNewSimdGetUpperNode(TYP_SIMD16, op1Dup, simdBaseType, simdSize);

simdSize = 16;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseType, 16);
}

assert(simdSize == 16);

unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int shiftCount = genLog2(vectorLength);
int typeSize = genTypeSize(simdBaseType);
Expand Down
Loading