diff --git a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm index dd731539c8d75..bfdff7009191e 100644 --- a/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm +++ b/onnxruntime/core/mlas/lib/amd64/SoftmaxKernelAvx.asm @@ -21,7 +21,6 @@ INCLUDE mlasi.inc .list - EXTERN MlasMaskMoveTableAvx:NEAR EXTERN MlasMinimumF32Value:NEAR ;++ @@ -46,6 +45,10 @@ INCLUDE mlasi.inc LEAF_ENTRY MlasReduceMaximumF32KernelAvx, _TEXT vbroadcastss ymm0,DWORD PTR [MlasMinimumF32Value] + test rdx,rdx + jz ExitKernel + cmp rdx,8 + jb ProcessRemainingCountBy1 cmp rdx,32 jb ProcessRemainingCountBy8 vmovaps ymm1,ymm0 @@ -74,23 +77,22 @@ ProcessRemainingCountBy8: jmp ProcessRemainingCountBy8 ProcessRemainingCountLessThan8: - test rdx,rdx - jz ReduceScalar - lea r10,MlasMaskMoveTableAvx+8*4 - neg rdx - vmovups ymm3,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm1,ymm3,YMMWORD PTR [rcx] - vmaxps ymm1,ymm0,ymm1 - vblendvps ymm0,ymm0,ymm1,ymm3 ; ignore masked elements - -ReduceScalar: - vextractf128 xmm1,ymm0,1 + vextractf128 xmm1,ymm0,1 ; reduce to single scalar vmaxps xmm0,xmm0,xmm1 vshufps xmm1,xmm0,xmm0,0EEh vmaxps xmm0,xmm0,xmm1 vshufps xmm1,xmm0,xmm0,055h vmaxss xmm0,xmm0,xmm1 + test rdx,rdx + jz ExitKernel +ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rcx] + add rcx,4 ; advance input by 1 element + dec edx + jnz ProcessRemainingCountBy1 + +ExitKernel: vzeroupper ret @@ -149,12 +151,13 @@ ProcessRemainingCountBy8: ProcessRemainingCountLessThan8: test rdx,rdx jz ExitKernel - lea r10,MlasMaskMoveTableAvx+8*4 - neg rdx - vmovups ymm3,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx] - vmulps ymm0,ymm4,ymm0 - vmaskmovps YMMWORD PTR [rcx],ymm3,ymm0 + +ProcessRemainingCountBy1: + vmulss xmm0,xmm4,DWORD PTR [rcx] + vmovss DWORD PTR [rcx],xmm0 + add rcx,4 ; advance output by 1 element + dec edx + jnz ProcessRemainingCountBy1 ExitKernel: vzeroupper @@ -226,13 +229,15 @@ ProcessRemainingCountBy8: ProcessRemainingCountLessThan8: test r8,r8 jz ExitKernel - lea r10,MlasMaskMoveTableAvx+8*4 - neg r8 - vmovups ymm3,YMMWORD PTR [r10+r8*4] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rcx] - vaddps ymm0,ymm4,ymm0 - vsubps ymm0,ymm0,ymm5 ; do as two steps for numeric stability - vmaskmovps YMMWORD PTR [rdx],ymm3,ymm0 + +ProcessRemainingCountBy1: + vaddss xmm0,xmm4,DWORD PTR [rcx] + add rcx,4 ; advance input by 1 element + vsubss xmm0,xmm0,xmm5 + vmovss DWORD PTR [rdx],xmm0 + add rdx,4 ; advance output by 1 element + dec r8d + jnz ProcessRemainingCountBy1 ExitKernel: vzeroupper diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 3566f05be1c82..b81c4b07b86eb 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -289,7 +289,8 @@ Routine Description: This routine computes the exponential function for the supplied vector. - This function handles a narrower range of inputs compaerd to MlasExp + This function handles a narrower range of inputs compared to + MlasComputeExpVector in order to improve efficiency. Arguments: @@ -510,40 +511,45 @@ Return Value: --*/ { - MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(MlasMinimumF32Value); + float Maximum = MlasMinimumF32Value; - if (N >= 16) { + if (N >= 4) { - MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; - MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; - MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; + MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(Maximum); - while (N >= 16) { + if (N >= 16) { - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); - MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, MlasLoadFloat32x4(Input + 4)); - MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MlasLoadFloat32x4(Input + 8)); - MaximumVector3 = MlasMaximumFloat32x4(MaximumVector3, MlasLoadFloat32x4(Input + 12)); + MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; + MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; + MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; + + while (N >= 16) { + + MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); + MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, MlasLoadFloat32x4(Input + 4)); + MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MlasLoadFloat32x4(Input + 8)); + MaximumVector3 = MlasMaximumFloat32x4(MaximumVector3, MlasLoadFloat32x4(Input + 12)); - Input += 16; - N -= 16; + Input += 16; + N -= 16; + } + + MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector1); + MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MaximumVector3); + MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector2); } - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector1); - MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MaximumVector3); - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MaximumVector2); - } + while (N >= 4) { - while (N >= 4) { + MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); + Input += 4; + N -= 4; + } - Input += 4; - N -= 4; + Maximum = MlasReduceMaximumFloat32x4(MaximumVector0); } - float Maximum = MlasReduceMaximumFloat32x4(MaximumVector0); - while (N > 0) { Maximum = (std::max)(Maximum, *Input); diff --git a/onnxruntime/core/mlas/lib/pooling.cpp b/onnxruntime/core/mlas/lib/pooling.cpp index 06a89973a0075..0399b56f76b87 100644 --- a/onnxruntime/core/mlas/lib/pooling.cpp +++ b/onnxruntime/core/mlas/lib/pooling.cpp @@ -21,7 +21,8 @@ Module Name: // threads. // -struct MLAS_WORK_BLOCK { +struct MLAS_POOL_WORK_BLOCK +{ MLAS_POOLING_KIND PoolingKind; size_t InputShape[3]; size_t InputSize; @@ -38,7 +39,7 @@ struct MLAS_WORK_BLOCK { typedef void (MLAS_POOL_KERNEL_ROUTINE)( - const MLAS_WORK_BLOCK* WorkBlock, + const MLAS_POOL_WORK_BLOCK* WorkBlock, size_t ChannelCount, const float* Input, float* Output @@ -87,22 +88,11 @@ struct MLAS_MAXIMUM_POOLING return MlasMaximumFloat32x4(Reduction, Value); } -#if defined(MLAS_NEON64_INTRINSICS) - - static float ReduceFloat32x4(MLAS_FLOAT32X4 Reduction) + static float Reduce(MLAS_FLOAT32X4 Reduction) { - return vmaxvq_f32(Reduction); + return MlasReduceMaximumFloat32x4(Reduction); } -#elif defined(MLAS_NEON32_INTRINSICS) - - static float32x2_t ReducePairwise(float32x2_t Vector0, float32x2_t Vector1) - { - return vpmax_f32(Vector0, Vector1); - } - -#endif - static float AveragePool(float Reduction, float Size) { MLAS_UNREFERENCED_PARAMETER(Size); @@ -169,25 +159,11 @@ struct MLAS_AVERAGE_POOLING return MlasAddFloat32x4(Reduction, Value); } -#if defined(MLAS_NEON64_INTRINSICS) - - static float ReduceFloat32x4(MLAS_FLOAT32X4 Reduction) + static float Reduce(MLAS_FLOAT32X4 Reduction) { - Reduction = vpaddq_f32(Reduction, Reduction); - Reduction = vpaddq_f32(Reduction, Reduction); - - return vgetq_lane_f32(Reduction, 0); - } - -#elif defined(MLAS_NEON32_INTRINSICS) - - static float32x2_t ReducePairwise(float32x2_t Vector0, float32x2_t Vector1) - { - return vpadd_f32(Vector0, Vector1); + return MlasReduceAddFloat32x4(Reduction); } -#endif - static float AveragePool(float Reduction, float Size) { return Reduction / Size; @@ -272,7 +248,7 @@ struct MLAS_AVERAGE_POOLING template void MlasPool1DKernel( - const MLAS_WORK_BLOCK* WorkBlock, + const MLAS_POOL_WORK_BLOCK* WorkBlock, size_t ChannelCount, const float* Input, float* Output @@ -342,7 +318,7 @@ Return Value: template void MlasPool2DKernel( - const MLAS_WORK_BLOCK* WorkBlock, + const MLAS_POOL_WORK_BLOCK* WorkBlock, size_t ChannelCount, const float* Input, float* Output @@ -430,7 +406,7 @@ Return Value: template void MlasPool2DVectorKernel( - const MLAS_WORK_BLOCK* WorkBlock, + const MLAS_POOL_WORK_BLOCK* WorkBlock, size_t ChannelCount, const float* Input, float* Output @@ -654,7 +630,7 @@ Return Value: template void MlasPool3DKernel( - const MLAS_WORK_BLOCK* WorkBlock, + const MLAS_POOL_WORK_BLOCK* WorkBlock, size_t ChannelCount, const float* Input, float* Output @@ -759,7 +735,7 @@ Return Value: template void MlasPool3DVectorKernel( - const MLAS_WORK_BLOCK* WorkBlock, + const MLAS_POOL_WORK_BLOCK* WorkBlock, size_t ChannelCount, const float* Input, float* Output @@ -1027,7 +1003,7 @@ Return Value: template void MlasPoolGlobalKernel( - const MLAS_WORK_BLOCK* WorkBlock, + const MLAS_POOL_WORK_BLOCK* WorkBlock, size_t ChannelCount, const float* Input, float* Output @@ -1081,37 +1057,7 @@ Return Value: // Reduce the vector to a single float value. // -#if defined(MLAS_NEON64_INTRINSICS) - - float ReductionValue = PoolingType::ReduceFloat32x4(Reduction); - -#elif defined(MLAS_NEON32_INTRINSICS) - - float32x2_t ReductionLow = vget_low_f32(Reduction); - float32x2_t ReductionHigh = vget_high_f32(Reduction); - - ReductionLow = PoolingType::ReducePairwise(ReductionLow, ReductionHigh); - ReductionLow = PoolingType::ReducePairwise(ReductionLow, ReductionHigh); - - float ReductionValue = vget_lane_f32(ReductionLow, 0); - -#elif defined(MLAS_SSE2_INTRINSICS) - - Reduction = PoolingType::Reduce(Reduction, _mm_shuffle_ps(Reduction, Reduction, _MM_SHUFFLE(3, 2, 3, 2))); - Reduction = PoolingType::Reduce(Reduction, _mm_shuffle_ps(Reduction, Reduction, _MM_SHUFFLE(1, 1, 1, 1))); - - float ReductionValue = _mm_cvtss_f32(Reduction); - -#elif defined(MLAS_VSX_INTRINSICS) - - Reduction = PoolingType::Reduce(Reduction, MLAS_FLOAT32X4(vec_splat((__vector int64_t)Reduction, 1))); - Reduction = PoolingType::Reduce(Reduction, vec_splat(Reduction, 1)); - - float ReductionValue = Reduction[0]; - -#else -#error Unsupported architecture. -#endif + float ReductionValue = PoolingType::Reduce(Reduction); // // Iterate over the remaining input buffer an element at a time. @@ -1228,7 +1174,7 @@ Return Value: --*/ { - MLAS_WORK_BLOCK WorkBlock; + MLAS_POOL_WORK_BLOCK WorkBlock; WorkBlock.PoolingKind = PoolingKind; @@ -1237,10 +1183,8 @@ Return Value: // and output shapes over the batch and channel counts. // - //TODO: use a safeint here and make sure the result value can fit into int32_t size_t TotalChannelCount = size_t(InputShape[0]) * size_t(InputShape[1]); - InputShape += 2; OutputShape += 2; diff --git a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx.S b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx.S index 453e1f8ee4e74..7432ff0f92415 100644 --- a/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx.S +++ b/onnxruntime/core/mlas/lib/x86_64/SoftmaxKernelAvx.S @@ -46,6 +46,10 @@ Return Value: C_UNDERSCORE(MlasReduceMaximumF32KernelAvx): vbroadcastss ymm0,DWORD PTR C_UNDERSCORE(MlasMinimumF32Value)[rip] + test rsi,rsi + jz .LReduceMaximum.ExitKernel + cmp rsi,8 + jb .LReduceMaximum.ProcessRemainingCountBy1 cmp rsi,32 jb .LReduceMaximum.ProcessRemainingCountBy8 vmovaps ymm1,ymm0 @@ -74,23 +78,22 @@ C_UNDERSCORE(MlasReduceMaximumF32KernelAvx): jmp .LReduceMaximum.ProcessRemainingCountBy8 .LReduceMaximum.ProcessRemainingCountLessThan8: - test rsi,rsi - jz .LReduceMaximum.ReduceScalar - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - neg rsi - vmovups ymm3,YMMWORD PTR [r10+rsi*4] - vmaskmovps ymm1,ymm3,YMMWORD PTR [rdi] - vmaxps ymm1,ymm0,ymm1 - vblendvps ymm0,ymm0,ymm1,ymm3 # ignore masked elements - -.LReduceMaximum.ReduceScalar: - vextractf128 xmm1,ymm0,1 + vextractf128 xmm1,ymm0,1 # reduce to single scalar vmaxps xmm0,xmm0,xmm1 vshufps xmm1,xmm0,xmm0,0xEE vmaxps xmm0,xmm0,xmm1 vshufps xmm1,xmm0,xmm0,0x55 vmaxss xmm0,xmm0,xmm1 + test rsi,rsi + jz .LReduceMaximum.ExitKernel +.LReduceMaximum.ProcessRemainingCountBy1: + vmaxss xmm0,xmm0,DWORD PTR [rdi] + add rdi,4 # advance input by 1 element + dec esi + jnz .LReduceMaximum.ProcessRemainingCountBy1 + +.LReduceMaximum.ExitKernel: vzeroupper ret @@ -148,12 +151,13 @@ C_UNDERSCORE(MlasComputeSoftmaxOutputF32KernelAvx): .LComputeSoftmaxOutput.ProcessRemainingCountLessThan8: test rsi,rsi jz .LComputeSoftmaxOutput.ExitKernel - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - neg rsi - vmovups ymm3,YMMWORD PTR [r10+rsi*4] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] - vmulps ymm0,ymm4,ymm0 - vmaskmovps YMMWORD PTR [rdi],ymm3,ymm0 + +.LComputeSoftmaxOutput.ProcessRemainingCountBy1: + vmulss xmm0,xmm4,DWORD PTR [rdi] + vmovss DWORD PTR [rdi],xmm0 + add rdi,4 # advance output by 1 element + dec esi + jnz .LComputeSoftmaxOutput.ProcessRemainingCountBy1 .LComputeSoftmaxOutput.ExitKernel: vzeroupper @@ -224,13 +228,15 @@ C_UNDERSCORE(MlasComputeLogSoftmaxOutputF32KernelAvx): .LComputeLogSoftmaxOutput.ProcessRemainingCountLessThan8: test rdx,rdx jz .LComputeLogSoftmaxOutput.ExitKernel - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - neg rdx - vmovups ymm3,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm0,ymm3,YMMWORD PTR [rdi] - vaddps ymm0,ymm4,ymm0 - vsubps ymm0,ymm0,ymm5 # do as two steps for numeric stability - vmaskmovps YMMWORD PTR [rsi],ymm3,ymm0 + +.LComputeLogSoftmaxOutput.ProcessRemainingCountBy1: + vaddss xmm0,xmm4,DWORD PTR [rdi] + add rdi,4 # advance input by 1 element + vsubss xmm0,xmm0,xmm5 + vmovss DWORD PTR [rsi],xmm0 + add rsi,4 # advance output by 1 element + dec edx + jnz .LComputeLogSoftmaxOutput.ProcessRemainingCountBy1 .LComputeLogSoftmaxOutput.ExitKernel: vzeroupper