Skip to content

[aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16#17031

Merged
snnn merged 6 commits intomicrosoft:mainfrom
snadampal:sbgemm_aarch64
Jan 22, 2024
Merged

[aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16#17031
snnn merged 6 commits intomicrosoft:mainfrom
snadampal:sbgemm_aarch64

Conversation

@snadampal
Copy link
Contributor

@snadampal snadampal commented Aug 7, 2023

Description

This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to implement matrix multiplication with bfloat16 SIMD instructions (bfmmla) and MatMul operator changes to invoke the Sbgemm kernel. To enable Sbgemm kernel, set the following session option:
"kOrtSessionOptionsGemmFastMathMode"

The PR also adds new test cases for mlas and ort.

Motivation and Context

This is to improve MatMul performance on aarch64 platform.
I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x -1.76x performance improvement compared to sgemm (fp32) kernel performance.

cd onnxruntime/python/tools/transformers
python3 benchmark.py

And the unit test precision results are matching to sgemm kernel results.
./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync

@snadampal
Copy link
Contributor Author

appreciate if someone can review this PR.

@snadampal
Copy link
Contributor Author

Hi @snnn , would you be able to review and provide feedback on this PR? appreciate your time.

@snadampal
Copy link
Contributor Author

Hi, I have rebased the PR to resolve the merge conflicts. I'm happy to address any feedback you may have. Thank you!

@milpuz01
Copy link
Contributor

I have checked out the changes and run performance test and accuracy tests with and without flag using onnxruntime_perf_test (modified the binary to dump output for comparisons) on AWS Graviton3 instances and it was fine.

@snadampal snadampal force-pushed the sbgemm_aarch64 branch 4 times, most recently from eb257ff to 83a6f6e Compare October 4, 2023 19:29
@snadampal
Copy link
Contributor Author

Hi @chenfucn , @yufenglee , I have updated the PR (1) to move to the newer gemm interface and (2) to add session option based fastmath mode control. Please review and let me know your feedback.

@snadampal
Copy link
Contributor Author

Hi @chenfucn , @yufengle, appreciate if someone can trigger the CI for this PR. I have addressed all the feedback except the windows testing for which I'm waiting for the Windows CI results. Thank you!

Copy link
Contributor

@chenfucn chenfucn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, please add mlas unit tests that call the kernel directly with different shapes are other parameters.

@chenfucn
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, Linux QNN CI Pipeline, MacOS CI Pipeline

@chenfucn
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline, Windows ARM64 QNN CI Pipeline, Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 7 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@snadampal
Copy link
Contributor Author

Thanks for the review, I will update the PR to address this and also add unit tests.

@snadampal
Copy link
Contributor Author

snadampal commented Oct 25, 2023

I have updated the PR to address all the feedback so far and also the learnings from my other qgemm PR.
(1) added the feature only for not Apple
(2) added mlas unit tests
(3) tested linux full build (both release and release with debug info)
(4) minimal build
(5) android build with cross compilation on x86. and (5) lintrunner and git-clang-format

Next, adding ort optimizer and provider tests to test the fastmath session.
Please review and let me know if any feedback on this version.

@snadampal
Copy link
Contributor Author

thank you, I see your point. bf16 and f16 are the potential fastmath options, but on aarch64, so far I see interest for bf16 fastmath alone. I agree that there may not be multiple of these for different platforms, so I will go ahead with a simple config key.

static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

Added SbgemmKernel assembly implementation with bfmmla instructions and
sbgemm utility functions to prepack Matrix B along with conversion to bfloat16.
sbgemm kernel is invoked when fastmath mode is enabled and HW supports
the bf16 instruction set. It's disabled by default, set the following
session option to 1 to enable it.
"kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16"
@snadampal
Copy link
Contributor Author

Update the PR for the session name and other points discussed so far including the clang-formatting. Tested

  1. release, debug and minimal builds on aarch64 neoverse v1 and n1 platforms
  2. android build and linux cross compilation for aarch64 config on x86 platform

@chenfucn
Copy link
Contributor

/azp run Linux CPU CI Pipeline, Linux CPU Minimal Build E2E CI Pipeline, Linux GPU CI Pipeline, Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline, Linux QNN CI Pipeline, MacOS CI Pipeline, Windows ARM64 QNN CI Pipeline

@chenfucn
Copy link
Contributor

/azp run Windows CPU CI Pipeline, Windows GPU CI Pipeline, Windows GPU TensorRT CI Pipeline, Windows x64 QNN CI Pipeline, onnxruntime-binary-size-checks-ci-pipeline, orttraining-linux-ci-pipeline, orttraining-linux-gpu-ci-pipeline, orttraining-ortmodule-distributed

@azure-pipelines
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

1 similar comment
@azure-pipelines
Copy link

Azure Pipelines successfully started running 8 pipeline(s).

@snnn snnn merged commit 77da2ef into microsoft:main Jan 22, 2024
@snadampal
Copy link
Contributor Author

Thanks to @chenfucn , @snnn , @skottmckay and @yufenglee for the great feedback and merging the PR!

YUNQIUGUO pushed a commit that referenced this pull request Jan 23, 2024
…oat16 (#17031)

### Description
This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to
implement matrix multiplication with bfloat16 SIMD instructions (bfmmla)
and MatMul operator changes to invoke the Sbgemm kernel. To enable
Sbgemm kernel, set the following session option:
"kOrtSessionOptionsGemmFastMathMode"

The PR also adds new test cases for mlas and ort.

### Motivation and Context

This is to improve MatMul performance on aarch64 platform.
I have run the below benchmarking script (bert , roberta and gpt2 model
inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x
-1.76x performance improvement compared to sgemm (fp32) kernel
performance.

```
cd onnxruntime/python/tools/transformers
python3 benchmark.py
```
And the unit test precision results are matching to sgemm kernel
results.
`./build.sh --config RelWithDebInfo --build_shared_lib --parallel
--compile_no_warning_as_error --skip_submodule_sync `
@snnn
Copy link
Contributor

snnn commented Jan 24, 2024

@snadampal , thanks for making ONNX Runtime better. Welcome to bring more changes to us. You have my email. Do not hesitate to contact me anytime when you need help on reviewing PRs.

@maajidkhann
Copy link

Hello @snadampal . This is a great reference PR for any SIMD based contributions to be made for ARM.
Can you please help me with how do we generate this file (onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S)?
https://github.com/microsoft/onnxruntime/pull/17031/files#diff-6458cefb29cdb4ba0a976ca7ba93e0f3738f6b02e8d6063a51378c4fecfba7c4

My understanding is, we can add the SIMD intrinsic in the .cpp file (onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp) https://github.com/microsoft/onnxruntime/pull/17031/files#diff-a6732e6798dee7a36040e9c388882279bbb70f1e53372a0a751b149429346118 like how you have added NEON code here and then use the gcc/clang compiler to generate the .S or .asm file?

@snadampal
Copy link
Contributor Author

Hi @maajidkhann , Intrinsics based will be the best approach because it scales well for new architectures. but for this PR, I had hand written the assembly. since the goal to extract the best performance.

@snnn
Copy link
Contributor

snnn commented Sep 5, 2025

This PR has been cherry-picked into the rel-1.17.0 branch in PR #19243. Removing the release:1.17.0 label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants