Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 30 additions & 25 deletions onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
INCLUDE mlasi.inc
.list

EXTERN MlasMaskMoveTableAvx:NEAR
EXTERN MlasMinimumF32Value:NEAR

;++
Expand All @@ -46,6 +45,10 @@ INCLUDE mlasi.inc
LEAF_ENTRY MlasReduceMaximumF32KernelAvx, _TEXT

vbroadcastss ymm0,DWORD PTR [MlasMinimumF32Value]
test rdx,rdx
jz ExitKernel
cmp rdx,8
jb ProcessRemainingCountBy1
cmp rdx,32
jb ProcessRemainingCountBy8
vmovaps ymm1,ymm0
Expand Down Expand Up @@ -74,23 +77,22 @@ ProcessRemainingCountBy8:
jmp ProcessRemainingCountBy8

ProcessRemainingCountLessThan8:
test rdx,rdx
jz ReduceScalar
lea r10,MlasMaskMoveTableAvx+8*4
neg rdx
vmovups ymm3,YMMWORD PTR [r10+rdx*4]
vmaskmovps ymm1,ymm3,YMMWORD PTR [rcx]
vmaxps ymm1,ymm0,ymm1
vblendvps ymm0,ymm0,ymm1,ymm3 ; ignore masked elements

ReduceScalar:
vextractf128 xmm1,ymm0,1
vextractf128 xmm1,ymm0,1 ; reduce to single scalar
vmaxps xmm0,xmm0,xmm1
vshufps xmm1,xmm0,xmm0,0EEh
vmaxps xmm0,xmm0,xmm1
vshufps xmm1,xmm0,xmm0,055h
vmaxss xmm0,xmm0,xmm1
test rdx,rdx
jz ExitKernel

ProcessRemainingCountBy1:
vmaxss xmm0,xmm0,DWORD PTR [rcx]
add rcx,4 ; advance input by 1 element
dec edx
jnz ProcessRemainingCountBy1

ExitKernel:
vzeroupper
ret

Expand Down Expand Up @@ -149,12 +151,13 @@ ProcessRemainingCountBy8:
ProcessRemainingCountLessThan8:
test rdx,rdx
jz ExitKernel
lea r10,MlasMaskMoveTableAvx+8*4
neg rdx
vmovups ymm3,YMMWORD PTR [r10+rdx*4]
vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx]
vmulps ymm0,ymm4,ymm0
vmaskmovps YMMWORD PTR [rcx],ymm3,ymm0

ProcessRemainingCountBy1:
vmulss xmm0,xmm4,DWORD PTR [rcx]
vmovss DWORD PTR [rcx],xmm0
add rcx,4 ; advance output by 1 element
dec edx
jnz ProcessRemainingCountBy1

ExitKernel:
vzeroupper
Expand Down Expand Up @@ -226,13 +229,15 @@ ProcessRemainingCountBy8:
ProcessRemainingCountLessThan8:
test r8,r8
jz ExitKernel
lea r10,MlasMaskMoveTableAvx+8*4
neg r8
vmovups ymm3,YMMWORD PTR [r10+r8*4]
vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx]
vaddps ymm0,ymm4,ymm0
vsubps ymm0,ymm0,ymm5 ; do as two steps for numeric stability
vmaskmovps YMMWORD PTR [rdx],ymm3,ymm0

ProcessRemainingCountBy1:
vaddss xmm0,xmm4,DWORD PTR [rcx]
add rcx,4 ; advance input by 1 element
vsubss xmm0,xmm0,xmm5
vmovss DWORD PTR [rdx],xmm0
add rdx,4 ; advance output by 1 element
dec r8d
jnz ProcessRemainingCountBy1

ExitKernel:
vzeroupper
Expand Down
52 changes: 29 additions & 23 deletions onnxruntime/core/mlas/lib/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ Routine Description:

This routine computes the exponential function for the supplied vector.

This function handles a narrower range of inputs compaerd to MlasExp
This function handles a narrower range of inputs compared to
MlasComputeExpVector in order to improve efficiency.

Arguments:

Expand Down Expand Up @@ -510,40 +511,45 @@ Return Value:

--*/
{
MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(MlasMinimumF32Value);
float Maximum = MlasMinimumF32Value;

if (N >= 16) {
if (N >= 4) {

MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0;
MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0;
MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0;
MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(Maximum);

while (N >= 16) {
if (N >= 16) {

MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input));
MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, MlasLoadFloat32x4(Input + 4));
MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MlasLoadFloat32x4(Input + 8));
MaximumVector3 = MlasMaximumFloat32x4(MaximumVector3, MlasLoadFloat32x4(Input + 12));
MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0;
MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0;
MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0;

while (N >= 16) {

MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input));
MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, MlasLoadFloat32x4(Input + 4));
MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MlasLoadFloat32x4(Input + 8));
MaximumVector3 = MlasMaximumFloat32x4(MaximumVector3, MlasLoadFloat32x4(Input + 12));

Input += 16;
N -= 16;
Input += 16;
N -= 16;
}

MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector1);
MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MaximumVector3);
MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector2);
}

MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector1);
MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MaximumVector3);
MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector2);
}
while (N >= 4) {

while (N >= 4) {
MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input));

MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input));
Input += 4;
N -= 4;
}

Input += 4;
N -= 4;
Maximum = MlasReduceMaximumFloat32x4(MaximumVector0);
}

float Maximum = MlasReduceMaximumFloat32x4(MaximumVector0);

while (N > 0) {

Maximum = (std::max)(Maximum, *Input);
Expand Down
86 changes: 15 additions & 71 deletions onnxruntime/core/mlas/lib/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ Module Name:
// threads.
//

struct MLAS_WORK_BLOCK {
struct MLAS_POOL_WORK_BLOCK
{
MLAS_POOLING_KIND PoolingKind;
size_t InputShape[3];
size_t InputSize;
Expand All @@ -38,7 +39,7 @@ struct MLAS_WORK_BLOCK {
typedef
void
(MLAS_POOL_KERNEL_ROUTINE)(
const MLAS_WORK_BLOCK* WorkBlock,
const MLAS_POOL_WORK_BLOCK* WorkBlock,
size_t ChannelCount,
const float* Input,
float* Output
Expand Down Expand Up @@ -87,22 +88,11 @@ struct MLAS_MAXIMUM_POOLING
return MlasMaximumFloat32x4(Reduction, Value);
}

#if defined(MLAS_NEON64_INTRINSICS)

static float ReduceFloat32x4(MLAS_FLOAT32X4 Reduction)
static float Reduce(MLAS_FLOAT32X4 Reduction)
{
return vmaxvq_f32(Reduction);
return MlasReduceMaximumFloat32x4(Reduction);
}

#elif defined(MLAS_NEON32_INTRINSICS)

static float32x2_t ReducePairwise(float32x2_t Vector0, float32x2_t Vector1)
{
return vpmax_f32(Vector0, Vector1);
}

#endif

static float AveragePool(float Reduction, float Size)
{
MLAS_UNREFERENCED_PARAMETER(Size);
Expand Down Expand Up @@ -169,25 +159,11 @@ struct MLAS_AVERAGE_POOLING
return MlasAddFloat32x4(Reduction, Value);
}

#if defined(MLAS_NEON64_INTRINSICS)

static float ReduceFloat32x4(MLAS_FLOAT32X4 Reduction)
static float Reduce(MLAS_FLOAT32X4 Reduction)
{
Reduction = vpaddq_f32(Reduction, Reduction);
Reduction = vpaddq_f32(Reduction, Reduction);

return vgetq_lane_f32(Reduction, 0);
}

#elif defined(MLAS_NEON32_INTRINSICS)

static float32x2_t ReducePairwise(float32x2_t Vector0, float32x2_t Vector1)
{
return vpadd_f32(Vector0, Vector1);
return MlasReduceAddFloat32x4(Reduction);
}

#endif

static float AveragePool(float Reduction, float Size)
{
return Reduction / Size;
Expand Down Expand Up @@ -272,7 +248,7 @@ struct MLAS_AVERAGE_POOLING
template<typename PoolingType>
void
MlasPool1DKernel(
const MLAS_WORK_BLOCK* WorkBlock,
const MLAS_POOL_WORK_BLOCK* WorkBlock,
size_t ChannelCount,
const float* Input,
float* Output
Expand Down Expand Up @@ -342,7 +318,7 @@ Return Value:
template<typename PoolingType>
void
MlasPool2DKernel(
const MLAS_WORK_BLOCK* WorkBlock,
const MLAS_POOL_WORK_BLOCK* WorkBlock,
size_t ChannelCount,
const float* Input,
float* Output
Expand Down Expand Up @@ -430,7 +406,7 @@ Return Value:
template<typename PoolingType>
void
MlasPool2DVectorKernel(
const MLAS_WORK_BLOCK* WorkBlock,
const MLAS_POOL_WORK_BLOCK* WorkBlock,
size_t ChannelCount,
const float* Input,
float* Output
Expand Down Expand Up @@ -654,7 +630,7 @@ Return Value:
template<typename PoolingType>
void
MlasPool3DKernel(
const MLAS_WORK_BLOCK* WorkBlock,
const MLAS_POOL_WORK_BLOCK* WorkBlock,
size_t ChannelCount,
const float* Input,
float* Output
Expand Down Expand Up @@ -759,7 +735,7 @@ Return Value:
template<typename PoolingType>
void
MlasPool3DVectorKernel(
const MLAS_WORK_BLOCK* WorkBlock,
const MLAS_POOL_WORK_BLOCK* WorkBlock,
size_t ChannelCount,
const float* Input,
float* Output
Expand Down Expand Up @@ -1027,7 +1003,7 @@ Return Value:
template<typename PoolingType>
void
MlasPoolGlobalKernel(
const MLAS_WORK_BLOCK* WorkBlock,
const MLAS_POOL_WORK_BLOCK* WorkBlock,
size_t ChannelCount,
const float* Input,
float* Output
Expand Down Expand Up @@ -1081,37 +1057,7 @@ Return Value:
// Reduce the vector to a single float value.
//

#if defined(MLAS_NEON64_INTRINSICS)

float ReductionValue = PoolingType::ReduceFloat32x4(Reduction);

#elif defined(MLAS_NEON32_INTRINSICS)

float32x2_t ReductionLow = vget_low_f32(Reduction);
float32x2_t ReductionHigh = vget_high_f32(Reduction);

ReductionLow = PoolingType::ReducePairwise(ReductionLow, ReductionHigh);
ReductionLow = PoolingType::ReducePairwise(ReductionLow, ReductionHigh);

float ReductionValue = vget_lane_f32(ReductionLow, 0);

#elif defined(MLAS_SSE2_INTRINSICS)

Reduction = PoolingType::Reduce(Reduction, _mm_shuffle_ps(Reduction, Reduction, _MM_SHUFFLE(3, 2, 3, 2)));
Reduction = PoolingType::Reduce(Reduction, _mm_shuffle_ps(Reduction, Reduction, _MM_SHUFFLE(1, 1, 1, 1)));

float ReductionValue = _mm_cvtss_f32(Reduction);

#elif defined(MLAS_VSX_INTRINSICS)

Reduction = PoolingType::Reduce(Reduction, MLAS_FLOAT32X4(vec_splat((__vector int64_t)Reduction, 1)));
Reduction = PoolingType::Reduce(Reduction, vec_splat(Reduction, 1));

float ReductionValue = Reduction[0];

#else
#error Unsupported architecture.
#endif
float ReductionValue = PoolingType::Reduce(Reduction);

//
// Iterate over the remaining input buffer an element at a time.
Expand Down Expand Up @@ -1228,7 +1174,7 @@ Return Value:

--*/
{
MLAS_WORK_BLOCK WorkBlock;
MLAS_POOL_WORK_BLOCK WorkBlock;

WorkBlock.PoolingKind = PoolingKind;

Expand All @@ -1237,10 +1183,8 @@ Return Value:
// and output shapes over the batch and channel counts.
//

//TODO: use a safeint here and make sure the result value can fit into int32_t
size_t TotalChannelCount = size_t(InputShape[0]) * size_t(InputShape[1]);


InputShape += 2;
OutputShape += 2;

Expand Down
Loading