diff --git a/.github/actions/webgpu-validate-shader-key/action.yml b/.github/actions/webgpu-validate-shader-key/action.yml index 7b341d38ea906..86406a2e91877 100644 --- a/.github/actions/webgpu-validate-shader-key/action.yml +++ b/.github/actions/webgpu-validate-shader-key/action.yml @@ -22,7 +22,7 @@ runs: working-directory: ${{ github.action_path }} - name: Validate shader keys (native log) - if: ${{ !inputs.is_chromium_log != 'true' }} + if: ${{ inputs.is_chromium_log != 'true' }} shell: cmd run: | node validate-shader-key.js < "${{ inputs.log_file_path }}" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8d966d358de01..21baab0fd191c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,7 +26,7 @@ jobs: level: info filter_mode: diff_context - name: shellcheck # Static check shell scripts - uses: reviewdog/action-shellcheck@v1 + uses: reviewdog/action-shellcheck@v1.30.0 with: github_token: ${{ secrets.github_token }} reporter: github-pr-check diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index ce0e5167eb0a0..57f687d8502ff 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -200,7 +200,6 @@ jobs: - name: Validate shader keys - WebGPU EP if: ${{ inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }} - continue-on-error: true uses: ./.github/actions/webgpu-validate-shader-key with: log_file_path: ${{ runner.temp }}\web\test\07\chrome_debug.log diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index e8cea1c5805a3..8b3b8a2fcde54 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -126,7 +126,6 @@ jobs: dir ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log - name: Validate shader keys - continue-on-error: true uses: ./.github/actions/webgpu-validate-shader-key with: log_file_path: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index b64a5c3e5a4a2..77c35aac65b92 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -3,6 +3,7 @@ using System; using System.Runtime.InteropServices; +using static Microsoft.ML.OnnxRuntime.NativeMethods; namespace Microsoft.ML.OnnxRuntime { @@ -325,6 +326,16 @@ public struct OrtApi public IntPtr CreateLoraAdapterFromArray; public IntPtr ReleaseLoraAdapter; public IntPtr RunOptionsAddActiveLoraAdapter; + public IntPtr SetEpDynamicOptions; + public IntPtr ReleaseValueInfo; + public IntPtr ReleaseNode; + public IntPtr ReleaseGraph; + public IntPtr ReleaseModel; + public IntPtr GetValueInfoName; + public IntPtr GetValueInfoTypeInfo; + public IntPtr GetModelEditorApi; + public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; + public IntPtr SessionOptionsSetLoadCancellationFlag; } internal static class NativeMethods @@ -404,6 +415,7 @@ static NativeMethods() OrtReleaseSessionOptions = (DOrtReleaseSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSessionOptions, typeof(DOrtReleaseSessionOptions)); OrtCloneSessionOptions = (DOrtCloneSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.CloneSessionOptions, typeof(DOrtCloneSessionOptions)); OrtSetSessionExecutionMode = (DOrtSetSessionExecutionMode)Marshal.GetDelegateForFunctionPointer(api_.SetSessionExecutionMode, typeof(DOrtSetSessionExecutionMode)); + OrtSessionOptionsSetLoadCancellationFlag = (DOrtSessionOptionsSetLoadCancellationFlag)Marshal.GetDelegateForFunctionPointer(api_.SessionOptionsSetLoadCancellationFlag, typeof(DOrtSessionOptionsSetLoadCancellationFlag)); OrtSetOptimizedModelFilePath = (DOrtSetOptimizedModelFilePath)Marshal.GetDelegateForFunctionPointer(api_.SetOptimizedModelFilePath, typeof(DOrtSetOptimizedModelFilePath)); OrtEnableProfiling = (DOrtEnableProfiling)Marshal.GetDelegateForFunctionPointer(api_.EnableProfiling, typeof(DOrtEnableProfiling)); OrtDisableProfiling = (DOrtDisableProfiling)Marshal.GetDelegateForFunctionPointer(api_.DisableProfiling, typeof(DOrtDisableProfiling)); @@ -1025,6 +1037,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca ExecutionMode execution_mode); public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsSetLoadCancellationFlag(IntPtr /*(OrtSessionOptions*)*/ options, + bool value); + public static DOrtSessionOptionsSetLoadCancellationFlag OrtSessionOptionsSetLoadCancellationFlag; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath); public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index bd450451a1265..9b0f183f03681 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -802,6 +802,16 @@ public ExecutionMode ExecutionMode } private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL; + /// + /// Sets the load cancellation flag for the session. Default is set to false. + /// Provides an opportunity for the user to cancel model loading. + /// + /// true to request cancellation, false to proceed + public void SetLoadCancellationFlag(bool value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value)); + } + #endregion #region Private Methods diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f582abca34706..0308b5c79c508 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2039,10 +2039,11 @@ This version of the operator has been available since version 1 of the 'com.micr 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point. + If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. + 5. For uint8 data, the `gather_axis` must be 0. #### Version @@ -2082,7 +2083,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(int4), tensor(uint4)
+
T1 : tensor(int4), tensor(uint4), tensor(uint8)
Constrain quantized types.
T2 : tensor(float), tensor(float16), tensor(bfloat16)
Constrain dequantized types.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 60d9e8e747eeb..a20333e2340c4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -515,7 +515,7 @@ Do not modify directly.* |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedGemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| +|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| |GatherND|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index 0822eba950f50..10f658f52e0d9 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch abort(); \ } while (false) +#define ORT_THROW_FROM_STATUS(status) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException( \ + ORT_WHERE_WITH_STACK, status.ToString()) \ + .what()); \ + abort(); \ + } while (false) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) \ + .what()); \ + abort(); \ + } while (false) + #else #define ORT_TRY try @@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch #define ORT_THROW_EX(ex, ...) \ throw ex(__VA_ARGS__) +#define ORT_THROW_FROM_STATUS(status) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \ + static_cast<::onnxruntime::common::StatusCode>(status.Code())) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) + #endif #define ORT_MAKE_STATUS(category, code, ...) \ @@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch auto _status = (expr); \ if ((!_status.IsOK())) { \ ::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \ - ORT_THROW(_status); \ + ORT_THROW_FROM_STATUS(_status); \ } \ } while (0) diff --git a/include/onnxruntime/core/common/exceptions.h b/include/onnxruntime/core/common/exceptions.h index 494a770b8db98..6d0f6edd6e7c4 100644 --- a/include/onnxruntime/core/common/exceptions.h +++ b/include/onnxruntime/core/common/exceptions.h @@ -11,6 +11,7 @@ #include #include "core/common/common.h" +#include "core/common/status.h" #include "core/common/code_location.h" namespace onnxruntime { @@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception { /** Create a new exception that captures the location it was thrown from. @param location Location in the source code the exception is being thrown from + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + + OnnxRuntimeException(const CodeLocation& location, + const std::string& message, + common::StatusCategory category, + common::StatusCode code) noexcept + : OnnxRuntimeException(location, nullptr, message, category, code) { + } + + /** + Create a new exception that captures the location it was thrown from. + The instance will be created with ONNXRUNTIME category and FAIL code. + @param location Location in the source code the exception is being thrown from @param failed_condition Optional string containing the condition that failed. e.g. "tensor.Size() == input.Size()". May be nullptr. @param msg Message containing additional information about the exception cause. */ - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept + : OnnxRuntimeException(location, failed_condition, msg, + common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg, + common::StatusCategory category, + common::StatusCode code) + : location_{location}, category_(category), code_(code) { std::ostringstream ss; ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous @@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception { what_ = ss.str(); } + common::StatusCategory Category() const noexcept { + return category_; + } + + common::StatusCode Code() const noexcept { + return code_; + } + const char* what() const noexcept override { return what_.c_str(); } @@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception { const CodeLocation location_; const std::vector stacktrace_; std::string what_; + common::StatusCategory category_; + common::StatusCode code_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index 8f171daabbb1e..b222e411d7804 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -43,7 +43,8 @@ enum StatusCode { MODEL_LOADED = 8, NOT_IMPLEMENTED = 9, INVALID_GRAPH = 10, - EP_FAIL = 11 + EP_FAIL = 11, + MODEL_LOAD_CANCELED = 12, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -72,6 +73,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "INVALID_GRAPH"; case StatusCode::EP_FAIL: return "EP_FAIL"; + case StatusCode::MODEL_LOAD_CANCELED: + return "MODEL_LOAD_CANCELED"; default: return "GENERAL ERROR"; } @@ -104,6 +107,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); case StatusCode::EP_FAIL: return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::MODEL_LOAD_CANCELED: + return HRESULT_FROM_WIN32(ERROR_CANCELLED); default: return E_FAIL; } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6d4cc8a1f2fa9..3bf0d5e19c525 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -255,6 +255,7 @@ typedef enum OrtErrorCode { ORT_NOT_IMPLEMENTED, ORT_INVALID_GRAPH, ORT_EP_FAIL, + ORT_MODEL_LOAD_CANCELED, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -4898,6 +4899,24 @@ struct OrtApi { _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** \brief sets load cancellation flag to abort session loading process. + * + * \param[in] options instance that was passed to the session at creation time. + * \param[in] cancel setting this to true after model loading process was initiated will + * attempt to cancel the loading process. If cancellation is successful, CreateSession() + * CreateSessionFromArray() or any other session creation API that take session options as an + * argument will return an OrtStatus indicating that session loading was canceled at user request, + * error code ORT_MODEL_LOAD_CANCELED. + * The APIs above would not return any valid Session instance. This is the best case effort and the result + * is not guaranteed. The session may have already been created and initialized + * before the cancellation request was issued. + * + * \snippet{doc} snippets.dox OrtStatus + * + */ + ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool cancel); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 979b478e2fbb4..ce7dc1c45b05e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -928,6 +928,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode + SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag + SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 48c5e52e33c53..524e3ecc92936 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -747,6 +747,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionM return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLoadCancellationFlag(bool value) { + ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index 9e4730a407d57..708e458748b3a 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -12,7 +12,7 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.4" + "vite": "^6.2.5" } }, "node_modules/@babel/helper-string-parser": { @@ -1069,9 +1069,9 @@ } }, "node_modules/vite": { - "version": "6.2.4", - "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.4.tgz", - "integrity": "sha512-veHMSew8CcRzhL5o8ONjy8gkfmFJAd5Ac16oxBUjlwgX3Gq2Wqr+qNC3TjPIpy7TPV/KporLga5GT9HqdrCizw==", + "version": "6.2.5", + "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.5.tgz", + "integrity": "sha512-j023J/hCAa4pRIUH6J9HemwYfjB5llR2Ps0CWeikOtdR8+pAURAk0DoJC5/mm9kd+UgdnIy7d6HE4EAvlYhPhA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/test/e2e/exports/testcases/vite-default/package.json b/js/web/test/e2e/exports/testcases/vite-default/package.json index e06733f917e3f..904db7a41de9c 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package.json @@ -13,6 +13,6 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.4" + "vite": "^6.2.5" } } diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 345b5e793a764..1a737f3a9d251 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized); @@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc index 5935663f114a3..b83164d806ffc 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc @@ -16,6 +16,21 @@ namespace onnxruntime { namespace contrib { +namespace { +template +int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) { + return static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1))); +} + +template <> +int32_t GetDataElement(const uint8_t* data_ptr, int64_t data_idx) { + const uint8_t data_val_u8 = data_ptr[data_idx >> 1]; + // Weights are stored as (nibble2)(nibble1) in uint8_t. + auto data_val = static_cast((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F)); + return data_val; +} +} // namespace + template class GatherBlockQuantized : public OpKernel { public: @@ -98,6 +113,12 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex for (int64_t i = p.gather_axis + 1; i < static_cast(data_rank); ++i) shape.push_back(data_shape[narrow(i)]); + // When data is stored as uint8_t, each element has two int4 values. + // The shape in the onnx model reflects that by having the last dimension be half the number of values. + // Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536. + // However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here. + uint32_t components = (std::is_same_v) ? 2 : 1; + shape[shape.size() - 1] = shape.back() * components; p.output_tensor = context->Output(0, TensorShape(std::move(shape))); // validate quantization parameters @@ -106,7 +127,7 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex "data and scales must have the same rank."); for (size_t i = 0; i < data_shape.NumDimensions(); ++i) { ORT_RETURN_IF_NOT(i == static_cast(p.quantize_axis) - ? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i] + ? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i] : data_shape[i] == scales_shape[i], "data and scales do not match shapes."); } @@ -165,16 +186,22 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr, int64_t output_idx = output_idx_base; int64_t data_idx = data_idx_base; for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) { - auto data_val = static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1))); + auto data_val = GetDataElement(data_ptr, data_idx); int64_t x = data_idx / quantize_full_block; int64_t y = data_idx % quantize_full_block / quantize_N; int64_t z = data_idx % quantize_N; int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z; auto scale_val = static_cast(scales_ptr[scale_idx]); - auto zp_val = static_cast(zero_points_ptr - ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1)) - : 0); + int32_t zp_val; + if constexpr (std::is_same_v) { + // The default zero point for uint8 weights as stored by MatMulNBits op is 8. + zp_val = 8; + } else { + zp_val = static_cast(zero_points_ptr + ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1)) + : 0); + } output_ptr[output_idx] = static_cast(static_cast(data_val - zp_val) * scale_val); } @@ -205,7 +232,7 @@ template Status GatherBlockQuantized::Compute(OpKernelContext* context) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); - + auto components = (std::is_same_v) ? 2 : 1; const auto& data_shape = p.data_tensor->Shape(); // re-shape the data tensor to [gather_M, gather_axis_dim, gather_block] // re-shape the indices tensor to [gather_N] @@ -215,7 +242,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { // 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N], // 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :], // 4> pick the element from the block: value_i = data_blk[blk_ele_i] - const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1); + const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1) * components; const int64_t gather_axis_dim = data_shape[narrow(p.gather_axis)]; const int64_t gather_M = data_shape.SizeToDimension(narrow(p.gather_axis)); const int64_t gather_N = p.indices_tensor->Shape().Size(); @@ -229,7 +256,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { // data_i % (quantize_axis_dim * quantize_N) / quantize_N, // data_i % quantize_N) // 4> get scale index: (x, y / block_size_, z) - const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)]; + const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)] * components; const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt(p.quantize_axis) + 1); concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); @@ -273,6 +300,8 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ GatherBlockQuantized); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t); REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t); REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t); REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 0d4afc8c13f4b..6e7919f281fb6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -99,23 +99,24 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var tileK: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "// x holds the N and y holds the M\n" - << "let m = workgroup_id.y * TILE_SIZE;\n" - << "let n = workgroup_id.x * TILE_SIZE;\n" - << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n" + << "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n" + << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));\n" + << "let batch_idx = batch_head_idx / uniforms.num_heads;\n" + << "let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n"; } shader.MainFunctionBody() << "var value = f32_val_t(0);\n" "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" " }\n" " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" @@ -123,7 +124,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" - << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" + << " let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" @@ -152,9 +153,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " workgroupBarrier();\n" << "}\n"; - shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" - << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" - << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n" + << " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n" + << " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n" << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; @@ -181,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -199,9 +200,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o } const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; - program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, - (parameters.sequence_length_ + tile_size - 1) / tile_size, - parameters.batch_size_ * parameters.num_heads_) + const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile) .SetWorkgroupSize(tile_size, tile_size) .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, @@ -214,7 +215,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_ ? 1 : 0)}}) + {static_cast(parameters.is_first_prompt_ ? 1 : 0)}, + {num_total_seq_length_tile}, + {num_seq_length_tile}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); @@ -228,15 +231,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var thread_max: array;\n" << "var thread_sum: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let sequence_length = uniforms.sequence_length;\n" + shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n" + << "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" - << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" @@ -292,7 +295,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) .CacheHint(work_group_size) - .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) + .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, @@ -321,19 +324,20 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var tileQ: array;\n" << "var tileK: array;\n"; - shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" - << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" + shader.MainFunctionBody() << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * uniforms.num_seq_length_tile));\n" + << "let head_idx = batch_head_idx % uniforms.num_heads;\n" + << "let batch_idx = batch_head_idx / uniforms.num_heads;\n" + << "let m = (u32(workgroup_idx / uniforms.num_head_size_tile) % uniforms.num_seq_length_tile) * TILE_SIZE + local_id.y;\n" + << "let n = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE + local_id.x;\n" + << "let offsetA = batch_head_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; + shader.MainFunctionBody() << "let vOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n"; if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n"; } shader.MainFunctionBody() << "var value = output_value_t(0);\n" @@ -346,7 +350,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" - << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n" + << " let pastValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n" @@ -400,7 +404,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1); constexpr int tile_size = 12; int tile_n_size = tile_size * components; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_value) { @@ -414,9 +418,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank, components}); } - program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size, - (parameters.sequence_length_ + tile_size - 1) / tile_size, - parameters.batch_size_ * parameters.num_heads_) + const uint32_t num_head_size_tile = (parameters.v_head_size_ + tile_n_size - 1) / tile_n_size; + const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_head_size_tile * num_seq_length_tile) .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, @@ -429,7 +433,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_)}}) + {static_cast(parameters.is_first_prompt_)}, + {num_head_size_tile}, + {num_seq_length_tile}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 164ea72b07d9d..7c0cb40cc7f93 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -50,7 +50,9 @@ class AttentionProbsProgram final : public Program { {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32}, + {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -60,7 +62,6 @@ class AttentionProbsProgram final : public Program { bool has_attention_bias_; int tile_size_; int components_; - int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; @@ -90,8 +91,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -106,7 +107,9 @@ class VxAttentionScoreProgram final : public Program { {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"num_head_size_tile", ProgramUniformVariableDataType::Uint32}, + {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -114,7 +117,6 @@ class VxAttentionScoreProgram final : public Program { bool feed_past_value_; bool has_present_value_; int tile_size_; - int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; diff --git a/onnxruntime/contrib_ops/webgpu/fused_conv.cc b/onnxruntime/contrib_ops/webgpu/fused_conv.cc new file mode 100644 index 0000000000000..e6b7ac3ec24d4 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/fused_conv.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/conv.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { +using onnxruntime::webgpu::Conv; +template +class FusedConv final : public Conv { + public: + FusedConv(const OpKernelInfo& info) : Conv(info) { + ORT_ENFORCE(GetFusedActivationAttr(info, Conv::activation_).IsOK()); + } +}; + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", onnxruntime::webgpu::WebGpuSupportedFloatTypes()), + FusedConv); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 6e63ba3a0caa4..4136477a1d88c 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -40,7 +40,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h index 703d183ea5c87..b42c6a9ba3e10 100644 --- a/onnxruntime/core/framework/error_code_helper.h +++ b/onnxruntime/core/framework/error_code_helper.h @@ -17,16 +17,19 @@ Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category = c #ifndef ORT_NO_EXCEPTIONS #define API_IMPL_BEGIN try { -#define API_IMPL_END \ - } \ - catch (const onnxruntime::NotImplementedException& ex) { \ - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \ - } \ - catch (const std::exception& ex) { \ - return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ - } \ - catch (...) { \ - return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \ +#define API_IMPL_END \ + } \ + catch (const onnxruntime::OnnxRuntimeException& ex) { \ + return OrtApis::CreateStatus(static_cast(ex.Code()), ex.what()); \ + } \ + catch (const onnxruntime::NotImplementedException& ex) { \ + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \ + } \ + catch (const std::exception& ex) { \ + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ + } \ + catch (...) { \ + return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \ } #else diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index ff4d300f665b1..50f14104cfd7a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -56,6 +56,7 @@ namespace { // contains some common parameters used by the partitioning helper functions struct PartitionParams { std::reference_wrapper graph; + std::reference_wrapper check_load_cancellation_fn; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) std::reference_wrapper func_mgr; std::reference_wrapper fused_kernel_registry; @@ -143,6 +144,7 @@ struct GetCapabilityForEPParams { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) IResourceAccountant* resource_accountant; std::reference_wrapper graph_optimizer_registry; + std::reference_wrapper check_load_cancellation_fn; }; auto get_capabilities = [](const IExecutionProvider& ep, @@ -188,7 +190,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, + graph_optimizer_registry); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Graph partitioning was canceled by user request"); + } if (capabilities.empty()) { return Status::OK(); @@ -209,6 +216,10 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l // Perform layout transformation on the specific EP assigned graph bool modified = false; ORT_RETURN_IF_ERROR(params.transform_layout(graph, modified, current_ep, params.debug_graph_fn)); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "GetCapabilities was canceled by user request"); + } // It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes // which are reconstructed to update domain or completely new nodes which are necessary for layout transformation. @@ -226,7 +237,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, + graph_optimizer_registry); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "GetCapabilities was canceled by user request"); + } // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -405,6 +421,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, + const CheckLoadCancellationFn& check_load_cancellation_fn, const logging::Logger& logger, IResourceAccountant* resource_accountant, const GraphOptimizerRegistry& graph_optimizer_registry) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -420,7 +437,10 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry)); + transform_layout_fn, debug_graph_fn, + check_load_cancellation_fn, + logger, resource_accountant, + graph_optimizer_registry)); } } @@ -445,7 +465,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(transform_layout_fn), std::cref(debug_graph_fn), resource_accountant, - std::ref(graph_optimizer_registry)}; + std::ref(graph_optimizer_registry), + std::cref(check_load_cancellation_fn)}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -532,6 +553,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, } ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs)); + ORT_RETURN_IF(check_load_cancellation_fn(), + "Graph partitioning is canceled due to user request."); if (node_compute_funcs.size() != nodes_to_compile.size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); @@ -633,6 +656,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide Graph& graph, const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, + const CheckLoadCancellationFn& check_load_cancellation_fn, InlinedHashSet& not_inlined, size_t& inlined_count) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -650,6 +674,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide *subgraph, graph_optimizer_registry, logger, + check_load_cancellation_fn, not_inlined, inlined_count)); } @@ -673,8 +698,13 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, + graph_optimizer_registry, logger, capabilities)); + if (check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request."); + } + for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; if (nodes.size() == 1) { @@ -707,6 +737,9 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id))); } } + if (check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request."); + } } return Status::OK(); @@ -846,6 +879,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, auto& fused_kernel_registry = partition_params.fused_kernel_registry.get(); auto& fused_node_unique_id = partition_params.fused_node_unique_id.get(); const auto& transform_layout_function = partition_params.transform_layout_function; + const CheckLoadCancellationFn& check_load_cancellation_fn = partition_params.check_load_cancellation_fn; do { // process full graph with each EP @@ -861,6 +895,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, + check_load_cancellation_fn, logger, resource_accountant, graph_optimizer_registry)); } @@ -915,7 +950,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::cref(partition_params.debug_graph_fn), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) nullptr, - std::ref(graph_optimizer_registry) + std::ref(graph_optimizer_registry), + partition_params.check_load_cancellation_fn }; // clang-format on @@ -972,6 +1008,9 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::vector single_node_compute_func; ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}}, single_node_compute_func)); + if (partition_params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph partitioning is canceled due to user request."); + } ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element."); auto& func_mgr = partition_params.func_mgr.get(); @@ -1032,6 +1071,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, return Status::OK(); } + auto check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); }; + auto& graph = model.MainGraph(); InlinedHashSet not_inlined; do { @@ -1041,13 +1082,13 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, graph, *graph_optimizer_registry_, logger, + check_load_cancellation_fn, not_inlined, inlined_count)); if (inlined_count == 0) { break; } - ORT_RETURN_IF_ERROR(graph.Resolve()); } while (true); @@ -1082,6 +1123,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified."); } + CheckLoadCancellationFn check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); }; + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // fused_kernel_registry is preparing the kernels created on the fly for fused sub graph. // It is only visible for current session. @@ -1092,6 +1135,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, PartitionParams partition_params{ std::ref(graph), + std::cref(check_load_cancellation_fn), std::ref(func_mgr), std::ref(*fused_kernel_registry), std::ref(fused_node_unique_id), @@ -1105,6 +1149,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, ORT_UNUSED_PARAMETER(debug_graph_fn); PartitionParams partition_params{ std::ref(graph), + std::cref(check_load_cancellation_fn), }; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index b9d4022cb5a14..87edc7a64c6b5 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -33,6 +33,16 @@ class GraphPartitioner { graph_optimizer_registry_(std::move(graph_optimizer_registry)) { } + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, + const ExecutionProviders& providers, + std::unique_ptr graph_optimizer_registry, + CheckLoadCancellationFn check_load_cancellation_fn) + : kernel_registry_mgr_(kernel_registry_mgr), + providers_(providers), + graph_optimizer_registry_(std::move(graph_optimizer_registry)), + check_load_cancellation_fn_(std::move(check_load_cancellation_fn)) { + } + // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, @@ -41,6 +51,10 @@ class GraphPartitioner { Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; + bool IsLoadCancellationFlagSet() const { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + #ifndef ORT_MINIMAL_BUILD /// // Ahead of Time Function inlining. The main purpose of the function is to inline as many @@ -69,6 +83,7 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; std::unique_ptr graph_optimizer_registry_; + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8d4db36106f28..ef323b99b006c 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "core/common/inlined_containers.h" #include "core/framework/config_options.h" @@ -66,6 +67,8 @@ struct FreeDimensionOverride { int64_t dim_value; }; +using CheckLoadCancellationFn = std::function; + /** * Configuration information for a session. */ @@ -184,6 +187,18 @@ struct SessionOptions { // User specified logging func and param OrtLoggingFunction user_logging_function = nullptr; void* user_logging_param = nullptr; + + void SetLoadCancellationFlag(bool value) noexcept { + *load_cancellation_flag = value; + } + + bool IsLoadCancellationFlagSet() const noexcept { + return *load_cancellation_flag; + } + + // Load cancellation flag is necessary to be within shared memory as session_options are + // copied internally and the flag needs to be accessible across all copies. + std::shared_ptr load_cancellation_flag = std::make_shared(false); }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index d174d6cc72ead..6362a3169f3a3 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -422,6 +422,10 @@ Status SessionState::PrepackConstantInitializedTensors( auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map]( bool should_cache_prepacked_weights_for_shared_initializers) -> Status { for (auto& node : GetGraphViewer().Nodes()) { + if (sess_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Weight pre-packing was canceled due to user request."); + } auto kernel = GetMutableKernel(node.Index()); int input_idx = 0; for (auto& input_def : node.InputDefs()) { @@ -1541,6 +1545,11 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringname(); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 7b4a45ce8aa0f..d87688a62040c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3571,10 +3571,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point. + If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. + 5. For uint8 data, the `gather_axis` must be 0. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized) @@ -3602,7 +3603,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h .Input(2, "scales", "quantization scale", "T2") .Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional) .Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2") - .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.") + .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.") .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { @@ -3637,14 +3638,19 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h gather_axis = (gather_axis + r) % r; quantize_axis = (quantize_axis + r) % r; + if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) { + fail_shape_inference("gather_axis must be 0, for uint8 data"); + } + if (scales_shape.dim_size() != r) { fail_shape_inference("scales must have the same rank as data"); } + uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1; for (int i = 0; i < r; ++i) { if (!data_shape.dim(i).has_dim_value() || !scales_shape.dim(i).has_dim_value() || - (i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || + (i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || (i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) { fail_shape_inference("data shape and scales shape do not match"); } @@ -3652,6 +3658,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h // validate zero point shape if (ctx.hasInput(3)) { + if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) { + fail_type_inference("zero_points are not supported for uint8_t data type"); + } + if (!hasInputShape(ctx, 3)) { fail_shape_inference("zero_points shape must be known"); } @@ -3675,12 +3685,15 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); } for (int i = 0; i < out_rank; ++i) { + // For uint8_t data type the last dimension needs to be expanded back to actual dimension, + // because the data 2 int4s are stored packed in a single uint8_t. + auto last_dimension_components = (i == out_rank - 1) ? components : 1; *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() = (i < gather_axis) ? data_shape.dim(i) : (i >= gather_axis && i < gather_axis + q) ? indices_shape.dim(i - gather_axis) - : data_shape.dim(i - q + 1); + : data_shape.dim(i - q + 1) * last_dimension_components; } }); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 39ffc6a5b0cee..334ecb3887d14 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1268,6 +1268,10 @@ Graph::Graph(const Model& owning_model, #endif } + if (owning_model_.IsLoadCancellationFlagSet()) { + ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request."); + } + // Remove constant nodes as they're replaced with initializers above. const gsl::not_null*> graph_mutable_nodes{graph_proto_->mutable_node()}; graph_mutable_nodes->erase( @@ -1365,6 +1369,10 @@ Graph::Graph(const Model& owning_model, } } + if (owning_model_.IsLoadCancellationFlagSet()) { + ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request."); + } + for (auto& graph_output : graph_proto_->output()) { if (utils::HasName(graph_output) && utils::HasType(graph_output)) { auto& name = graph_output.name(); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 7629e40c1b5fe..436af7115eb1a 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -82,7 +82,7 @@ Model::Model(const std::string& graph_name, const std::vector& model_local_functions, const logging::Logger& logger, const ModelOptions& options) - : model_path_(model_path) { + : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) { model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); model_proto_.mutable_graph()->set_name(graph_name); model_metadata_ = model_metadata; @@ -161,7 +161,7 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path, Model::Model(ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger, const ModelOptions& options) - : model_path_(model_path) { + : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) { if (!utils::HasGraph(model_proto)) { ORT_THROW("ModelProto does not have a graph."); } @@ -435,6 +435,11 @@ Status Model::Load(const ModelProto& model_proto, ORT_TRY { model = std::make_unique(model_proto, model_path, local_registries, logger, options); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what())); @@ -474,6 +479,11 @@ Status Model::Load(ModelProto&& model_proto, ORT_TRY { model = std::make_unique(std::move(model_proto), model_path, local_registries, logger, options); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what())); @@ -509,6 +519,11 @@ static Status LoadModelHelper(const T& file_path, Loader loader) { ORT_TRY { status = loader(fd); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, FAIL, ex.what()); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6fd94c60d6b99..70f82bcfb160b 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -11,6 +11,7 @@ #include "core/common/flatbuffers.h" +#include "core/framework/session_options.h" #include "core/graph/graph_viewer.h" #include "core/graph/ort_format_load_options.h" #include "core/session/onnxruntime_c_api.h" @@ -38,6 +39,14 @@ struct ModelOptions { // be returned. bool strict_shape_type_inference; + CheckLoadCancellationFn check_load_cancellation_fn; + + ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference, + CheckLoadCancellationFn check_load_cancellation_fn) + : allow_released_opsets_only(allow_released_opsets_only), + strict_shape_type_inference(strict_shape_type_inference), + check_load_cancellation_fn(std::move(check_load_cancellation_fn)) {} + ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference) : allow_released_opsets_only(allow_released_opsets_only), strict_shape_type_inference(strict_shape_type_inference) {} @@ -102,6 +111,11 @@ class Model { #endif // !defined(ORT_MINIMAL_BUILD) + // Check for load cancellation. + bool IsLoadCancellationFlagSet() const noexcept { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + #if !defined(ORT_MINIMAL_BUILD) // Get model's IR version. // Return if not specified. @@ -343,5 +357,7 @@ class Model { // Main graph of the model. std::unique_ptr graph_; + + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index ea9d8605e2417..71c8667a89b1d 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -121,7 +121,7 @@ class ConvActivationSelector : public NodeSelector { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } - } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) { + } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider || node_ep == kWebGpuExecutionProvider) { if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { return std::nullopt; diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index 039283bb2d4e1..83c3f70799987 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -27,6 +27,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor } for (unsigned step = 0; step < steps_; ++step) { + if (IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph transformation canceled due to user request."); + } bool graph_changed = false; for (const auto& transformer : transformers->second) { if (step > 0 && transformer->ShouldOnlyApplyOnce()) diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index ed66302434ab2..eab57f12bfcbb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -24,6 +24,16 @@ class GraphTransformerManager { // Get the maximum number of graph transformation steps common::Status GetSteps(unsigned& steps) const; + // Set the cancellation flag ptr from session_options + void SetLoadCancellationFn(CheckLoadCancellationFn check_load_cancellation_fn) { + check_load_cancellation_fn_ = std::move(check_load_cancellation_fn); + } + + // Get the cancellation flag ptr + bool IsLoadCancellationFlagSet() const noexcept { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + // Register a transformer with a level. common::Status Register(std::unique_ptr transformer, TransformerLevel level); @@ -38,5 +48,6 @@ class GraphTransformerManager { InlinedHashMap>> level_to_transformer_map_; InlinedHashMap transformers_info_; + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 9684394da0520..eae2a464cef7e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -296,17 +296,19 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider, - onnxruntime::kJsExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider, - onnxruntime::kJsExecutionProvider}; + const InlinedHashSet cpu_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider}; const InlinedHashSet cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider, onnxruntime::kAclExecutionProvider}; @@ -338,7 +340,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_dml_acl_eps)); transformers.emplace_back(std::make_unique(cpu_acl_eps)); - transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); + transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_webgpu_eps)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 93b2acb5b002c..26642459a6863 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -77,6 +77,7 @@ Status CreateNodeArgs(const std::vector& names, const OnnxTensorInfo& tensor_info = tensor_info_table.at(name); std::unique_ptr tensor_type = Factory::Create(); tensor_type->mutable_tensor_type()->set_elem_type(tensor_info.data_type_); + tensor_type->mutable_tensor_type()->mutable_shape(); for (size_t j = 0; j < tensor_info.shape_.size(); ++j) { tensor_type->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 9d5e16caa361d..bc8905c225822 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -611,6 +611,8 @@ struct ProviderHost { virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void InferShapes(const std::string& m, const std::string& save_path) = 0; + virtual void InferShapes(ONNX_NAMESPACE::ModelProto& m) = 0; virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0; virtual void DeregisterSchema(const std::string& domain, const std::string& op_type, int version) = 0; virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0; @@ -1010,6 +1012,7 @@ struct ProviderHost { virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0; virtual Graph* Graph__MutableParentGraph(Graph* p) = 0; virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0; + virtual void Graph__SetName(Graph* p, const std::string& name) const noexcept = 0; virtual const std::filesystem::path& Graph__ModelPath(const Graph* p) const = 0; virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0; virtual bool Graph__IsSubgraph(const Graph* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index e2af144f455e4..5f0f9ca4c8584 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1050,6 +1050,7 @@ struct Graph final { const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); } Graph* MutableParentGraph() { return g_host->Graph__MutableParentGraph(this); } const std::string& Name() const noexcept { return g_host->Graph__Name(this); } + void SetName(const std::string& name) noexcept { return g_host->Graph__SetName(this, name); } const std::filesystem::path& ModelPath() const { return g_host->Graph__ModelPath(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); } bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 6547f00cd47c7..33aa8fa2b31b8 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -360,10 +360,19 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); }; the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; + the_global_api.graph_set_name = [](Graph& graph, const char* name) -> void { return graph.SetName(std::string(name)); }; the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, const auto& enter, const auto& leave, const auto& stop) { graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; + + the_global_api.graph_infer_shapes_from_filepath = [](const std::string& m, const std::string& save_path) -> auto { return Provider_GetHost()->InferShapes(m, save_path); }; + the_global_api.graph_to_graph_proto = [](const Graph& graph) -> ONNX_NAMESPACE::GraphProto* { + return graph.ToGraphProto().release(); + }; + the_global_api.graph_proto_delete = [](ONNX_NAMESPACE::GraphProto* p) { delete p; }; + the_global_api.graph_infer_shapes = [](ONNX_NAMESPACE::ModelProto& m) -> auto { return Provider_GetHost()->InferShapes(m); }; + // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 85a1262d8489b..6c9c728d8ffad 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -20,6 +20,7 @@ struct NodeAttributes; namespace ONNX_NAMESPACE { struct AttributeProto; struct TensorProto; +struct GraphProto; struct ModelProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { @@ -71,6 +72,7 @@ enum AttributeProto_AttributeType : int { namespace vaip_core { class GraphHolder; using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::GraphProto; using ONNX_NAMESPACE::ModelProto; using ONNX_NAMESPACE::TensorProto; using onnxruntime::Graph; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 0becc41d861f7..d40da70726b43 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (14u) +#define VAIP_ORT_API_MAJOR (16u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -249,7 +249,13 @@ struct OrtApiForVaip { const std::function& leave, const std::function& comp, const std::function& - stop); // [103] + stop); // [103] + void (*graph_set_name)(Graph& graph, const char* name); // [104] + void (*graph_infer_shapes_from_filepath)( + const std::string& m, const std::string& save_path); // [105] + GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106] + void (*graph_proto_delete)(GraphProto* p); // [107] + void (*graph_infer_shapes)(ModelProto& m); // [108] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 91cae111a708a..315d0cd75e946 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,15 +13,15 @@ void* GpuBufferAllocator::Alloc(size_t size) { return nullptr; } - WGPUBuffer buffer; - if (!session_initialized_ && context_.SupportsBufferMapExtendedUsages()) { - buffer = context_.BufferManager().CreateUMA(size); - } else { - buffer = context_.BufferManager().Create(size); + stats_.num_allocs++; + +#if !defined(__wasm__) + if (!session_initialized_ && context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages)) { + return context_.BufferManager().CreateUMA(size); } +#endif // !defined(__wasm__) - stats_.num_allocs++; - return buffer; + return context_.BufferManager().Create(size); } void GpuBufferAllocator::Free(void* p) { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index adb37f54f2e8f..1d8c689cbd909 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -56,14 +56,27 @@ class LazyReleaseCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer)); + pending_buffers_.emplace_back(buffer); } void OnRefresh() override { + Release(); pending_buffers_.clear(); } - std::vector pending_buffers_; + public: + ~LazyReleaseCacheManager() { + Release(); + } + + protected: + void Release() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + } + + std::vector pending_buffers_; }; class SimpleCacheManager : public IBufferCacheManager { @@ -74,7 +87,7 @@ class SimpleCacheManager : public IBufferCacheManager { WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { auto it = buffers_.find(buffer_size); if (it != buffers_.end() && !it->second.empty()) { - auto buffer = it->second.back().MoveToCHandle(); + auto buffer = it->second.back(); it->second.pop_back(); return buffer; } @@ -87,18 +100,31 @@ class SimpleCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer)); + pending_buffers_.emplace_back(buffer); } void OnRefresh() override { for (auto& buffer : pending_buffers_) { - buffers_[static_cast(buffer.GetSize())].emplace_back(std::move(buffer)); + buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); } pending_buffers_.clear(); } - std::map> buffers_; - std::vector pending_buffers_; + public: + ~SimpleCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buffers_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + } + + protected: + std::map> buffers_; + std::vector pending_buffers_; }; // TODO: maybe use different bucket size for storage and uniform buffers? @@ -155,7 +181,7 @@ class BucketCacheManager : public IBufferCacheManager { WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { auto it = buckets_.find(buffer_size); if (it != buckets_.end() && !it->second.empty()) { - auto buffer = it->second.back().MoveToCHandle(); + auto buffer = it->second.back(); it->second.pop_back(); return buffer; } @@ -167,31 +193,44 @@ class BucketCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer)); + pending_buffers_.emplace_back(buffer); } void OnRefresh() override { // TODO: consider graph capture. currently not supported for (auto& buffer : pending_buffers_) { - auto buffer_size = static_cast(buffer.GetSize()); + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); auto it = buckets_.find(buffer_size); if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { - it->second.emplace_back(std::move(buffer)); + it->second.emplace_back(buffer); + } else { + wgpuBufferRelease(buffer); } } pending_buffers_.clear(); } + ~BucketCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buckets_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + } + protected: void Initialize() { buckets_keys_.reserve(buckets_limit_.size()); buckets_.reserve(buckets_limit_.size()); for (const auto& pair : buckets_limit_) { buckets_keys_.push_back(pair.first); - buckets_.emplace(pair.first, std::vector()); + buckets_.emplace(pair.first, std::vector()); } std::sort(buckets_keys_.begin(), buckets_keys_.end()); @@ -205,8 +244,8 @@ class BucketCacheManager : public IBufferCacheManager { #endif } std::unordered_map buckets_limit_; - std::unordered_map> buckets_; - std::vector pending_buffers_; + std::unordered_map> buckets_; + std::vector pending_buffers_; std::vector buckets_keys_; }; @@ -255,11 +294,10 @@ BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buf void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { // If the buffer is mapped, we can directly write to it. - wgpu::Buffer dst_buffer = dst; - auto mapped_data = dst_buffer.GetMappedRange(); + void* mapped_data = wgpuBufferGetMappedRange(dst, 0, WGPU_WHOLE_MAP_SIZE); // ensure the buffer is mapped if (mapped_data) { memcpy(mapped_data, src, size); - dst_buffer.Unmap(); + wgpuBufferUnmap(dst); return; } @@ -288,9 +326,11 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { EnforceBufferUnmapped(context_, dst); auto buffer_size = NormalizeBufferSize(size); - ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + auto src_size = static_cast(wgpuBufferGetSize(src)); + auto dst_size = static_cast(wgpuBufferGetSize(dst)); + ORT_ENFORCE(buffer_size <= src_size && buffer_size <= dst_size, "Source and destination buffers must have enough space for the copy operation. src_size=", - wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + src_size, ", dst_size=", dst_size, ", copy_size=", buffer_size, "."); auto& command_encoder = context_.GetCommandEncoder(); context_.EndComputePass(); @@ -298,7 +338,7 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { } WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { - auto& cache = GetCacheManager(static_cast(usage)); + auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); auto buffer = cache.TryAcquireCachedBuffer(buffer_size); @@ -310,7 +350,6 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { wgpu::BufferDescriptor desc{}; desc.size = buffer_size; desc.usage = usage; - // desc.label = std::to_string(xx++).c_str(); buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); @@ -320,14 +359,16 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { } WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) { - ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must have storage usage."); - auto& cache = GetCacheManager(static_cast(usage)); + ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer."); + auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); + // Ensure the buffer is mapped for writing at creation. + usage |= wgpu::BufferUsage::MapWrite; + wgpu::BufferDescriptor desc{}; desc.size = buffer_size; - // Ensure the buffer is mapped for writing at creation. - desc.usage = usage | wgpu::BufferUsage::MapWrite; + desc.usage = usage; desc.mappedAtCreation = true; auto buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); @@ -373,12 +414,12 @@ void BufferManager::RefreshPendingBuffers() { default_cache_->OnRefresh(); } -IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const { - if (usage & WGPUBufferUsage_Storage) { +IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const { + if (usage & wgpu::BufferUsage::Storage) { return *storage_cache_; - } else if (usage & WGPUBufferUsage_Uniform) { + } else if (usage & wgpu::BufferUsage::Uniform) { return *uniform_cache_; - } else if (usage & WGPUBufferUsage_QueryResolve) { + } else if (usage & wgpu::BufferUsage::QueryResolve) { return *query_resolve_cache_; } else { return *default_cache_; @@ -386,7 +427,8 @@ IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const } IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const { - return GetCacheManager(wgpuBufferGetUsage(buffer)); + auto usage = static_cast(wgpuBufferGetUsage(buffer)); + return GetCacheManager(usage); } std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index 6a8ebdd60a1ec..b9028ad5de858 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -70,7 +70,7 @@ class BufferManager { void RefreshPendingBuffers(); private: - IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const; + IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const; IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; WebGpuContext& context_; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 9b447d5fdb59a..cdd3909874e7f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -6,8 +6,9 @@ #include "core/providers/cpu/tensor/utils.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" - +#include "core/providers/webgpu/nn/fuse_utils.h" #include "core/providers/webgpu/data_transfer.h" + namespace onnxruntime { namespace webgpu { @@ -54,11 +55,12 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string process_bias; if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); - process_bias = "value += output_value_t(bias[row + i]);"; + process_bias = is_channels_last_ ? "value += output_value_t(bias[col])" : "value += output_value_t(bias[row + i]);"; } + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | - ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& batch_dims = shader.AddIndices("batch_dims"); int a_components = a.NumComponents(); @@ -90,6 +92,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { << "for (var i = 0u; i < " << output_number_ << "u; i++) {\n" << " var value = values[i];\n" << process_bias << "\n" + << apply_activation << "\n" << " let cur_indices = output_indices_t(batch, row + i, col/ " << components << ");\n" << " let offset = " << output.IndicesToOffset("cur_indices") << ";\n" << output.SetByOffset("offset", "value") @@ -127,7 +130,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1; TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components}); - MatMulNaiveProgram program{output_rank, output_number, has_bias}; + MatMulNaiveProgram program{Activation(), output_rank, output_number, has_bias}; program .CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number)) @@ -147,11 +150,32 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { return context.RunProgram(program); } - int64_t batchA = a->Shape().SizeToDimension(a->Shape().NumDimensions() - 2); - int64_t batchB = b->Shape().SizeToDimension(b->Shape().NumDimensions() - 2); + std::vector inputs(has_bias ? 3 : 2); + inputs[0] = a; + inputs[1] = b; + if (has_bias) { + const auto* bias = context.Input(2); + inputs.push_back(bias); + } + auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false); + + return context.RunProgram(program); +} + +MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, + const TensorShape& input_a_reshape, + const TensorShape& input_b_reshape) { + const auto* a = inputs[0]; + const auto* b = inputs[1]; + bool has_bias = inputs.size() > 2; + TensorShape a_shape = input_a_reshape.NumDimensions() > 0 ? input_a_reshape : a->Shape(); + TensorShape b_shape = input_b_reshape.NumDimensions() > 0 ? input_b_reshape : b->Shape(); + + MatMulComputeHelper helper; + ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape)); + int64_t batchA = a_shape.SizeToDimension(a_shape.NumDimensions() - 2); + int64_t batchB = b_shape.SizeToDimension(b_shape.NumDimensions() - 2); - TensorShape a_shape = a->Shape(); - TensorShape b_shape = b->Shape(); TensorShape output_shape = helper.OutputShape(); const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2]; @@ -184,9 +208,9 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { const int64_t batch_size = outer_dims.Size(); // Get dimensions for matrix multiplication from TensorShape - const int32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension - const int32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension - const int32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension + const uint32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension + const uint32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension + const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0; @@ -194,34 +218,36 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { ? InlinedVector({4, 1, 1}) : InlinedVector({4, 4, 1}); - const uint32_t dispatch_x = narrow((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / - (MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0])); - const uint32_t dispatch_y = narrow((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / - (MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); - const uint32_t dispatch_z = narrow((static_cast(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / - (MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); + const uint32_t dispatch_x = narrow((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0])); + const uint32_t dispatch_y = narrow((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); + const uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); const int components = is_vec4 ? 4 : 1; const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components); const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); - MatMulProgram program{has_bias, is_vec4, elements_per_thread}; + MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last}; program - .CacheHint(absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4)) + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4)) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) - .SetWorkgroupSize(MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z); + .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z); if (has_bias) { - const auto* bias = context.Input(2); - program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1}); + auto bias_components = is_channels_last ? components : 1; + const auto* bias = inputs[2]; + TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); } - return context.RunProgram(program); + return program; } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 789e824383189..91216d8e25eec 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -9,16 +9,20 @@ #include "core/providers/webgpu/math/matmul_utils.h" #include "core/providers/webgpu/math/matmul_packed.h" #include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/nn/fuse_utils.h" namespace onnxruntime { namespace webgpu { +MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, + const TensorShape& input_a_reshape = TensorShape(), + const TensorShape& input_b_reshape = TensorShape()); + class MatMul final : public WebGpuKernel { public: MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {} Status ComputeInternal(ComputeContext& context) const override; - constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Z = 1; @@ -26,8 +30,8 @@ class MatMul final : public WebGpuKernel { class MatMulNaiveProgram final : public Program { public: - MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias) - : Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} { + MatMulNaiveProgram(const Activation& activation, const size_t output_rank, int64_t output_number, bool has_bias, bool is_channels_last = false) + : Program{"MatMulNaive"}, activation_(activation), output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias}, is_channels_last_(is_channels_last) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -38,9 +42,11 @@ class MatMulNaiveProgram final : public Program { {"K", ProgramUniformVariableDataType::Uint32}); private: + const Activation& activation_; const size_t output_rank_; const int64_t output_number_; const bool has_bias_; + const bool is_channels_last_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 2e5cff923f442..36510eec0cd3b 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -5,7 +5,7 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/webgpu_utils.h" - +#include namespace onnxruntime { namespace webgpu { @@ -13,7 +13,8 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, - const ShaderIndicesHelper& batch_dims) const { + const ShaderIndicesHelper& batch_dims, + std::string activation_snippet) const { int components = is_vec4_ ? 4 : 1; const std::string data_type = "a_element_t"; const std::string type_string = MakeScalarOrVectorType(components, data_type); @@ -23,7 +24,7 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, << "fn mm_readA(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n" << " var value = " << type_string << "(0.0);\n" << " let col = colIn * " << components << ";\n" - << " if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) {\n" + << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n" << " var a_indices: a_indices_t;\n" << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims.Rank(), "batch_indices") << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" @@ -38,7 +39,7 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, << "fn mm_readB(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n" << " var value = " << type_string << "(0.0);\n" << " let col = colIn * " << components << ";\n" - << " if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) {\n" + << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" << " var b_indices: b_indices_t;\n" << ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims.Rank(), "batch_indices") << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" @@ -52,14 +53,16 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, shader.AdditionalImplementation() << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: " << type_string << ") {\n" << " let col = colIn * " << components << ";\n" - << " if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {\n" + << " if (row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n" << " var value = valueIn;\n" << " let coords = vec3(batch, row, colIn);\n"; if (has_bias_) { - shader.AdditionalImplementation() << " value = value + " << type_string << "(bias[row]);\n"; + shader.AdditionalImplementation() << " value = value + " << (is_channels_last_ ? "bias[colIn]" : type_string + "(bias[row])") << ";\n"; } + shader.AdditionalImplementation() << " " << activation_snippet << "\n"; + shader.AdditionalImplementation() << output.SetByIndices("vec3(coords)", "value") << "\n" << " }\n" @@ -67,29 +70,36 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, } Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, - const ShaderIndicesHelper& batch_dims, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y) { + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a, + uint32_t tile_inner, + bool split_k, + uint32_t splitted_dim_inner) { + ORT_UNUSED_PARAMETER(split_k); + ORT_UNUSED_PARAMETER(splitted_dim_inner); + std::string write_data_to_sub_a_vec4_snippet = + transpose_a ? std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, kStart + inputRow, globalRowStart / innerElementSize + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n" + : std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRow + innerRow, kStart / innerElementSize + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n"; // elements per thread const auto elements_per_thread_x = elements_per_thread[0]; const auto elements_per_thread_y = elements_per_thread[1]; - const decltype(elements_per_thread_x) tile_inner = 32; const auto tile_a_outer = workgroup_size_y * elements_per_thread_y; const auto tile_b_outer = workgroup_size_x * elements_per_thread_x; - const auto tile_a_width = tile_inner; - - const auto tile_a_height = tile_a_outer; + const auto tile_a_width = transpose_a ? tile_a_outer : tile_inner; + const auto tile_a_height = transpose_a ? tile_inner : tile_a_outer; const auto inner_elements_size = tile_a_width / workgroup_size_x; const auto row_per_thread_b = tile_inner / workgroup_size_y; - const std::string data_type = "a_element_t"; - - if (!((inner_elements_size == 3 || inner_elements_size == 4) && - tile_a_width % workgroup_size_x == 0 && - tile_inner % workgroup_size_y == 0 && - elements_per_thread_x == 4)) { + if (!((transpose_a && inner_elements_size == 4 && elements_per_thread[1] == 4) || + (!transpose_a && (inner_elements_size == 3 || inner_elements_size == 4))) && + tile_a_width % workgroup_size_x == 0 && + tile_inner % workgroup_size_y == 0 && + elements_per_thread_x == 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid matrix multiplication configuration inner_elements_size: ", inner_elements_size, " must be 3 or 4. tile_a_width: ", tile_a_width, " must be divisible by WorkgroupSizeX: ", @@ -112,7 +122,7 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, << " let globalRow = i32(global_id.y) * rowPerThread;\n" << " let globalCol = i32(global_id.x);\n" << " let batch = i32(global_id.z);\n" - << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -121,14 +131,14 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, // Loop over shared dimension. shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n" - << " for (var t = 0; t < num_tiles; t = t + 1) {\n"; + << " for (var t = 0; t < i32(num_tiles); t = t + 1) {\n"; // Load one tile of A into local memory. shader.MainFunctionBody() << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" << " let inputRow = tileRow + innerRow;\n" << " let inputCol = tileCol;\n" - << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRow + innerRow, kStart / innerElementSize + inputCol, batchIndices);\n" + << " " << write_data_to_sub_a_vec4_snippet << " }\n"; // Load one tile of B into local memory. @@ -136,7 +146,7 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, << " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n" << " let inputRow = tileRowB + innerRow;\n" << " let inputCol = tileCol;\n" - << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol, batchIndices);\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol" << (nullptr != batch_dims ? ", batchIndices" : "") << ");\n" << " }\n" << " kStart = kStart + tileInner;\n" << " workgroupBarrier();\n"; @@ -152,15 +162,29 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];\n"; } - shader.MainFunctionBody() - << " for (var i = 0; i < rowPerThread; i = i + 1) {\n" - << " let ACached = mm_Asub[tileRow + i][k];\n" - << " acc[i] = BCached0 * ACached.x + acc[i];\n" - << " acc[i] = BCached1 * ACached.y + acc[i];\n" - << " acc[i] = BCached2 * ACached.z + acc[i];\n" - << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n" - << " }\n"; - + if (transpose_a) { + shader.MainFunctionBody() + << " let Acached0 = mm_Asub[k * innerElementSize][localRow];\n" + << " let Acached1 = mm_Asub[k * innerElementSize + 1][localRow];\n" + << " let Acached2 = mm_Asub[k * innerElementSize + 2][localRow];\n" + << (inner_elements_size == 3 ? "" : " let Acached3 = mm_Asub[k * innerElementSize + 3][localRow];\n") + << " for (var i = 0; i < rowPerThread; i = i + 1) {\n" + << " let ACached = mm_Asub[tileCol][i];\n" + << " acc[i] = BCached0 * ACached0[i] + acc[i];\n" + << " acc[i] = BCached1 * ACached1[i] + acc[i];\n" + << " acc[i] = BCached2 * ACached2[i] + acc[i];\n" + << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached3[i] + acc[i];") << "\n" + << " }\n"; + } else { + shader.MainFunctionBody() + << " for (var i = 0; i < rowPerThread; i = i + 1) {\n" + << " let ACached = mm_Asub[tileRow + i][k];\n" + << " acc[i] = BCached0 * ACached.x + acc[i];\n" + << " acc[i] = BCached1 * ACached.y + acc[i];\n" + << " acc[i] = BCached2 * ACached.z + acc[i];\n" + << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n" + << " }\n"; + } shader.MainFunctionBody() << " workgroupBarrier();\n" << " }\n"; // main for loop @@ -174,13 +198,22 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, return Status::OK(); } -Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderIndicesHelper& batch_dims, +Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y) { + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a, + uint32_t tile_inner, + bool split_k, + uint32_t splitted_dim_inner, + bool sequentially_access_by_threads) { + ORT_UNUSED_PARAMETER(split_k); + ORT_UNUSED_PARAMETER(splitted_dim_inner); + const auto elements_per_thread_x = elements_per_thread[0]; const auto elements_per_thread_y = elements_per_thread[1]; - const decltype(elements_per_thread_x) tile_inner = 32; const auto tile_a_outer = workgroup_size_y * elements_per_thread_y; const auto tile_b_outer = workgroup_size_x * elements_per_thread_x; @@ -194,12 +227,11 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderI ", tile_inner: ", tile_inner, " must be divisible by WorkgroupSizeY: ", workgroup_size_y); } - const std::string data_type = "a_element_t"; - const auto row_per_thread_a = tile_a_height / workgroup_size_y; const auto col_per_thread_a = tile_a_width / workgroup_size_x; const auto row_per_thread_b = tile_inner / workgroup_size_y; - + std::string write_data_to_sub_a_snippet = transpose_a ? std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, kStart + inputRow, globalRowStart + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n" + : std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, kStart + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n"; shader.AdditionalImplementation() << "var mm_Asub: array, " << tile_a_height << ">;\n" << "var mm_Bsub: array, " << tile_inner << ">;\n" @@ -208,93 +240,142 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderI << "const tileInner = " << tile_inner << ";\n"; shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" - << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" << " var acc: array, rowPerThread>;\n"; - shader.MainFunctionBody() - << "let tileRow = i32(local_id.y) * rowPerThread;\n" - << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(global_id.y) * rowPerThread;\n" - << "let globalCol = i32(global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" - << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" - << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; - - // Loop over shared dimension. - shader.MainFunctionBody() - << "for (var t = 0; t < num_tiles; t = t + 1) {\n"; - - // Load one tile of A into local memory. - shader.MainFunctionBody() - << " for (var innerRow = 0; innerRow < " << row_per_thread_a << "; innerRow = innerRow + 1) {\n" - << " for (var innerCol = 0; innerCol < " << col_per_thread_a << "; innerCol = innerCol + 1) {\n" - << " let inputRow = tileRowA + innerRow;\n" - << " let inputCol = tileColA + innerCol;\n" - << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, kStart + inputCol, batchIndices);\n" - << " }\n" - << " }\n"; - - // Load one tile of B into local memory. - shader.MainFunctionBody() - << " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n" - << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" - << " let inputRow = tileRowB + innerRow;\n" - << " let inputCol = tileCol + innerCol;\n" - << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol + innerCol, batchIndices);\n" - << " }\n" - << " }\n" - << " kStart = kStart + tileInner;\n" - << " workgroupBarrier();\n"; - - // Compute acc values for a single thread. - shader.MainFunctionBody() - << "var BCached: array<" << data_type << ", colPerThread>;\n" - << " for (var k = 0; k < tileInner; k = k + 1) {\n" - << " for (var inner = 0; inner < colPerThread; inner = inner + 1) {\n" - << " BCached[inner] = mm_Bsub[k][tileCol + inner];\n" - << " }\n" - << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" - << " let ACached = mm_Asub[tileRow + innerRow][k];\n" - << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" - << " acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];\n" - << " }\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; - - // Write the results to the output buffer - shader.MainFunctionBody() - << "for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" - << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" - << " mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]);\n" - << " }\n" - << "}\n"; - + if (sequentially_access_by_threads) { + shader.MainFunctionBody() << "let localRow = i32(local_id.y);\n" + << "let localCol = i32(local_id.x);\n" + << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << "\n" + << "// Loop over shared dimension.\n" + << "for (var t = 0; t < i32(num_tiles); t = t + 1) {\n" + << " // Load one tile of A into local memory.\n" + << " for (var inputRow = localRow; inputRow < " << tile_a_height << "; inputRow = inputRow + " << workgroup_size_y << ") {\n" + << " for (var inputCol = localCol; inputCol < " << tile_a_width << "; inputCol = inputCol + " << workgroup_size_x << ") {\n" + << " " << write_data_to_sub_a_snippet << "\n" + << " }\n" + << " }\n" + << " // Load one tile of B into local memory.\n" + << " for (var inputRow = localRow; inputRow < " << tile_inner << "; inputRow = inputRow + " << workgroup_size_y << ") {\n" + << " for (var inputCol = localCol; inputCol < " << tile_b_outer << "; inputCol = inputCol + " << workgroup_size_x << ") {\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch,\n" + << " kStart + inputRow,\n" + << " globalColStart + inputCol" << (batch_dims ? ", batchIndices" : "") << ");\n " + << " }\n" + << " }\n" + << " kStart = kStart + tileInner;\n" + << " workgroupBarrier();\n" + << "\n" + << " // Compute acc values for a single thread.\n" + << " var BCached : array<" << data_type << ", colPerThread>;\n" + << " for (var k = 0; k < tileInner; k = k + 1) {\n" + << " for (var inner = 0; inner < colPerThread; inner = inner + 1) {\n" + << " BCached[inner] = mm_Bsub[k][localCol + inner * " << workgroup_size_x << "];\n" + << " }\n" + << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " let ACached = " << (transpose_a ? "mm_Asub[k][localCol + innerRow * " + std::to_string(workgroup_size_y) + "];" : "mm_Asub[localRow + innerRow * " + std::to_string(workgroup_size_y) + "][k];") << "\n" + << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" + << " acc[innerRow][innerCol] = acc[innerRow][innerCol] +\n" + << " ACached * BCached[innerCol];\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " let gRow = globalRowStart + localRow + innerRow * " << workgroup_size_y << ";\n" + << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" + << " let gCol = globalColStart + localCol + innerCol * " << workgroup_size_x << ";\n" + << " mm_write(batch, gRow, gCol, acc[innerRow][innerCol]);\n" + << " }\n" + << "}\n"; + } else { + shader.MainFunctionBody() + << "let tileRow = i32(local_id.y) * rowPerThread;\n" + << "let tileCol = i32(local_id.x) * colPerThread;\n" + << "let globalRow = i32(global_id.y) * rowPerThread;\n" + << "let globalCol = i32(global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" + << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" + << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; + + // Loop over shared dimension. + shader.MainFunctionBody() + << "for (var t = 0; t < i32(num_tiles); t = t + 1) {\n"; + + // Load one tile of A into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < i32(" << row_per_thread_a << "); innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < i32(" << col_per_thread_a << "); innerCol = innerCol + 1) {\n" + << " let inputRow = tileRowA + innerRow;\n" + << " let inputCol = tileColA + innerCol;\n" + << " " << write_data_to_sub_a_snippet << "\n" + << " }\n" + << " }\n"; + + // Load one tile of B into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < i32(" << row_per_thread_b << "); innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n" + << " let inputRow = tileRowB + innerRow;\n" + << " let inputCol = tileCol + innerCol;\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol + innerCol" << (nullptr != batch_dims ? ", batchIndices" : "") << ");\n" + << " }\n" + << " }\n" + << " kStart = kStart + tileInner;\n" + << " workgroupBarrier();\n"; + + // Compute acc values for a single thread. + shader.MainFunctionBody() + << "var BCached: array<" << data_type << ", colPerThread>;\n" + << " for (var k = 0; k < tileInner; k = k + 1) {\n" + << " for (var inner = 0; inner < i32(colPerThread); inner = inner + 1) {\n" + << " BCached[inner] = mm_Bsub[k][tileCol + inner];\n" + << " }\n" + << " for (var innerRow = 0; innerRow < i32(rowPerThread); innerRow = innerRow + 1) {\n" + << " let ACached = mm_Asub[tileRow + innerRow][k];\n" + << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n" + << " acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + // Write the results to the output buffer + shader.MainFunctionBody() + << "for (var innerRow = 0; innerRow < i32(rowPerThread); innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n" + << " mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]);\n" + << " }\n" + << "}\n"; + } return Status::OK(); } Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); } - + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); // declare the read and write functions - MatMulReadWriteFnSource(shader, a, b, output, batch_dims); - + MatMulReadWriteFnSource(shader, a, b, output, batch_dims, apply_activation); + std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { - ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY())); + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } else { - ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY())); + ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index ea76468944066..d3a68ff8a57fa 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -7,38 +7,53 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/math/matmul_utils.h" +#include "core/providers/webgpu/nn/fuse_utils.h" namespace onnxruntime { namespace webgpu { class MatMulProgram final : public Program { public: - MatMulProgram(bool bias, bool is_vec4, const gsl::span& elements_per_thread) : Program{"MatMul"}, - has_bias_{bias}, - is_vec4_{is_vec4}, - elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} + MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false) : Program{"MatMul"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), + is_channels_last_(is_channels_last) {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Int32}, - {"dim_b_outer", ProgramUniformVariableDataType::Int32}, - {"dim_inner", ProgramUniformVariableDataType::Int32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}); static Status MakeMatMulPackedVec4Source(ShaderHelper& shader, - const ShaderIndicesHelper& batch_dims, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y); + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a = false, + uint32_t tile_inner = 32, + bool split_k = false, + uint32_t splitted_dim_inner = 32); static Status MakeMatMulPackedSource(ShaderHelper& shader, - const ShaderIndicesHelper& batch_dims, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y); + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a = false, + uint32_t tile_inner = 32, + bool split_k = false, + uint32_t splitted_dim_inner = 32, + bool sequentially_access_by_threads = false); private: + const Activation& activation_; const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; - - void MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, const ShaderIndicesHelper& batch_dims) const; + bool is_channels_last_ = false; + void MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, const ShaderIndicesHelper& batch_dims, std::string apply_activation) const; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/nn/activation_util.cc b/onnxruntime/core/providers/webgpu/nn/activation_util.cc new file mode 100644 index 0000000000000..b5c31d98cda93 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/activation_util.cc @@ -0,0 +1,25 @@ +#include "core/providers/webgpu/nn/activation_util.h" +#include "core/common/common.h" +namespace onnxruntime { +namespace webgpu { +std::string TypeSnippet(uint32_t component, std::string data_type) { + switch (component) { + case 1: + return data_type; + case 2: + return "vec2<" + data_type + ">"; + case 3: + return "vec3<" + data_type + ">"; + case 4: + return "vec4<" + data_type + ">"; + default: + ORT_THROW("Component ", component, " is not supported."); + } +} + +std::string BiasSnippet(bool has_bias) { + return has_bias ? "value = value + getBiasByOutputCoords(coords);" : ""; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/activation_util.h b/onnxruntime/core/providers/webgpu/nn/activation_util.h new file mode 100644 index 0000000000000..1c9fd93e35384 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/activation_util.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace webgpu { + +extern std::string TypeSnippet(uint32_t component, std::string data_type); +extern std::string BiasSnippet(bool has_bias); + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc new file mode 100644 index 0000000000000..0edad3eebe2ea --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/nn/conv.h" +#include "core/providers/webgpu/nn/conv2d_mm_webgpu.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/webgpu/nn/grouped_conv.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/math/matmul.h" +namespace onnxruntime { +namespace webgpu { + +Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm) { + // Transpose weights + auto rank = kernel_shape.NumDimensions(); + TensorShapeVector transposed_kernel_shape_vector(rank); + for (size_t i = 0; i < rank; ++i) { + transposed_kernel_shape_vector[i] = kernel_shape[perm[i]]; + } + uint32_t output_size = onnxruntime::narrow(kernel_shape.Size()); + TensorShape transposed_kernel_shape(transposed_kernel_shape_vector); + *transposed_kernel = context.CreateGPUTensor(kernel->DataType(), transposed_kernel_shape); + bool use_shared = false; + TransposeProgram program{perm, use_shared}; + program + .CacheHint(absl::StrJoin(perm, "-")) + .AddInput({kernel, ProgramTensorMetadataDependency::TypeAndRank, kernel_shape, 1}) + .AddOutput({transposed_kernel, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariable({output_size}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + +template +Status Conv::ComputeInternal(ComputeContext& context) const { + bool has_bias = context.InputCount() > 2; + const auto* input = context.Input(0); + const auto* kernel = context.Input(1); + const auto* bias = has_bias ? context.Input(2) : nullptr; + TensorShape input_shape = input->Shape(); + TensorShape kernel_shape = kernel->Shape(); + ConvAttributes::ConvPadVector local_pads(conv_attrs_.pads.begin(), conv_attrs_.pads.end()); + TensorShapeVector local_dilations(conv_attrs_.dilations.begin(), conv_attrs_.dilations.end()); + TensorShapeVector local_strides(conv_attrs_.strides.begin(), conv_attrs_.strides.end()); + TensorShapeVector kernel_spacial_shape_vector; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(kernel_shape, kernel_spacial_shape_vector, false)); + if (local_pads.empty()) { + local_pads.resize(kernel_spacial_shape_vector.size() * 2, 0); + } + if (local_dilations.empty()) { + local_dilations.resize(kernel_spacial_shape_vector.size(), 1); + } + if (local_strides.empty()) { + local_strides.resize(kernel_spacial_shape_vector.size(), 1); + } + TensorShapeVector input_shape_vector = input_shape.AsShapeVector(); + auto batch = input_shape[0]; + TensorShapeVector output_shape_vector = {batch}; + TensorShape input_spacial_shape = is_channels_last ? TensorShape(TensorShapeVector(std::next(input_shape_vector.begin()), std::prev(input_shape_vector.end()))) : input_shape.Slice(2); + ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_spacial_shape, kernel_spacial_shape_vector, local_strides, local_dilations, local_pads, output_shape_vector)); + auto output_channels = kernel_shape[0]; + if (is_channels_last) { + output_shape_vector.push_back(output_channels); + } else { + output_shape_vector.insert(output_shape_vector.begin() + 1, output_channels); + } + auto output_shape = TensorShape(output_shape_vector); + auto* output = context.Output(0, output_shape); + std::vector strides; + std::vector pads; + std::vector dilations; + auto transform_dim = [](int64_t dim) { return static_cast(dim); }; + std::transform(local_pads.begin(), local_pads.end(), std::back_inserter(pads), transform_dim); + std::transform(local_strides.begin(), local_strides.end(), std::back_inserter(strides), transform_dim); + std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim); + auto rank = input_shape.NumDimensions(); + const InlinedVector perm = {2, 3, 1, 0}; + if (rank > 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d and Conv2d are supported."); + } else if (rank == 4) { + // Conv2D + } else if (rank == 3) { + // Conv1D + TensorShapeVector kernel_shape_vector = kernel_shape.AsShapeVector(); + input_shape_vector.insert(input_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + output_shape_vector.insert(output_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + kernel_shape_vector.insert(kernel_shape_vector.begin() + 2, 1); + input_shape = TensorShape(input_shape_vector); + kernel_shape = TensorShape(kernel_shape_vector); + pads.insert(pads.begin(), 0); + pads.insert(pads.begin() + 2, 0); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input and kernel tensors must have at least 3 dimensions"); + } + std::vector inputs(has_bias ? 3 : 2); + inputs[0] = input; + inputs[1] = kernel; + if (has_bias) { + inputs[2] = bias; + } + std::vector modified_input_output_shapes = {input_shape, kernel_shape}; + if (has_bias) { + modified_input_output_shapes.push_back(bias->Shape()); + } + modified_input_output_shapes.push_back(TensorShape(output_shape_vector)); + uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0; + auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2; + auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2; + std::vector updated_pads{pad0, pad1}; + if (conv_attrs_.group > 1) { + Tensor transposed_kernel; + if (is_channels_last) { + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + inputs[1] = &transposed_kernel; + modified_input_output_shapes[1] = transposed_kernel.Shape(); + } + auto output_channels_per_group = output_channels / conv_attrs_.group; + auto components = static_cast(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1); + auto output_size = output_shape.Size() / components; + GroupedConvProgram program(activation_, has_bias, is_channels_last); + auto reduced_kernel_shape = ReduceShapeByComponents(modified_input_output_shapes[1], components); + auto reduced_output_shape = ReduceShapeByComponents(modified_input_output_shapes[has_bias ? 3 : 2], components); + program.CacheHint(activation_.ToString(), std::to_string(components), std::to_string(is_channels_last)) + .AddInput({inputs[0], ProgramTensorMetadataDependency::TypeAndRank, modified_input_output_shapes[0], 1}) + .AddInput({inputs[1], ProgramTensorMetadataDependency::TypeAndRank, reduced_kernel_shape, components}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, reduced_output_shape, components}) + .AddUniformVariables({{static_cast(output_size)}, {dilations}, {strides}, {updated_pads}, {static_cast(output_channels_per_group)}, {static_cast(components)}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + if (has_bias) { + auto reduced_bias_shape = ReduceShapeByComponents(modified_input_output_shapes[2], components); + program.AddInput({inputs[2], ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, components}); + } + return context.RunProgram(program); + } + const auto input_height = input_shape[is_channels_last ? 1 : 2]; + const auto input_width = input_shape[is_channels_last ? 2 : 3]; + const auto input_channels = input_shape[is_channels_last ? 3 : 1]; + const auto kernel_height = kernel_shape[2]; + const auto kernel_width = kernel_shape[3]; + const auto output_height = output_shape_vector[is_channels_last ? 1 : 2]; + const auto output_width = output_shape_vector[is_channels_last ? 2 : 3]; + + const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0; + if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) { + Tensor transposed_kernel; + TensorShape input_reshape; + TensorShape kernel_reshape; + TensorShape matmul_output_shape; + std::vector matmul_inputs; + std::vector matmul_input_reshapes; + if (is_channels_last) { + // Transpose weights + + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + inputs[1] = &transposed_kernel; + if (same_size) { + const auto shared_dim = input_height * input_width * input_channels; + input_reshape = TensorShape({1, batch, shared_dim}); + kernel_reshape = TensorShape({1, shared_dim, output_channels}); + matmul_output_shape = TensorShape({1, batch, output_channels}); + } else { + input_reshape = TensorShape({batch, input_height * input_width, input_channels}); + kernel_reshape = TensorShape({1, input_channels, output_channels}); + matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels}); + } + matmul_inputs.push_back(input); + matmul_inputs.push_back(&transposed_kernel); + matmul_input_reshapes.push_back(input_reshape); + matmul_input_reshapes.push_back(kernel_reshape); + } else { + input_reshape = TensorShape({batch, input_channels, input_height * input_width}); + kernel_reshape = TensorShape({1, output_channels, input_channels}); + matmul_output_shape = TensorShape({batch, output_channels, output_height * output_width}); + matmul_inputs.push_back(kernel); + matmul_inputs.push_back(input); + matmul_input_reshapes.push_back(kernel_reshape); + matmul_input_reshapes.push_back(input_reshape); + } + if (has_bias) { + matmul_inputs.push_back(bias); + } + auto N = matmul_output_shape[2]; + auto matmul_first_input_numdims = matmul_input_reshapes[0].NumDimensions(); + auto K = matmul_input_reshapes[0].GetDims()[matmul_first_input_numdims - 1]; + if (N < 8 && K < 8) { + const auto components = GetMaxComponents(N); + const auto a_components = GetMaxComponents(K); + const auto output_number = GetMaxComponents(output_shape[1]); + uint32_t output_size = static_cast(output_shape.Size() / components / output_number); + const size_t output_rank = matmul_output_shape.NumDimensions(); + TensorShape outer_dims = output_rank > 2 ? matmul_output_shape.Slice(0, output_rank - 2) : TensorShape({}); + MatMulNaiveProgram program(activation_, output_rank, output_number, has_bias); + program + .CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number)) + .AddInputs({{matmul_inputs[0], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[0], a_components), int(a_components)}, + {matmul_inputs[1], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[1], components), int(components)}}); + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, bias->Shape(), components}); + } + program + .AddOutputs({{output, ProgramTensorMetadataDependency::None, ReduceShapeByComponents(matmul_output_shape, components), int(components)}}) + .SetDispatchGroupSize(static_cast((output_size + 63) / 64)) + .AddIndices(outer_dims) + .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); + return context.RunProgram(program); + } else { + MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); + return context.RunProgram(program); + } + } + const bool sequentially_access_by_threads = true; + // Transpose weights + Tensor transposed_kernel; + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + auto dim_a_outer = static_cast(is_channels_last ? output_height * output_width : output_channels); + auto dim_b_outer = static_cast(is_channels_last ? output_channels : output_height * output_width); + auto dim_inner = static_cast(kernel_height * kernel_width * input_channels); + inputs[1] = &transposed_kernel; + TensorShape transposed_kernel_shape = transposed_kernel.Shape(); + modified_input_output_shapes[1] = transposed_kernel.Shape(); + Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, sequentially_access_by_threads, modified_input_output_shapes); + return context.RunProgram(conv2d_mm_program); +} + +// Explicit template instantiation for FusedConv +template class Conv; +template class Conv; +template class Conv; +template class Conv; + +#define WEBGPU_ONNX_CONV_OPERATOR_KERNEL(VERSION_FROM) \ + ONNX_OPERATOR_KERNEL_EX( \ + Conv, \ + kMSInternalNHWCDomain, \ + VERSION_FROM, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); \ + \ + ONNX_OPERATOR_KERNEL_EX( \ + Conv, \ + kOnnxDomain, \ + VERSION_FROM, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); + +#define WEBGPU_ONNX_CONV_OPERATOR_VERSIONED_KERNEL(VERSION_FROM, VERSION_TO) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Conv, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); \ + \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Conv, \ + kMSInternalNHWCDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); + +WEBGPU_ONNX_CONV_OPERATOR_VERSIONED_KERNEL(1, 10) +WEBGPU_ONNX_CONV_OPERATOR_VERSIONED_KERNEL(11, 21) +WEBGPU_ONNX_CONV_OPERATOR_KERNEL(22) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h new file mode 100644 index 0000000000000..cafaa272c0613 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/optional.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +template +class Conv : public WebGpuKernel { + public: + Conv(const OpKernelInfo& info) : WebGpuKernel(info), conv_attrs_(info) { + if (is_fused) { + ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); + } + } + Status ComputeInternal(ComputeContext& context) const override; + + protected: + ConvAttributes conv_attrs_; + Activation activation_; +}; + +Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc new file mode 100644 index 0000000000000..24e49304cf532 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "core/providers/webgpu/nn/conv2d_mm_webgpu.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/activation_util.h" +#include "core/providers/webgpu/math/matmul_packed.h" +#include "core/providers/webgpu/nn/conv_utils.h" +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace webgpu { +std::string Conv2dMMProgram::Conv2dCommonSnippet(const ShaderVariableHelper& x, const ShaderVariableHelper& w, const Activation& activation, std::string data_type, uint32_t inner_element_size_x, uint32_t inner_element_size_w, uint32_t inner_element_size) const { + auto get_x_snippet = [&](int32_t inner_element_size) -> std::string { + switch (inner_element_size) { + case 1: + return "resData = " + x.GetByOffset("xIndex") + ";"; + case 3: + return "resData = vec3(" + x.GetByOffset("xIndex") + ", " + x.GetByOffset("xIndex + 1") + ", " + x.GetByOffset("xIndex + 2") + ");"; + case 4: + return "resData = " + x.GetByOffset("xIndex") + ";\n "; + default: + ORT_THROW("inner_element_size", inner_element_size, " is not supported."); + } + }; + auto get_w_snippet = [&](int32_t inner_element_size) -> std::string { + switch (inner_element_size) { + case 1: + return "return " + w.GetByOffset("row * i32(uniforms.w_shape[3]) + colIn") + ";\n"; + case 4: + return "return " + w.GetByOffset("row * i32(uniforms.w_shape[3]) + colIn") + ";\n"; + default: + ORT_THROW("inner_element_size ", inner_element_size, " is not supported."); + } + }; + const std::string coord_a_snippet = is_channels_last_ ? "let coord = vec4(batch, xRow, xCol, xCh / " + std::to_string(inner_element_size_x == 3 ? 4 : inner_element_size_x) + ");" : "let coord = vec4(batch, xCh, xRow, xCol);"; + const std::string coord_res_snippet = is_channels_last_ ? "let coords = vec4(batch, row / outWidth, row % outWidth, col / " + std::to_string(inner_element_size) + ");" : "let coords = vec4(batch, row, col / outWidth, col % outWidth);"; + + const std::string xHeight = is_channels_last_ ? "i32(uniforms.x_shape[1])" : "i32(uniforms.x_shape[2])"; + const std::string xWidth = is_channels_last_ ? "i32(uniforms.x_shape[2])" : "i32(uniforms.x_shape[3])"; + const std::string row = is_channels_last_ ? "row" : "col"; + const std::string col = is_channels_last_ ? "col" : "row"; + std::stringstream read_x_snippet; + read_x_snippet + << "let inChannels = i32(uniforms.w_shape[2]);\n" + << "let outWidth = " << (is_channels_last_ ? "i32(uniforms.result_shape[2])" : "i32(uniforms.result_shape[3])") << ";\n" + << "let outRow = " << row << " / outWidth;\n " + << "let outCol = " << row << " % outWidth;\n" + << "let WRow = " << col << " / (i32(uniforms.w_shape[1]) * inChannels);\n" + << "let WCol = " << col << " / inChannels % i32(uniforms.w_shape[1]);\n" + << "let xRow = outRow * i32(uniforms.strides[0]) + i32(uniforms.dilations[0]) * WRow - i32(uniforms.pads[0]);\n" + << "let xCol = outCol * i32(uniforms.strides[1]) + i32(uniforms.dilations[1]) * WCol - i32(uniforms.pads[1]);\n" + << "let xCh = " << col << " % inChannels;\n" + << "var resData = " << TypeSnippet(inner_element_size_x, data_type) << "(0.0);\n " + << "// The bounds checking is always needed since we use it to pad zero for\n" + << "// the \" same \" padding type.\n" + << "if (xRow >= 0 && xRow < " << xHeight << " && xCol >= 0 && xCol < " << xWidth << ") {\n" + << " " << coord_a_snippet << "\n" + << " let xIndex = getIndexFromCoords4D(coord, vec4(uniforms.x_shape));\n" + << " " << get_x_snippet(inner_element_size_x) + << "}\n" + << "return resData;"; + std::stringstream sample_x; + if (is_channels_last_) { + if (fit_a_outer_ && fit_inner_) { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << read_x_snippet.str(); + } else { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n" + << " " << read_x_snippet.str() << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_x, data_type) << "(0.0);\n"; + } + } else { + if (fit_inner_ && fit_b_outer_) { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << read_x_snippet.str(); + } else { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << "if (row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" + << " " << read_x_snippet.str() << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_x, data_type) << "(0.0);\n"; + } + } + std::stringstream sample_w; + if (is_channels_last_) { + if (fit_inner_ && fit_b_outer_) { + sample_w << get_w_snippet(inner_element_size_w); + } else { + sample_w << "let col = colIn * " << inner_element_size_w << ";\n" + << "if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" + << " " << get_w_snippet(inner_element_size_w) << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_w, data_type) << "(0.0);\n"; + } + } else { + sample_w << "let col = colIn * " << inner_element_size_w << ";\n" + << "if (row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" + << " " << get_w_snippet(inner_element_size_w) << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_w, data_type) << "(0.0);\n"; + } + const std::string res_type = TypeSnippet(inner_element_size, data_type); + const std::string a_type = is_channels_last_ ? TypeSnippet(inner_element_size_x, data_type) : TypeSnippet(inner_element_size_w, data_type); + const std::string b_type = is_channels_last_ ? TypeSnippet(inner_element_size_w, data_type) : TypeSnippet(inner_element_size_x, data_type); + const std::string apply_activation = GetActivationSnippet(activation, res_type, data_type); + std::stringstream user_code; + user_code << "fn mm_readA(batch : i32, row : i32, colIn : i32) -> " << a_type << " {\n" + << (is_channels_last_ ? sample_x.str() : sample_w.str()) + << "}\n" + << "\n" + << "fn mm_readB(batch : i32, row : i32, colIn : i32) -> " << b_type << " {\n" + << (is_channels_last_ ? sample_w.str() : sample_x.str()) + << "}\n" + << "\n" + << "fn mm_write(batch : i32, row : i32, colIn : i32, valueIn : " << res_type << ") {\n" + << " let col = colIn * " << inner_element_size << ";\n" + << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n" + << " var value = valueIn;\n" + << " let outWidth = " << (is_channels_last_ ? " i32(uniforms.result_shape[2]) " : " i32(uniforms.result_shape[3]) ") << ";\n" + << " " << coord_res_snippet << "\n" + << " " << BiasSnippet(has_bias_) << "\n" + << " " << apply_activation << "\n" + << " setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value);\n" + << " }\n" + << "}\n"; + return user_code.str(); +} + +Status Conv2dMMProgram::GenerateShaderCode(ShaderHelper& shader) const { + std::stringstream declaration_functions; + declaration_functions << "fn setOutputAtIndex(flatIndex : i32, value : " << (is_vec4_ ? "vec4" : "x_element_t") << ") {\n" + << " result[flatIndex] = " << (is_vec4_ ? "vec4" : "x_element_t") << "(value);\n" + << "}\n" + << "fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : " << (is_vec4_ ? "vec4" : "x_element_t") << "){\n" + << " let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3));\n" + << " setOutputAtIndex(flatIndex, value);\n" + << "}\n"; + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + std::vector inputs = {&x, &w}; + ORT_IGNORE_RETURN_VALUE(shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesTypeAlias)); + if (has_bias_) { + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + inputs.push_back(&bias); + declaration_functions << "fn getBiasByOutputCoords(coords : vec4) -> bias_value_t {" << "\n" + << " return bias[" << (is_channels_last_ ? "coords.w" : "coords.y") << "];\n" + << "}"; + } + shader.AdditionalImplementation() + << UtilFunctions("uniforms.result_stride") + << declaration_functions.str() + << Conv2dCommonSnippet(x, w, activation_, "x_element_t", element_size_[0], element_size_[1], element_size_[2]); + std::string data_type = "x_element_t"; + return is_vec4_ ? MatMulProgram::MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, /* transpose_a = */ !is_channels_last_, tile_inner_) : MatMulProgram::MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, false, tile_inner_, false, 0, sequentially_access_by_threads_); +} + +Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last, bool sequentially_access_by_threads, const std::vector& input_output_shapes) { + const auto* input = inputs[0]; + const auto* weight = inputs[1]; + bool has_bias = inputs.size() > 2; + const auto* bias = has_bias ? inputs[2] : nullptr; + const auto& input_shape = input_output_shapes[0]; + auto in_channels = is_channels_last ? input_shape[3] : input_shape[1]; + const auto& output_shape = has_bias ? input_output_shapes[3] : input_output_shapes[2]; + auto batch_size = output_shape[0]; + const auto output_width = is_channels_last ? output_shape[2] : output_shape[3]; + const auto output_height = is_channels_last ? output_shape[1] : output_shape[2]; + const auto output_channels = is_channels_last ? output_shape[3] : output_shape[1]; + // TODO: enable vec4 for NCHW + const bool is_vec4 = is_channels_last && (in_channels % 4 == 0 || in_channels % 3 == 0) && output_channels % 4 == 0; + + // TODO: fine tune size + const auto dispatch_x = is_channels_last ? output_channels : output_width * output_height; + const auto dispatch_y = is_channels_last ? output_width * output_height : output_channels; + std::vector workgroup_size = {8, 8, 1}; + InlinedVector elements_per_thread = {4, static_cast(dim_a_outer <= 8 ? 1 : 4), 1}; + auto integer_ceil = [](int64_t a, int64_t b) -> int64_t { return (a + b - 1) / b; }; + + const std::vector dispatch = { + static_cast(integer_ceil(integer_ceil(dispatch_x, workgroup_size[0]), elements_per_thread[0])), + static_cast(integer_ceil(integer_ceil(dispatch_y, workgroup_size[1]), elements_per_thread[1])), + static_cast(integer_ceil(integer_ceil(batch_size, workgroup_size[2]), elements_per_thread[2])), + }; + + uint32_t inner_element_size = is_vec4 ? (is_channels_last && in_channels % 4 != 0 ? 3 : 4) : 1; + auto tile_a_outer = static_cast(workgroup_size[1] * elements_per_thread[1]); + auto tile_b_outer = static_cast(workgroup_size[0] * elements_per_thread[0]); + auto tile_inner = std::max(workgroup_size[0] * inner_element_size, workgroup_size[1]); + bool fit_a_outer = dim_a_outer % tile_a_outer == 0; + bool fit_b_outer = dim_b_outer % tile_b_outer == 0; + bool fit_inner = dim_inner % tile_inner == 0; + std::vector element_size = {is_vec4 ? inner_element_size : 1, static_cast(is_vec4 ? 4 : 1), static_cast(is_vec4 ? 4 : 1)}; + const auto components = is_vec4 ? 4 : 1; + const auto input_components = static_cast(inner_element_size == 3 ? 1 : inner_element_size); + Conv2dMMProgram program(activation, tile_inner, fit_a_outer, fit_b_outer, fit_inner, is_channels_last, is_vec4, has_bias, std::move(element_size), std::move(elements_per_thread), sequentially_access_by_threads); + TensorShape reduced_input_shape = ReduceShapeByComponents(input_output_shapes[0], input_components); + TensorShape reduced_weight_shape = ReduceShapeByComponents(input_output_shapes[1], components); + TensorShape reduced_output_shape = ReduceShapeByComponents(input_output_shapes[has_bias ? 3 : 2], components); + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, reduced_input_shape, input_components}, {weight, ProgramTensorMetadataDependency::TypeAndRank, reduced_weight_shape, components}}); + if (has_bias) { + TensorShape reduced_bias_shape = ReduceShapeByComponents(input_output_shapes[2], components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, components}); + } + const auto stringify = [](const std::vector& vec) -> std::string { + std::ostringstream oss; + std::transform(vec.begin(), vec.end(), std::ostream_iterator(oss, ","), [](uint32_t i) { return std::to_string(i); }); + return oss.str(); + }; + program.CacheHint(activation.ToString(), stringify({inner_element_size, static_cast(is_vec4 ? 1 : 0), fit_a_outer, fit_b_outer, fit_inner, tile_a_outer, tile_a_outer, tile_inner, static_cast(components)})) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, reduced_output_shape, components}) + .SetDispatchGroupSize(dispatch[0], dispatch[1], dispatch[2]) + .SetWorkgroupSize(workgroup_size[0], workgroup_size[1], workgroup_size[2]) + .AddUniformVariables({{static_cast(dim_a_outer)}, + {static_cast(dim_b_outer)}, + {static_cast(dim_inner)}, + {pads}, + {strides}, + {dilations}}); + + return program; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h new file mode 100644 index 0000000000000..0087d11db179d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { +class Conv2dMMProgram final : public Program { + public: + Conv2dMMProgram(const Activation& activation, uint32_t tile_inner, bool fit_a_outer, bool fit_b_outer, bool fit_inner, bool is_channels_last, bool is_vec4, bool has_bias, std::vector&& element_size, InlinedVector&& elements_per_thread, bool sequentially_access_by_threads) : Program("Conv2dMM"), + activation_(activation), + tile_inner_(tile_inner), + fit_a_outer_(fit_a_outer), + fit_b_outer_(fit_b_outer), + fit_inner_(fit_inner), + is_channels_last_(is_channels_last), + is_vec4_(is_vec4), + has_bias_(has_bias), + element_size_(std::move(element_size)), + elements_per_thread_(std::move(elements_per_thread)), + sequentially_access_by_threads_(sequentially_access_by_threads) {} + + std::string Conv2dCommonSnippet(const ShaderVariableHelper& x, const ShaderVariableHelper& w, const Activation& activation, std::string data_type, uint32_t inner_element_size_x = 4, uint32_t inner_element_size_w = 4, uint32_t inner_element_size = 4) const; + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation& activation_; + uint32_t tile_inner_; + bool fit_a_outer_; + bool fit_b_outer_; + bool fit_inner_; + bool is_channels_last_; + bool is_vec4_; + bool has_bias_; + std::vector element_size_; + InlinedVector elements_per_thread_; + bool sequentially_access_by_threads_; +}; + +Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last, bool sequentially_access_by_threads, const std::vector& modified_input_output_shapes); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc new file mode 100644 index 0000000000000..74f3e0dcc85f5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/nn/conv_backprop_webgpu.h" +#include "core/providers/webgpu/webgpu_utils.h" +namespace onnxruntime { +namespace webgpu { + +Status ConvTranspose2DProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& dy = shader.AddInput("dy", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + if (has_bias_) { + shader.AddInput("bias"); + } + auto row_dim = is_channels_last_ ? 1 : 2; + auto col_dim = is_channels_last_ ? 2 : 3; + auto channel_dim = is_channels_last_ ? 3 : 1; + auto calculate_result = [&]() -> std::string { + std::stringstream ss; + if (pack_input_as4_) { + if (a_components_ == 4) { + ss << "let xValue = " << dy.GetByOffset("x_offset") << ";\n" + << "let wValue = " << w.GetByOffset("w_offset") << ";\n" + << "dotProd = dotProd + dot(xValue, wValue);\n" + << "x_offset += 1;\n" + << "w_offset += 1;\n"; + } else if (a_components_ == 2) { + ss << "let xValue = vec4(" << dy.GetByOffset("x_offset") << ", " << dy.GetByOffset("x_offset + 1") << ");\n" + << "let wValue = vec4(" << w.GetByOffset("w_offset") << ", " << w.GetByOffset("w_offset + 1u") << ");\n" + << "dotProd = dotProd + dot(xValue, wValue);\n" + << "x_offset += 2;\n" + << "w_offset += 2;\n"; + } else if (a_components_ == 1) { + ss << "let xValue = vec4(" << dy.GetByOffset("x_offset") << ", " << dy.GetByOffset("x_offset + 1u") << ", " << dy.GetByOffset("x_offset + 2u") << ", " << dy.GetByOffset("x_offset + 3u") << ");\n" + << "let wValue = vec4(" << w.GetByOffset("x_offset") << ", " << w.GetByOffset("x_offset + 1u") << ", " << w.GetByOffset("x_offset + 2u") << ", " << w.GetByOffset("x_offset + 3u") << ");\n" + << "dotProd = dotProd + dot(xValue, wValue);\n" + << "x_offset += 4;\n" + << "w_offset += 4;\n"; + } + } else { + if (is_channels_last_) { + ss << "let xValue = " << dy.GetByIndices("dy_indices_t(batch, idyR, idyC, inputChannel / " + std::to_string(a_components_)) << ");\n"; + } else { + ss << "let xValue = " << dy.GetByIndices("dy_indices_t(batch, inputChannel, idyR, idyC)") << ";\n"; + } + if (a_components_ == 1) { + ss << "let wValue = " << w.GetByIndices("w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)") << ";\n" + << "dotProd = dotProd + xValue * wValue;\n"; + } else if (a_components_ == b_components_ && components_ == 1) { + ss << "let wValue = " << w.GetByIndices("w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)") << ";\n" + << "dotProd = dotProd + dot(xValue, wValue);\n"; + } else { + for (uint32_t i = 0; i < a_components_; ++i) { + ss << "let w_indices" << i << " = w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel + d2 + " << i << ", wOutChannel);\n " + << "let w_offset" << i << " = " << w.IndicesToOffset("w_indices" + std::to_string(i)) << ";\n" + << "let wValue" << i << " = " << w.GetByIndices("w_indices" + std::to_string(i)) << ";\n" + << "dotProd = dotProd + xValue[" << i << "] * wValue" << i << ";\n"; + } + } + } + return ss.str(); + }; + auto calculate_remainder = [&]() -> std::string { + std::stringstream ss; + if (input_channels_remainder_ > 0) { + ORT_ENFORCE(pack_input_as4_, "Invalid input_channels_remainder: ", input_channels_remainder_); + if (a_components_ == 1) { + for (uint32_t i = 0; i < input_channels_remainder_; ++i) { + ss << "dotProd = dotProd + " << dy.GetByOffset("x_offset + " + std::to_string(i)) << ";\n"; + } + } else if (a_components_ == 2) { + if (input_channels_remainder_ != 2) { + ORT_THROW("Invalid input_channels_remainder: ", input_channels_remainder_); + } + ss << "let xValue = " << dy.GetByOffset("x_offset") << ";\n" + << "let wValue = " << w.GetByOffset("w_offset") << ";\n" + << "dotProd = dotProd + dot(xValue, wValue);\n"; + } + } + return ss.str(); + }; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let outputIndices = " << output.OffsetToIndices("global_idx") << ";\n" + << "let batch = " << output.IndicesGet("outputIndices", 0) << ";\n" + << "let d1 = " << output.IndicesGet("outputIndices", channel_dim) << ";\n" + << "let r = " << output.IndicesGet("outputIndices", row_dim) << ";\n" + << "let c = " << output.IndicesGet("outputIndices", col_dim) << ";\n" + << "let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads);\n" + << "let dyRCorner = dyCorner.x;\n" + << "let dyCCorner = dyCorner.y;\n" + << "let groupId = d1 / (uniforms.output_channels_per_group / " << components_ << ");\n" + << "let wOutChannel = d1 - groupId * (uniforms.output_channels_per_group / " << components_ << ");\n" + << "// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n" + << "// ? = to be determined. : = across all values in that axis.\n" + << "var dotProd = output_value_t(0.0);\n" + << "var wR: u32 = 0;\n" + << "if (uniforms.dilations.x == 1) {\n" + << " // Minimum wR >= 0 that satisfies (dyRCorner + wR) % (uniforms.strides.x) == 0\n" + << " wR = u32(((dyRCorner + i32(uniforms.strides.x) - 1) / i32(uniforms.strides.x)) * i32(uniforms.strides.x) - dyRCorner);\n" + << "}\n" + << "for (; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {\n" + << " if (wR % uniforms.dilations.x != 0) {\n" + << " continue;\n" + << " }\n" + << " let dyR = (dy_element_t(dyRCorner) + dy_element_t(wR)) / dy_element_t(uniforms.strides[0]);\n" + << " let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;\n" + << " if (dyR < 0.0 || dyR >= dy_element_t(uniforms.dy_shape[" << row_dim << "]) || fract(dyR) > 0.0 || wRPerm < 0) {\n" + << " continue;\n" + << " }\n" + << " let idyR: u32 = u32(dyR);\n" + << " var wC: u32 = 0;\n" + << " if (uniforms.dilations.y == 1) {\n" + << " // Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0\n" + << " wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner);\n" + << " }\n" + << " for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {\n" + << " if (wC % uniforms.dilations.y != 0) {" + << " continue;\n" + << " }\n" + << " let dyC = (dy_element_t(dyCCorner) + dy_element_t(wC)) / dy_element_t(uniforms.strides.y);\n" + << " let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;\n" + << " if (dyC < 0.0 || dyC >= dy_element_t(uniforms.dy_shape[" << col_dim << "]) ||\n" + << " fract(dyC) > 0.0 || wCPerm < 0) {\n" + << " continue;\n" + << " }\n" + << " let idyC: u32 = u32(dyC);\n" + << " var inputChannel = groupId * uniforms.input_channels_per_group;\n"; + if (pack_input_as4_) { + shader.MainFunctionBody() << " let dy_indices = dy_indices_t(batch, idyR, idyC, inputChannel);\n" + << " let w_indices = w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel);\n" + << " var x_offset = " << dy.IndicesToOffset("dy_indices") << ";\n" + << " var w_offset = " << w.IndicesToOffset("w_indices") << ";\n"; + } + + shader.MainFunctionBody() << " for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + " << (pack_input_as4_ ? 4 : a_components_) << ") {\n" + << " " << calculate_result() << "\n" + << " inputChannel = inputChannel + " << (pack_input_as4_ ? 4 : 1) << ";\n" + << " }\n" + << " " << calculate_remainder() << "\n" + << " wC = wC + uniforms.strides.y - 1;\n" + << " }\n" + << " wR = wR + uniforms.strides.x - 1;\n" + << "}\n" + << "let value = dotProd" << (has_bias_ ? " + bias[d1]" : "") << ";\n" + << output.SetByOffset("global_idx", "value") << "\n"; + return Status::OK(); +} + +ConvTranspose2DProgram CreateConvTranspose2DProgram(const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, bool is_channels_last, const std::vector& modified_input_output_shapes, uint32_t groups) { + bool has_bias = inputs.size() > 2; + const auto* input = inputs[0]; + const auto* weight = inputs[1]; + const auto& input_shape = modified_input_output_shapes[0]; + const auto& weight_shape = modified_input_output_shapes[1]; + const auto& output_shape = modified_input_output_shapes[has_bias ? 3 : 2]; + auto input_channels_per_group = weight_shape[2] / groups; + auto output_channels_per_group = weight_shape[3]; + auto a_components = is_channels_last ? GetMaxComponents(input_channels_per_group) : 1; + bool pack_input_as4 = is_channels_last && output_channels_per_group == 1 && input_channels_per_group >= 4; + auto input_channels_per_group_int = pack_input_as4 ? ((input_channels_per_group + 3) / 4) * 4 : (input_channels_per_group / a_components) * a_components; + auto input_channels_remainder = input_channels_per_group - input_channels_per_group_int; + auto components = is_channels_last ? GetMaxComponents(output_channels_per_group) : 1; + auto b_components = is_channels_last ? (output_channels_per_group == 1 ? a_components : components) : 1; + TensorShape reduced_input_shape = ReduceShapeByComponents(input_shape, a_components); + TensorShape reduced_weight_shape = ReduceShapeByComponents(weight_shape, b_components); + TensorShape reduced_output_shape = ReduceShapeByComponents(output_shape, components); + auto output_size = reduced_output_shape.Size(); + std::vector kernel_dims = {static_cast(weight_shape[0]), static_cast(weight_shape[1])}; + std::vector effective_kernel_dims = {kernel_dims[0] + ((dilations[0] <= 1) ? 0 : ((kernel_dims[0] - 1) * (dilations[0] - 1))), kernel_dims[1] + ((dilations[1] <= 1) ? 0 : ((kernel_dims[1] - 1) * (dilations[1] - 1)))}; + std::vector local_pads = {effective_kernel_dims[0] - 1 - pads[0], effective_kernel_dims[1] - 1 - pads[1]}; + ConvTranspose2DProgram program(is_channels_last, has_bias, components, a_components, b_components, uint32_t(input_channels_remainder), pack_input_as4); + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, reduced_input_shape, a_components}, {weight, ProgramTensorMetadataDependency::TypeAndRank, reduced_weight_shape, b_components}}); + if (has_bias) { + const auto* bias = inputs[2]; + const auto& bias_shape = modified_input_output_shapes[2]; + TensorShape reduced_bias_shape = ReduceShapeByComponents(bias_shape, components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, components}); + } + program.AddOutput({output, ProgramTensorMetadataDependency::Rank, reduced_output_shape, components}) + .CacheHint(std::to_string(input_channels_remainder) + "-" + std::to_string(pack_input_as4) + std::to_string(components) + + "-" + std::to_string(b_components) + "-" + std::to_string(a_components) + "-" + std::to_string(is_channels_last ? 1 : 0)) + .AddUniformVariables({{static_cast(output_size)}, {strides}, {kernel_dims}, {dilations}, {effective_kernel_dims}, {local_pads}, {static_cast(input_channels_per_group_int)}, {static_cast(input_channels_per_group)}, {static_cast(output_channels_per_group)}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + + return program; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h new file mode 100644 index 0000000000000..6c784e4825a65 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webgpu { + +class ConvTranspose2DProgram : public Program { + public: + ConvTranspose2DProgram(bool is_channels_last, bool has_bias, uint32_t components, uint32_t a_components, uint32_t b_components, uint32_t input_channels_remainder, bool pack_input_as4) : Program("ConvTranspose2D"), is_channels_last_(is_channels_last), has_bias_(has_bias), components_(components), a_components_(a_components), b_components_(b_components), input_channels_remainder_(input_channels_remainder), pack_input_as4_(pack_input_as4) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"filter_dims", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"effective_filter_dims", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"input_channels_per_group_int", ProgramUniformVariableDataType::Uint32}, + {"input_channels_per_group", ProgramUniformVariableDataType::Uint32}, + {"output_channels_per_group", ProgramUniformVariableDataType::Uint32}); + + private: + bool is_channels_last_; + bool has_bias_; + uint32_t components_; + uint32_t a_components_; + uint32_t b_components_; + uint32_t input_channels_remainder_; + bool pack_input_as4_; +}; + +ConvTranspose2DProgram CreateConvTranspose2DProgram(const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, bool is_channels_last, const std::vector& modified_input_output_shapes, uint32_t groups); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc new file mode 100644 index 0000000000000..9cd290ef56013 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "conv.h" +#include "conv_transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/webgpu/nn/conv_backprop_webgpu.h" + +namespace onnxruntime { +namespace webgpu { +// kernel shape is the spacial dims of the filter. +// ie. filter shape with batch and channel. kernel shape dimension is 2 less than the filter dimension + +template +Status ConvTranspose::ComputeInternal(ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* filter = context.Input(1); + TensorShape input_shape = input->Shape(); + TensorShape filter_shape = filter->Shape(); + const InlinedVector perm = {2, 3, 0, 1}; + TensorShapeVector local_output_padding(conv_transpose_attrs_.output_padding.begin(), conv_transpose_attrs_.output_padding.end()); + ConvAttributes::ConvPadVector local_pads(conv_transpose_attrs_.pads.begin(), conv_transpose_attrs_.pads.end()); + TensorShapeVector local_dilations(conv_transpose_attrs_.dilations.begin(), conv_transpose_attrs_.dilations.end()); + TensorShapeVector local_strides(conv_transpose_attrs_.strides.begin(), conv_transpose_attrs_.strides.end()); + TensorShapeVector kernel_shape_vector; + auto rank = input_shape.NumDimensions(); + TensorShape input_spacial_shape = input_shape.Slice(is_channels_last ? 1 : 2, is_channels_last ? rank - 1 : rank); + local_pads.reserve(2 * (input_spacial_shape.NumDimensions())); + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(filter_shape, kernel_shape_vector, false)); + if (local_output_padding.empty()) { + local_output_padding.resize(kernel_shape_vector.size(), 0); + } + if (local_pads.empty()) { + local_pads.resize(kernel_shape_vector.size() * 2, 0); + } + if (local_dilations.empty()) { + local_dilations.resize(kernel_shape_vector.size(), 1); + } + if (local_strides.empty()) { + local_strides.resize(kernel_shape_vector.size(), 1); + } + auto group = conv_transpose_attrs_.group; + auto num_output_channels = group * filter_shape[1]; + auto batch_size = input_shape[0]; + TensorShapeVector output_shape_vector; + conv_transpose_attrs_.ComputePadsAndOutputShape(input_spacial_shape, num_output_channels, kernel_shape_vector, local_strides, local_dilations, local_output_padding, batch_size, &local_pads, &output_shape_vector, is_channels_last); + TensorShape computed_output_shape(output_shape_vector); + std::vector strides; + std::vector pads; + std::vector dilations; + auto transform_dim = [](int64_t dim) { return static_cast(dim); }; + std::transform(local_pads.begin(), local_pads.end(), std::back_inserter(pads), transform_dim); + std::transform(local_strides.begin(), local_strides.end(), std::back_inserter(strides), transform_dim); + std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim); + + bool has_bias = context.InputCount() > 2; + const auto* bias = has_bias ? context.Input(2) : nullptr; + if (input_shape.NumDimensions() == 3 && filter_shape.NumDimensions() == 3) { + // ConvTranspose1D + TensorShapeVector input_shape_vector = input_shape.AsShapeVector(); + TensorShapeVector filter_shape_vector = filter_shape.AsShapeVector(); + input_shape_vector.insert(input_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + output_shape_vector.insert(output_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + filter_shape_vector.insert(filter_shape_vector.begin() + 2, 1); + input_shape = TensorShape(input_shape_vector); + filter_shape = TensorShape(filter_shape_vector); + pads.insert(pads.begin(), 0); + pads.insert(pads.begin() + 2, 0); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } + if (input_shape.NumDimensions() > 4 || filter_shape.NumDimensions() > 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv2d or Conv1d are supported."); + } else if (input_shape.NumDimensions() < 2 || filter_shape.NumDimensions() < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input and kernel tensors must have at least 3 dimensions"); + } + // Transpose weights + Tensor transposed_filter; + ORT_RETURN_IF_ERROR(TransposeKernel(context, filter, filter_shape, &transposed_filter, perm)); + TensorShape output_shape(output_shape_vector); + TensorShape transposed_filter_shape = transposed_filter.Shape(); + std::vector inputs = {input, &transposed_filter}; + std::vector input_output_shapes = {input_shape, transposed_filter_shape}; + if (has_bias) { + inputs.push_back(bias); + input_output_shapes.push_back(bias->Shape()); + } + uint32_t auto_pad_adjust = conv_transpose_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0; + auto pad0 = conv_transpose_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2; + auto pad1 = conv_transpose_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2; + Tensor* output = context.Output(0, computed_output_shape); + input_output_shapes.push_back(output_shape); + auto program = CreateConvTranspose2DProgram(inputs, {pad0, pad1}, strides, dilations, output, is_channels_last, input_output_shapes, static_cast(conv_transpose_attrs_.group)); + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 11, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 11, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_transpose.h b/onnxruntime/core/providers/webgpu/nn/conv_transpose.h new file mode 100644 index 0000000000000..a97b3f5947303 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_transpose.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/common.h" + +#include "core/providers/cpu/nn/conv_transpose_attributes.h" +#include "core/providers/webgpu/webgpu_kernel.h" +namespace onnxruntime { +namespace webgpu { + +template +class ConvTranspose final : public WebGpuKernel { + public: + ConvTranspose(const OpKernelInfo& info) : WebGpuKernel(info), conv_transpose_attrs_(info) { + } + Status ComputeInternal(ComputeContext& context) const override; + + protected: + ConvTransposeAttributes conv_transpose_attrs_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_utils.cc b/onnxruntime/core/providers/webgpu/nn/conv_utils.cc new file mode 100644 index 0000000000000..233662c10bfb8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_utils.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/nn/conv_utils.h" +namespace onnxruntime { +namespace webgpu { +std::string UtilFunctions(std::string stride_string) { + std::stringstream ss; + ss << "fn getIndexFromCoords3D(coords : vec3, shape : vec3) -> i32 {\n" + << " return dot(coords, vec3(shape.y * shape.z, shape.z, 1));\n" + << "}\n" + << "fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 {\n" + << " return dot(coords, vec4(shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));\n" + << "}\n" + << "fn getOutputIndexFromCoords(coords : vec4) -> i32 {\n" + << " return dot(coords, vec4(i32(" << stride_string << ".x), i32(" << stride_string << ".y), i32(" << stride_string << ".z), 1));\n" + << "}\n"; + return ss.str(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_utils.h b/onnxruntime/core/providers/webgpu/nn/conv_utils.h new file mode 100644 index 0000000000000..ad8aa868ff7f0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace webgpu { + +std::string UtilFunctions(std::string stride_string); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc new file mode 100644 index 0000000000000..38db604695a54 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/nn/fuse_utils.h" +#include +namespace onnxruntime { +namespace webgpu { + +Status GetFusedActivationAttr(const OpKernelInfo& info, Activation& activation) { + activation.activation_kind_ = ActivationKind::None; + + std::string activation_type; + if (info.GetAttr("activation", &activation_type).IsOK()) { + if (activation_type == "Relu") { + activation.activation_kind_ = ActivationKind::Relu; + } else if (activation_type == "Tanh") { + activation.activation_kind_ = ActivationKind::Tanh; + } else if (activation_type == "Sigmoid") { + activation.activation_kind_ = ActivationKind::Sigmoid; + } else { + // The remaining activation types have additional parameters to be pulled out. + size_t activation_params_count; + if (activation_type == "LeakyRelu") { + activation.activation_kind_ = ActivationKind::LeakyRelu; + activation_params_count = 1; + } else if (activation_type == "Clip") { + activation.activation_kind_ = ActivationKind::Clip; + activation_params_count = 2; + } else if (activation_type == "HardSigmoid") { + activation.activation_kind_ = ActivationKind::HardSigmoid; + activation_params_count = 2; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "unimplemented activation: " + activation_type); + } + + std::vector activation_params; + common::Status status = info.GetAttrs("activation_params", activation_params); + if (!status.IsOK()) { + return status; + } else if (activation_params_count != activation_params.size()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "activation_params count mismatch"); + } + for (size_t i = 0; i < activation_params_count; i++) { + activation.activation_params_.values_[i] = activation_params[i]; + } + } + } + + return Status::OK(); +} + +std::string GetActivationSnippet(const Activation& activation, std::string value_type, std::string base_type) { + std::string snippet; + auto base_type_cast = [base_type](float value) -> std::string { + return base_type + "(" + std::to_string(value) + ")"; + }; + auto value_type_cast = [base_type_cast, value_type](float f) -> std::string { + return value_type + "(" + base_type_cast(f) + ")"; + }; + switch (activation.activation_kind_) { + case ActivationKind::Relu: + return "value = max(value, " + value_type_cast(0.0) + ");"; + case ActivationKind::Sigmoid: + return "value = " + value_type_cast(1.0) + " / (" + value_type_cast(1.0) + " + exp(-value));"; + case ActivationKind::Clip: + return "value = clamp(value, " + value_type_cast(activation.activation_params_.Clip.minimum_) + ", " + value_type_cast(activation.activation_params_.Clip.maximum_) + ");"; + case ActivationKind::HardSigmoid: + return "value = clamp(" + value_type_cast(activation.activation_params_.HardSigmoid.alpha_) + " * value + " + value_type_cast(activation.activation_params_.HardSigmoid.beta_) + ", 0.0" + ", 1.0" + ");"; + case ActivationKind::LeakyRelu: + return "value = select(" + base_type_cast(activation.activation_params_.LeakyRelu.alpha_) + " * value, value, value >= " + value_type_cast(0.0) + ");"; + case ActivationKind::Tanh: + return "value = tanh(value);"; + default: + return ""; + } +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h new file mode 100644 index 0000000000000..f5d2585bb9b45 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include "core/providers/webgpu/webgpu_kernel.h" + +#pragma once +namespace onnxruntime { +namespace webgpu { +enum class ActivationKind { + None, + Relu, + Sigmoid, + Clip, + HardSigmoid, + LeakyRelu, + Tanh +}; + +using Activation = struct Activation { + std::string ToString() const { + std::stringstream oss; + oss << "ActivationKind: " << static_cast(activation_kind_) << ";"; + oss << "ActivationParams: " << activation_params_.values_[0] << ";"; + oss << "ActivationParams: " << activation_params_.values_[1] << ";"; + return oss.str(); + } + using ActivationParameters = union ActivationParameters { + struct { + float alpha_; + } LeakyRelu; + struct { + float minimum_; + float maximum_; + } Clip; + struct { + float alpha_; + float beta_; + } HardSigmoid; + float values_[2]; + }; + ActivationParameters activation_params_ = {}; + ActivationKind activation_kind_ = ActivationKind::None; +}; + +Status GetFusedActivationAttr(const OpKernelInfo& info, Activation& activation); +std::string GetActivationSnippet(const Activation& activation, std::string value_type, std::string base_type); +// Status AppendActivationUniformsData(const Activation& activation, std::vector& variables); +// Status AppendActivationUniforms(const Activation& activation, std::vector& data); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/grouped_conv.cc b/onnxruntime/core/providers/webgpu/nn/grouped_conv.cc new file mode 100644 index 0000000000000..4dc0b82cdd7eb --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/grouped_conv.cc @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/webgpu/nn/grouped_conv.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +std::string CanculateResult(const ShaderVariableHelper& x, const ShaderVariableHelper& w, bool is_channels_last) { + std::stringstream ss; + if (is_channels_last) { + ss << "for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[0]; wHeight++) {\n" + << " let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];\n" + << " if (xHeight < 0u || xHeight >= uniforms.x_shape[1]) {\n" + << " continue;\n" + << " }\n" + << "" + << " for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[1]; wWidth++) {\n" + << " let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];\n" + << " if (xWidth < 0u || xWidth >= uniforms.x_shape[2]) {\n" + << " continue;\n" + << " }\n" + << "" + << " for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[2]; wInChannel++) {\n" + << " let input_channel = in_channel_offset + wInChannel;\n" + << " let x_indices = x_indices_t(batch, xHeight, xWidth, input_channel);\n" + << " let w_indices = w_indices_t(wHeight, wWidth, wInChannel, output_channel);\n" + << " let xVal = " << x.GetByIndices("x_indices") << ";\n" + << " let wVal = " << w.GetByIndices("w_indices") << ";\n" + << " value += xVal * wVal;\n" + << " }\n" + << " }\n" + << "}\n"; + } else { + ss << "for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {\n" + << " let input_channel = in_channel_offset + wInChannel;\n" + << " for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {\n" + << " let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];\n" + << "" + << " if (xHeight < 0u || xHeight >= uniforms.x_shape[2]) {\n" + << " continue;\n" + << " }\n" + << "" + << " for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {\n" + << " let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];\n" + << " if (xWidth < 0u || xWidth >= uniforms.x_shape[3]) {\n" + << " continue;\n" + << " }\n" + << "" + << " let x_indices = x_indices_t(batch, input_channel, xHeight, xWidth);\n" + << " let w_indices = w_indices_t(output_channel, wInChannel, wHeight, wWidth);\n" + << " let xVal = " << x.GetByIndices("x_indices") << ";\n" + << " let wVal = " << w.GetByIndices("w_indices") << ";\n" + << " value += xVal * wVal;\n" + << " }\n" + << " }\n" + << "}\n"; + } + return ss.str(); +} + +Status GroupedConvProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "let batch: u32 = output_indices[0];\n" + << "let output_channel: u32 = " << output.IndicesGet("output_indices", is_channels_last_ ? "3" : "1") << ";\n" + << "let xRCCorner_x: u32 = " << output.IndicesGet("output_indices", is_channels_last_ ? "1" : "2") << ";\n" + << "let xRCCorner_y: u32 = " << output.IndicesGet("output_indices", is_channels_last_ ? "2" : "3") << ";\n" + << "let xRCCorner: vec2 = vec2(xRCCorner_x, xRCCorner_y) * uniforms.strides - uniforms.pads;\n" + << "let group_id = output_channel * uniforms.components / uniforms.output_channels_per_group;\n" + << "let in_channel_offset = group_id * " << w.IndicesGet("uniforms.w_shape", is_channels_last_ ? 2 : 1) << ";\n" + << "var value: output_value_t = output_value_t(0);\n" + << CanculateResult(x, w, is_channels_last_); + if (has_bias_) { + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.MainFunctionBody() << "value += " + b.GetByIndices("output_channel") + ";\n"; + } + shader.MainFunctionBody() << apply_activation << "\n"; + shader.MainFunctionBody() << output.SetByOffset("global_idx", "value"); + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/grouped_conv.h b/onnxruntime/core/providers/webgpu/nn/grouped_conv.h new file mode 100644 index 0000000000000..d09f9679eecf5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/grouped_conv.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/optional.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +class GroupedConvProgram final : public Program { + public: + GroupedConvProgram(const Activation& activation, bool has_bias, bool is_channels_last) : Program("GroupedConv"), activation_(activation), has_bias_(has_bias), is_channels_last_(is_channels_last) { + } + Status GenerateShaderCode(ShaderHelper& shader) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"output_channels_per_group", ProgramUniformVariableDataType::Uint32}, + {"components", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation& activation_; + bool has_bias_; + bool is_channels_last_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc index cb019892b006f..f24578a145aae 100644 --- a/onnxruntime/core/providers/webgpu/tensor/pad.cc +++ b/onnxruntime/core/providers/webgpu/tensor/pad.cc @@ -168,9 +168,9 @@ Status Pad::ComputeInternal(ComputeContext& context) const { PadProgram program{mode_, dim_value_zero, is_float16}; if (!dim_value_zero) { - program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}); + program.AddInput({input_tensor, ProgramTensorMetadataDependency::Rank}); } - program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::TypeAndRank}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .CacheHint(std::to_string(static_cast(mode_)), dim_value_zero) .AddUniformVariables({{gsl::span(lower_pads.data(), lower_pads.size())}, {output_size}, {value_uint32}}); diff --git a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc index f68ace3c1d8a1..75a7f859c965f 100644 --- a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc +++ b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc @@ -122,7 +122,7 @@ void CalcNearestPixel(std::ostream& os, ResizeNearestMode mode) { body = "select(i32(round(x_original)), i32(floor(x_original)), x_original == f32(i32(x_original)) + 0.5)"; break; case ResizeNearestMode::ROUND_PREFER_CEIL: - body = "i32(round(x_original))"; + body = "select(i32(round(x_original)), i32(ceil(x_original)), x_original == f32(i32(x_original)) + 0.5)"; break; case ResizeNearestMode::FLOOR: body = "i32(floor(x_original))"; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 955b54e873261..2987d3905fe54 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -136,6 +136,8 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi LOGS_DEFAULT(VERBOSE) << "WebGPU EP Context is created for: Instance=" << instance_.Get() << ", Device=" << device_.Get() << "."; + // cache device queue + device_queue_ = device_.GetQueue(); // cache adapter info ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // cache device limits @@ -147,10 +149,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi device_features_.insert(supported_features.features[i]); } -#if !defined(__wasm__) - supports_buffer_map_extended_usages_ = device_.HasFeature(wgpu::FeatureName::BufferMapExtendedUsages); -#endif - // create buffer manager buffer_mgr_ = BufferManagerFactory::Create(*this, buffer_cache_config.storage.mode, @@ -404,7 +402,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { } uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); - device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); + device_queue_.WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); } const auto& compute_pass_encoder = GetComputePassEncoder(); @@ -696,7 +694,7 @@ void WebGpuContext::Flush() { } auto command_buffer = current_command_encoder_.Finish(); - Device().GetQueue().Submit(1, &command_buffer); + device_queue_.Submit(1, &command_buffer); BufferManager().RefreshPendingBuffers(); current_command_encoder_ = nullptr; num_pending_dispatches_ = 0; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 2f044400afee2..8ebb122103177 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -145,8 +145,6 @@ class WebGpuContext final { Status Run(ComputeContext& context, const ProgramBase& program); void OnRunEnd(); - bool SupportsBufferMapExtendedUsages() const { return supports_buffer_map_extended_usages_; } - private: enum class TimestampQueryType { None = 0, @@ -207,6 +205,7 @@ class WebGpuContext final { webgpu::ValidationMode validation_mode_; + wgpu::Queue device_queue_; wgpu::AdapterInfo adapter_info_; wgpu::Limits device_limits_; std::unordered_set device_features_; @@ -237,7 +236,6 @@ class WebGpuContext final { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) std::unique_ptr pix_frame_generator_ = nullptr; #endif // ENABLE_PIX_FOR_WEBGPU_EP - bool supports_buffer_map_extended_usages_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 2427bf62cc658..eb65e998c81c5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -250,6 +250,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -257,9 +259,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInt class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Conv); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 21, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 22, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 21, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 22, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); @@ -578,21 +582,25 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc new file mode 100644 index 0000000000000..9b16767475c0c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/webgpu_utils.h" +namespace onnxruntime { +namespace webgpu { + +TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components) { + // Reduce the last dimensions by components creating a new tensor shape. + TensorShapeVector shape_vector = shape.AsShapeVector(); + auto reduce_index = shape_vector.size() - 1; + // Find the last dimension that is divisible by components. + while (shape_vector[reduce_index] % components != 0 && reduce_index > 0) { + ORT_ENFORCE(components % shape_vector[reduce_index] == 0, "The components must divide dims"); + components /= shape_vector[reduce_index]; + shape_vector[reduce_index] = 1; + reduce_index--; + } + ORT_ENFORCE(reduce_index >= 0 && shape_vector[reduce_index] % components == 0, "The last non-unit dimension of the input shape must be divisible by the number of components."); + shape_vector[reduce_index] /= components; + return TensorShape(shape_vector); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 5f6f18f34b7f5..e02d9266e8a0e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -1,8 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include +#include "core/common/common.h" +#include "core/framework/tensor_shape.h" namespace onnxruntime { namespace webgpu { @@ -44,5 +47,7 @@ inline std::string MakeScalarOrVectorType(int components, std::string_view data_ } } +TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components); + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 2e733f67a888c..e50ee5738c30e 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -340,3 +340,11 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* return nullptr; API_IMPL_END } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool is_cancel) { + API_IMPL_BEGIN + options->value.SetLoadCancellationFlag(is_cancel); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5ea562ce3535..0cb361bae563b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -153,7 +153,7 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaOrJsOrDmlEp(const Graph& graph) { +static bool AreAllComputeNodesAssignedToCudaOrJsOrDmlEpWebGpuEp(const Graph& graph) { bool nodes_on_cpu_and_cuda_and_js_and_dml_eps_only = true; for (const auto& node : graph.Nodes()) { @@ -164,6 +164,7 @@ static bool AreAllComputeNodesAssignedToCudaOrJsOrDmlEp(const Graph& graph) { !(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider || node_provider == kJsExecutionProvider || + node_provider == kWebGpuExecutionProvider || node_provider == kDmlExecutionProvider) && node_provider != kCpuExecutionProvider) { nodes_on_cpu_and_cuda_and_js_and_dml_eps_only = false; @@ -383,6 +384,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options ORT_THROW_IF_ERROR(graph_transformer_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps)); + graph_transformer_mgr_.SetLoadCancellationFn(this->check_load_cancellation_fn_); #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1004,11 +1006,13 @@ common::Status InferenceSession::LoadOnnxModel(const PathString& model_uri) { std::copy(std::begin(interop_domains_), std::end(interop_domains_), std::back_inserter(domain_ptrs)); ORT_RETURN_IF_ERROR(AddCustomOpDomains(domain_ptrs)); #endif + const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; common::Status st = LoadWithLoader(loader, "model_loading_uri"); @@ -1101,7 +1105,8 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_array"); @@ -1139,7 +1144,8 @@ common::Status InferenceSession::LoadOnnxModel(ModelProto model_proto) { // This call will move model_proto to the constructed model instance return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_proto"); @@ -1172,7 +1178,8 @@ common::Status InferenceSession::Load(std::istream& model_istream, bool allow_re const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; ModelOptions model_opts(allow_released_opsets_only, - strict_shape_type_inference); + strict_shape_type_inference, + check_load_cancellation_fn_); std::string external_data_folder_path = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsModelExternalInitializersFileFolderPath, ""); @@ -1211,7 +1218,8 @@ common::Status InferenceSession::Load() { // Pass on ownership of the parsed ModelProto to the Model instance (its job here is done by this stage) return Model::Load(std::move(this->model_proto_), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(allow_released_opsets_only, strict_shape_type_inference)); + ModelOptions(allow_released_opsets_only, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_from_saved_proto"); @@ -1239,7 +1247,8 @@ common::Status InferenceSession::Load(const OrtModel& model_editor_api_model) { std::unique_ptr tmp_model; ORT_RETURN_IF_ERROR(Model::LoadFromModelEditorApiModel(model_editor_api_model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, - ModelOptions(true, strict_shape_type_inference), + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_), *session_logger_, tmp_model)); model_ = std::move(tmp_model); @@ -1283,7 +1292,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool auto graph_optimizer_registry = std::make_unique(&session_options_, execution_providers_.Get(onnxruntime::kCpuExecutionProvider), session_logger_); - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry), + check_load_cancellation_fn_); // Run Ahead Of time function inlining if (const bool disable_aot_function_inlining = @@ -1711,7 +1721,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, providers.Get(onnxruntime::kCpuExecutionProvider), &logger); - GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry), + [&sess_options]() -> bool { return sess_options.IsLoadCancellationFlagSet(); }); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, @@ -1784,6 +1795,11 @@ common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() con #pragma warning(disable : 26117) #endif common::Status InferenceSession::Initialize() { + if (session_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Session initialization canceled due to user request."); + } + Status status = Status::OK(); TimePoint tp; if (session_profiler_.IsEnabled()) { @@ -2009,6 +2025,10 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + if (session_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Session initialization canceled due to user request."); + } // Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP. // @@ -2041,6 +2061,7 @@ common::Status InferenceSession::Initialize() { onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; for (auto& it : graph_support_ep_list) { @@ -2063,12 +2084,13 @@ common::Status InferenceSession::Initialize() { if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kWebGpuExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kDmlExecutionProvider) == 0) { // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaOrJsOrDmlEp(graph)) { + if (!AreAllComputeNodesAssignedToCudaOrJsOrDmlEpWebGpuEp(graph)) { LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " << " as all compute graph nodes have not been partitioned to the " << target_ep->Type(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 5b484103c9ecf..7b5d98c38a0fa 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -781,6 +781,10 @@ class InferenceSession { // the session options are released after the individual operators are destroyed. SessionOptions session_options_; + CheckLoadCancellationFn check_load_cancellation_fn_ = [this]() { + return session_options_.IsLoadCancellationFlagSet(); + }; + /// Logging manager if provided. logging::LoggingManager* logging_manager_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 0e23d7a791bec..ac67a3ce5c1a2 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -720,22 +720,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr sess; - OrtStatus* status = nullptr; *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); - - *out = reinterpret_cast(sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + *out = reinterpret_cast(sess.release()); + return nullptr; API_IMPL_END } @@ -743,22 +732,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr sess; - OrtStatus* status = nullptr; - *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); - - *out = reinterpret_cast(sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + *out = reinterpret_cast(sess.release()); + return nullptr; API_IMPL_END } @@ -2810,6 +2787,7 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::GetModelEditorApi, &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, + &OrtApis::SessionOptionsSetLoadCancellationFlag, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9d8aeb18a782f..0a87036a0dd1d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -549,4 +549,7 @@ ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* ONNXTensorElementDataType type, _Outptr_ OrtValue** out); +ORT_API_STATUS_IMPL(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool is_cancel); + } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index fb77376809d31..8539394f321bd 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -43,6 +43,7 @@ #include "core/session/onnxruntime_c_api.h" #include "core/common/string_helper.h" #include +#include "onnx/shape_inference/implementation.h" #ifdef ENABLE_TRAINING #ifdef ENABLE_TRAINING_TORCH_INTEROP @@ -771,6 +772,12 @@ struct ProviderHostImpl : ProviderHost { int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); } ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); } + void InferShapes(const std::string& m, const std::string& save_path) override { + return ONNX_NAMESPACE::shape_inference::InferShapes(m, save_path); + } + void InferShapes(ONNX_NAMESPACE::ModelProto& m) override { + return ONNX_NAMESPACE::shape_inference::InferShapes(m); + } void RegisterSchema(const std::string& domain, const OrtCustomOp* op) override { auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); const auto& domain_to_version_map = domain_instance.Map(); @@ -1268,6 +1275,7 @@ struct ProviderHostImpl : ProviderHost { const Graph* Graph__ParentGraph(const Graph* p) const override { return p->ParentGraph(); } Graph* Graph__MutableParentGraph(Graph* p) override { return p->MutableParentGraph(); } const std::string& Graph__Name(const Graph* p) const noexcept override { return p->Name(); } + void Graph__SetName(Graph* p, const std::string& name) const noexcept override { return p->SetName(name); } const std::filesystem::path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); } bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 975502063ac2a..a069cfa0b4713 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1753,6 +1753,12 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") options->value.execution_mode = execution_mode; }, R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc") + .def( + "set_load_cancellation_flag", + [](PySessionOptions* options, bool value) -> void { + options->value.SetLoadCancellationFlag(value); + }, + R"pbdoc(Request inference session load cancellation)pbdoc") .def_property( "execution_order", [](const PySessionOptions* options) -> ExecutionOrder { return options->value.execution_order; }, diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index 63bf6410f86c3..fe93f5cd358bf 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -126,7 +126,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if node_before_layer_norm is None: continue child = self.model.find_first_child_by_type( - node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False + node_before_layer_norm, + "LayerNormalization", + input_name_to_nodes, + False, ) if child is None: continue @@ -146,19 +149,26 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], - [1, 1, 0, 0, 0], + [1, None, 0, 0, 0], ) if qkv_nodes is None: logger.debug("fuse_attention: failed to match qkv path") return - - reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[2], qkv_nodes[3], qkv_nodes[-1] + reshape_qkv, transpose_qkv, matmul_qkv = ( + qkv_nodes[2], + qkv_nodes[3], + qkv_nodes[-1], + ) v_nodes = self.model.match_parent_path( - matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None] + matmul_qkv, + ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, None], ) if v_nodes is None: - v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1]) + v_nodes = self.model.match_parent_path( + matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None] + ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return @@ -182,17 +192,30 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) if qk_nodes is None: qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) - if qk_nodes is None: - qk_nodes = self.model.match_parent_path( - matmul_qkv, ["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0, 0, 0] - ) - if qk_nodes is None: - logger.debug("fuse_attention: failed to match qk path") - return - else: - add_mask = qk_nodes[3] - else: + if qk_nodes is not None: add_mask = qk_nodes[1] + else: + # If attention mask is not used, we can still match the qk path. + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) + if qk_nodes is None: + # Cast nodes are added in the model for fp16. + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"], + [0, 0, 0, 0, 0, 0], + ) + if qk_nodes is not None: + add_mask = qk_nodes[3] + else: + # If attention mask is not used, we can still match the qk path. + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Cast", "Cast", "Softmax", "Mul", "MatMul"], + [0, 0, 0, 0, 0], + ) + if qk_nodes is None: + logger.debug("fuse_attention: failed to match qk path") + return else: assert len(add_mask_indices) == 1 causal_mask_input_index = 1 - add_mask_indices[0] @@ -201,10 +224,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): matmul_qk = qk_nodes[-1] q_nodes = self.model.match_parent_path( - matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None] + matmul_qk, + ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], + [0, 0, 0, 0, None, None], ) if q_nodes is None: - q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 1]) + q_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None] + ) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return @@ -216,10 +243,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_q, matmul_q = q_nodes[-2], q_nodes[-1] k_nodes = self.model.match_parent_path( - matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None] + matmul_qk, + ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, 0, None], ) if k_nodes is None: - k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1]) + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None] + ) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return @@ -242,7 +273,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 4D Add after Q x K' add_qk_nodes = self.model.match_parent_path( add_mask, - ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze", "Reshape", "Reshape", "Cast"], + [ + "Where", + "Sub", + "Cast", + "Expand", + "Unsqueeze", + "Unsqueeze", + "Reshape", + "Reshape", + "Cast", + ], [1, 2, 1, 0, 0, 0, 0, 0, 0], ) if add_qk_nodes is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_fastgelu.py b/onnxruntime/python/tools/transformers/fusion_fastgelu.py index 210f10e2eadd4..728bd03244758 100644 --- a/onnxruntime/python/tools/transformers/fusion_fastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_fastgelu.py @@ -177,13 +177,12 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict return mul_after_mul_half = children[0] + # root_node could be None when root_input is graph input root_node = self.model.get_parent( mul_after_mul_half, 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1, output_name_to_node, ) - if root_node is None: - return mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node) if mul_before_tanh is None: @@ -197,7 +196,13 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict if add_before_tanh is None: return - mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node]) + mul_after_pow = self.model.match_parent( + add_before_tanh, + "Mul", + None, + output_name_to_node, + exclude=[root_node] if root_node else [], + ) if mul_after_pow is None: return @@ -212,7 +217,9 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict if not self.model.has_constant_input(pow, 3.0): return - if pow.input[0] != root_node.output[0]: + root_input = mul_after_mul_half.input[0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1] + + if pow.input[0] != root_input: return subgraph_nodes = [ @@ -236,7 +243,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict self.nodes_to_remove.extend(subgraph_nodes) fused_node = helper.make_node( "FastGelu", - inputs=[root_node.output[0]], + inputs=[root_input], outputs=mul_after_mul_half.output, name=self.model.create_node_name("FastGelu"), ) diff --git a/onnxruntime/test/contrib_ops/fused_conv_test.cc b/onnxruntime/test/contrib_ops/fused_conv_test.cc index e6fe0ec0e45a3..0dd69a49972e8 100644 --- a/onnxruntime/test/contrib_ops/fused_conv_test.cc +++ b/onnxruntime/test/contrib_ops/fused_conv_test.cc @@ -33,14 +33,16 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, bool disable_cpu = false, bool disable_cuda = false, bool disable_rocm = false, + bool disable_webgpu = false, bool use_float16 = false, bool weight_is_initializer = false) { bool enable_cuda = HasCudaEnvironment(0) && !use_float16 && !disable_cuda; // Only ROCm EP supports float16. bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !disable_rocm; + bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; - if (enable_cuda || enable_rocm || enable_cpu) { + if (enable_cuda || enable_rocm || enable_cpu || enable_webgpu) { OpTester test("FusedConv", 1, onnxruntime::kMSDomain); test.AddAttribute("group", attributes.group); test.AddAttribute("kernel_shape", attributes.kernel_shape); @@ -96,6 +98,10 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, execution_providers.push_back(DefaultRocmExecutionProvider()); } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } + if (enable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } @@ -110,15 +116,16 @@ void RunConvOp(const ConvOpAndTestAttributes& attributes, const vector& expected_output_shape, bool disable_cpu = false, bool disable_cuda = false, - bool disable_rocm = false) { + bool disable_rocm = false, + bool disable_webgpu = false) { bool weight_is_initializer = false; bool use_float16 = false; TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, - disable_cpu, disable_cuda, disable_rocm, use_float16, weight_is_initializer); + disable_cpu, disable_cuda, disable_rocm, disable_webgpu, use_float16, weight_is_initializer); use_float16 = true; TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, - disable_cpu, disable_cuda, disable_rocm, use_float16, weight_is_initializer); + disable_cpu, disable_cuda, disable_rocm, disable_webgpu, use_float16, weight_is_initializer); } TEST(FusedConvTest, Conv2D_HardSigmoid) { @@ -139,7 +146,7 @@ TEST(FusedConvTest, Conv2D_HardSigmoid) { vector W_shape = {2, 1, 2, 2}; vector Y_shape = {1, 2, 2, 2}; auto expected_vals = {0.8f, 0.9f, 1.0f, 1.0f, 0.2f, 0.1f, 0.0f, 0.0f}; - RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, true, true); + RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, true, true, true); } TEST(FusedConvTest, Conv2D_Relu) { @@ -233,7 +240,7 @@ TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) { vector Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}; vector Z_shape = {1, 2, 2, 2}; auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f}; - RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, false, true, true); + RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, false, true, true, true); } #endif diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index c4536fc56a22f..0dfe194e893e2 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -15,6 +15,27 @@ namespace onnxruntime { namespace test { +// When uint8_t data type is used GatherBlockQuantize applies MatMulNBit's conventions for storing the data. +// That is when no zero points are specified a default zero point of 8 is used. This convertor hence +// compensates for that by adding 8 to the data values, so that the outputs match the results that +// we be seen with non uint8_t data types. +template +void PackDataForUint8TypeIfNecessary(std::vector& data, std::vector& data_shape) { + if (!std::is_same_v) { + return; + } + // For uint8_t, we need to pack each pair of values (after adding 8) into a single uint8_t + std::vector packed_data; + for (size_t i = 0; i < data.size(); i += 2) { + int low_nibble = (data[i] + 8) & 0xF; + int high_nibble = ((i + 1) < data.size()) ? ((data[i + 1] + 8) & 0xF) : 0; + int packed = (high_nibble << 4) | low_nibble; + packed_data.push_back(packed); + } + data = packed_data; + data_shape[data_shape.size() - 1] = (data_shape[data_shape.size() - 1] + 1) / 2; +} + // Combinations: types, gather_axis, quantize_axis, block_size, indices, scale shape vs data shape template void RunGatherBlockQuantized(const std::vector& data, @@ -96,6 +117,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -123,7 +145,6 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); - Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); @@ -134,21 +155,70 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); + Test_Fail_WithZeroPoints(0, 2, 16); +} + +template +void Test_Fail_WithoutZeroPoints(int64_t gather_axis, + int64_t quantize_axis, + int64_t block_size) { + std::vector data = {-8, -7, -6, -5, + -4, -3, -2, -1, + 0, 1, 2, 3, + 4, 5, 6, 7, + 4, 5, 6, 7, + -4, -3, -2, -1}; + std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + std::vector output = {8.f, 10.f, 12.f, 14.f, + 3.f, 4.f, 5.f, 6.f, + -6.f, -4.f, -2.f, 0.f}; + std::vector output_shape = {1, 3, 4}; + + RunGatherBlockQuantized(ToType(data), + data_shape, + ToType(indices), + indices_shape, + ToType(scales), + scales_shape, + {}, + gather_axis, + quantize_axis, + block_size, + ToType(output), + output_shape, + OpTester::ExpectResult::kExpectFailure); +} + +TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) { + // T1 uint8_t with zero points is not yet supported. + Test_Fail_WithZeroPoints(0, 2, 16); + Test_Fail_WithZeroPoints(0, 2, 16); + // Gather on axis other than 0 is not supported with uint8_t + Test_Fail_WithoutZeroPoints(1, 2, 16); + Test_Fail_WithoutZeroPoints(1, 2, 16); } TEST(GatherBlockQuantizedOpTest, InvalidBlockSize) { Test_Fail_WithZeroPoints(0, 2, 8); Test_Fail_WithZeroPoints(0, 2, 17); + Test_Fail_WithZeroPoints(0, 2, 17); } TEST(GatherBlockQuantizedOpTest, InvalidGatherAxis) { Test_Fail_WithZeroPoints(3, 2, 16); Test_Fail_WithZeroPoints(-4, 2, 16); + Test_Fail_WithZeroPoints(-4, 2, 16); } TEST(GatherBlockQuantizedOpTest, InvalidQuantizeAxis) { Test_Fail_WithZeroPoints(0, 3, 16); Test_Fail_WithZeroPoints(0, -4, 16); + Test_Fail_WithZeroPoints(0, -4, 16); } template @@ -160,6 +230,7 @@ void Test_ShapeMismatch_WithZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f}; @@ -188,6 +259,7 @@ void Test_ShapeMismatch_WithZeroPoints() { TEST(GatherBlockQuantizedOpTest, ShapeMismatch) { Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); + Test_ShapeMismatch_WithZeroPoints(); } template @@ -199,6 +271,7 @@ void Test_InvalidIndices_WithZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {2}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -227,6 +300,7 @@ void Test_InvalidIndices_WithZeroPoints() { TEST(GatherBlockQuantizedOpTest, InvalidIndices) { Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); + Test_InvalidIndices_WithZeroPoints(); } template @@ -298,6 +372,7 @@ void Test_GatherAxis0_NoZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -340,6 +415,10 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) { Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); } template diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 95101c8075fc2..dc776f74d8758 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -498,6 +499,30 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } +TEST(InferenceSessionTests, RequestLoadCancellation) { + { + // Explicit cancel during load, small model is fine + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestLoadCancellation"; + + const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx"); + InferenceSession session_object{so, GetEnvironment()}; + so.SetLoadCancellationFlag(true); + ASSERT_FALSE(session_object.Load(model_uri).IsOK()); + } + { + // Explicit cancel during initialize, small model is fine + const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx"); + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestLoadCancellation"; + so.SetLoadCancellationFlag(false); + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + so.SetLoadCancellationFlag(true); + ASSERT_FALSE(session_object.Initialize().IsOK()); + } +} + #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a66964de17c72..a33b3148014f1 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2204,6 +2204,10 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kJsExecutionProvider); } +#elif defined(USE_WEBGPU) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); + } #else for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCpuExecutionProvider); @@ -2232,6 +2236,10 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kJsExecutionProvider); } +#elif defined(USE_WEBGPU) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); + } #else for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCpuExecutionProvider); @@ -2330,6 +2338,10 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kJsExecutionProvider); } +#elif defined(USE_WEBGPU) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); + } #else for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCpuExecutionProvider); @@ -2351,6 +2363,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { } else { ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); } +#elif defined(USE_WEBGPU) + std::set webgpu_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu", "HardSigmoid"}; + if (webgpu_supported.find(model.second) == webgpu_supported.end()) { + ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); + } else { + ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); + } #else ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); #endif diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 1630f63822b6a..b685b170c163f 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -421,12 +421,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "Select from 'gpu', or 'npu' \n"); } } else if (key == "performance_preference") { - std::set ov_supported_values = {"default", "high_performance", "minimal_power"}; + std::set ov_supported_values = {"default", "high_performance", "minimum_power"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. " - "Select from 'default', 'high_performance' or 'minimal_power' \n"); + "Select from 'default', 'high_performance' or 'minimum_power' \n"); } } else if (key == "disable_metacommands") { std::set ov_supported_values = {"true", "True", "false", "False"}; diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 9c0b779870c70..a5fd37361a255 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -576,11 +576,10 @@ TEST(Loop, InfiniteLoopTermination) { test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}, &session_run_options); // Disable TensorRT on unsupported data type BOOL - // call get to propagate any exception - terminator_result.get(); - // done with the thread terminator_thread.join(); + // call get to propagate any exception + terminator_result.get(); } // Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index d1350db8ec12e..1404071928e09 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -518,7 +518,7 @@ TEST(ConvFp16Test, Conv3D_1) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { @@ -557,7 +557,7 @@ TEST(ConvFp16Test, Conv3D_2) { vector{1, 1, 1}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { @@ -601,7 +601,7 @@ TEST(ConvFp16Test, Conv3D_Bias) { vector{2, 2, 2}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { @@ -1082,7 +1082,7 @@ TEST(ConvFp16Test, Pointwise_3D) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index a3a3dd939cbf0..06434d5b59ec6 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -489,7 +489,7 @@ TEST(ConvTest, Conv3D_1) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = {-0.43337246775627136f, -0.48385289311408997f, -0.30954962968826294f, @@ -526,7 +526,7 @@ TEST(ConvTest, Conv3D_2) { vector{1, 1, 1}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = {0.010772407054901123f, -0.43806642293930054f, 0.455391526222229f, -0.28657248616218567f, @@ -569,7 +569,7 @@ TEST(ConvTest, Conv3D_Bias) { vector{2, 2, 2}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = {0.46796226501464844f, -0.4613912105560303f, 0.33512794971466064f, -0.4010460674762726f, @@ -916,7 +916,7 @@ TEST(ConvTest, ConvDimWithZero) { vector{1, 1}, // kernel_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = vector(); diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 83b27f10fe04f..198fa07ae4ed0 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -933,7 +933,7 @@ TEST(ConvTransposeTest, DimWithZero) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, - kAclExecutionProvider, kQnnExecutionProvider}); + kAclExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_3D) { @@ -1068,7 +1068,7 @@ TEST(ConvTransposeTest, ConvTranspose_3D) { TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCudaExecutionProvider, - kCudaNHWCExecutionProvider, kQnnExecutionProvider}); + kCudaNHWCExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) { diff --git a/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp16.onnx b/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp16.onnx new file mode 100644 index 0000000000000..49e805169ee45 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp16.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp32.onnx b/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp32.onnx new file mode 100644 index 0000000000000..7fca335f13731 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp32.onnx differ diff --git a/onnxruntime/test/python/transformers/test_gelu_fusions.py b/onnxruntime/test/python/transformers/test_gelu_fusions.py index 94b969ad5377d..11ae1401ff8ed 100644 --- a/onnxruntime/test/python/transformers/test_gelu_fusions.py +++ b/onnxruntime/test/python/transformers/test_gelu_fusions.py @@ -3,6 +3,7 @@ import unittest import torch +from parameterized import parameterized from parity_utilities import find_transformers_source if find_transformers_source(): @@ -43,16 +44,6 @@ def forward(self, x): return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) -test_cases = [ - ("huggingface", "Gelu", HuggingfaceGelu), - ("huggingface", "FastGelu", HuggingfaceFastGelu), - ("huggingface", "QuickGelu", HuggingfaceQuickGelu), - ("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), - ("megatron", "Gelu", MegatronGelu), - ("megatron", "FastGelu", MegatronFastGelu), -] - - class TestGeluFusions(unittest.TestCase): def verify_node_count(self, bert_model, expected_node_count, test_name): for op_type, count in expected_node_count.items(): @@ -62,25 +53,47 @@ def verify_node_count(self, bert_model, expected_node_count, test_name): print(f"{op}: {len(bert_model.get_nodes_by_op_type(op))} expected={counter}") self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count) - def test_fusions(self): - for test_case in test_cases: - source, operator, model_class = test_case - model = model_class() - dummy_input = torch.ones(3, dtype=torch.float32) - test_name = f"{operator}_{source}" - onnx_path = f"{test_name}.onnx" - torch.onnx.export( - model, - (dummy_input), - onnx_path, - input_names=["input"], - output_names=["output"], - ) - optimizer = optimize_model(onnx_path, "bert") - # optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx") - os.remove(onnx_path) - expected_node_count = {operator: 1} - self.verify_node_count(optimizer, expected_node_count, test_name) + @parameterized.expand( + [ + (("huggingface", "Gelu", HuggingfaceGelu), True), + (("huggingface", "FastGelu", HuggingfaceFastGelu), True), + (("huggingface", "QuickGelu", HuggingfaceQuickGelu), True), + (("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), True), + (("megatron", "Gelu", MegatronGelu), True), + (("megatron", "FastGelu", MegatronFastGelu), True), + (("huggingface", "Gelu", HuggingfaceGelu), False), + (("huggingface", "FastGelu", HuggingfaceFastGelu), False), + (("huggingface", "QuickGelu", HuggingfaceQuickGelu), False), + (("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), False), + (("megatron", "Gelu", MegatronGelu), False), + (("megatron", "FastGelu", MegatronFastGelu), False), + ] + ) + def test_fusions(self, test_case, dynamo): + source, operator, model_class = test_case + model = model_class() + dummy_input = torch.ones(3, dtype=torch.float32) + test_name = f"{operator}_{source}" + onnx_path = f"{test_name}.onnx" + torch.onnx.export( + model, + (dummy_input,), + onnx_path, + input_names=["input"], + output_names=["output"], + dynamo=dynamo, + optimize=True, # Only meaningful when dynamo is True + ) + optimizer = optimize_model(onnx_path, "bert") + # optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx") + os.remove(onnx_path) + # Remove the associated .data file (dynamo) + data_path = onnx_path + ".data" + if os.path.exists(data_path): + os.remove(data_path) + expected_node_count = {operator: 1} + + self.verify_node_count(optimizer, expected_node_count, test_name) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_gemma3_vision.py b/onnxruntime/test/python/transformers/test_gemma3_vision.py new file mode 100644 index 0000000000000..4727d2c8030d2 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_gemma3_vision.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest + +import onnx +import torch +from parameterized import parameterized +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from dynamo_onnx_helper import DynamoOnnxHelper + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py#L363 +class SiglipAttention(torch.nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self): + super().__init__() + self.embed_dim = 20 + self.num_heads = 2 + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + + self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + + self.k_proj.weight.data.fill_(1) + self.v_proj.weight.data.fill_(1) + self.q_proj.weight.data.fill_(1) + self.out_proj.weight.data.fill_(1) + self.k_proj.bias.data.fill_(1) + self.v_proj.bias.data.fill_(1) + self.q_proj.bias.data.fill_(1) + self.out_proj.bias.data.fill_(1) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Gemma3VSIGLIPAttentionAndLayerNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn = SiglipAttention() + self.ln = torch.nn.LayerNorm(20, eps=1e-05) + + def forward(self, x): + # SkipLayerNorm ------+ + # | | + # Attention | + # | | + # MatMul | + # | | + # SkipLayerNorm ------+ + + # SkipLayerNorm + x = x + x + x = self.ln(x) + residual = x + + # Attention + MatMul + x, _ = self.attn(x) + + # SkipLayerNorm + x = residual + x + x = self.ln(x) + return x + + +class TestFusion(unittest.TestCase): + def verify_fusion(self, optimized_model, expected_model_filename): + optimized_model.topological_sort(is_deterministic=True) + + expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), + expected_initializer, + ) + ) + + def export(self, model, inputs) -> onnx.ModelProto: + with torch.no_grad(): + onnx_program = torch.onnx.export( + model, + args=inputs, + # f=os.path.join(os.path.dirname(__file__), "export.onnx"), + dynamo=True, + optimize=True, + ) + return onnx_program.model_proto # type: ignore + + def tearDown(self): + paths = [ + os.path.join(os.path.dirname(__file__), "export.onnx"), + os.path.join(os.path.dirname(__file__), "export.onnx.data"), + ] + for path in paths: + if os.path.exists(path): + os.remove(path) + + @parameterized.expand( + [ + (torch.float32, "gemma3-vision-attention_fp32.onnx"), + (torch.float16, "gemma3-vision-attention_fp16.onnx"), + ] + ) + def test_gemma3_vision_attention(self, dtype, model_name): + model = Gemma3VSIGLIPAttentionAndLayerNorm().eval().to(dtype) + inputs = (torch.randn(1, 2, 20, dtype=dtype),) + original_model = self.export(model, inputs) + + # TODO(titaiwang): Upstream these processings to onnxscript pass + onnx_model_wrapper = DynamoOnnxHelper(original_model) + onnx_model_wrapper.convert_constants_to_initializers() + onnx_model_wrapper.clear_metadata() + model_path = os.path.join(os.path.dirname(__file__), "export.onnx") + onnx_model_wrapper.model.save_model_to_file( + model_path, + use_external_data_format=True, + all_tensors_to_one_file=True, + convert_attribute=True, + ) + + options = FusionOptions("clip") + optimized_model = optimize_model( + model_path, + model_type="clip", + num_heads=2, + hidden_size=20, + optimization_options=options, + opt_level=0, + ) + self.verify_fusion(optimized_model, model_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b517ba7032886..e00606af1c086 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -4669,6 +4670,34 @@ TEST(CApiTest, RunBaseLoraModel) { } } +TEST(CApiTest, RequestLoadCancellation) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions session_options; + + auto terminator = [&session_options]() { + session_options.SetLoadCancellationFlag(true); + return; + }; + + std::packaged_task task{terminator}; + std::future terminator_result = task.get_future(); + std::thread terminator_thread{std::move(task)}; + bool terminated = false; + try { + Ort::Session session(env, model_path, session_options); + } catch (const Ort::Exception& ex) { + terminated = OrtErrorCode::ORT_MODEL_LOAD_CANCELED == ex.GetOrtErrorCode(); + } + // done with the thread + terminator_thread.join(); + + // call get to propagate any exception + terminator_result.get(); + + ASSERT_TRUE(terminated); +} + struct MockGQA : public OrtCustomOp { MockGQA() { OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) { diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 722e7696ba738..8f1189b05858c 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -147,6 +147,29 @@ extends: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + - template: stages/nodejs-win-packaging-stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + ArtifactName: 'drop-onnxruntime-nodejs-win-x64' + StageName: 'Windows_Nodejs_Packaging_x64' + BuildCommand: --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildArch: 'x64' + EnvSetupScript: 'setup_env.bat' + sln_platform: 'x64' + DoEsrp: ${{ parameters.DoEsrp }} + PublishWebGpuBuildTools: true + + - template: stages/nodejs-win-packaging-stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + ArtifactName: 'drop-onnxruntime-nodejs-win-arm64' + StageName: 'Windows_Nodejs_Packaging_arm64' + BuildCommand: --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildArch: 'x64' + EnvSetupScript: 'setup_env.bat' + sln_platform: 'arm64' + DoEsrp: ${{ parameters.DoEsrp }} + DependsOnStageName: Windows_Nodejs_Packaging_x64 - template: nuget/templates/dml-vs-2022.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml new file mode 100644 index 0000000000000..e1247565d8f5b --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml @@ -0,0 +1,57 @@ +parameters: +- name: CudaVersion + type: string + default: '12.2' + +stages: +- stage: Linux_Nodejs_Packaging_x64 + dependsOn: [] + jobs: + - job: Linux_Nodejs_Packaging_x64 + dependsOn: [] + workspace: + clean: all + timeoutInMinutes: 180 + pool: + name: 'onnxruntime-Ubuntu2204-AMD-CPU' + os: linux + variables: + - template: ../templates/common-variables.yml + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + - name: CUDA_VERSION + value: ${{ parameters.CudaVersion }} + - name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: ${{ variables.linux_trt_version_cuda12 }} + steps: + - checkout: self + clean: true + submodules: recursive + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }} + DockerBuildArgs: " + --build-arg TRT_VERSION=${{ variables.linux_trt_version }} + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build + - template: ../templates/set-version-number-variables-step.yml + + - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_nodejs_package.sh + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Build Node.js binding Package' + + - template: ../templates/nodejs-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'x64' + os: 'linux' + artifactName: 'drop-onnxruntime-nodejs-linux-x64' + + - template: ../templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml new file mode 100644 index 0000000000000..73e650eb07992 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml @@ -0,0 +1,192 @@ +parameters: + BuildCommand: '' + StageName: 'Windows_Nodejs_Packaging' + ArtifactName: 'drop-onnxruntime-nodejs-win' + DoEsrp: 'false' + BuildArch: 'x64' # Optional. Options: x86, x64 + sln_platform: 'x64' # Options: Win32, x64, arm, arm64 + AgentDemands: [] + BuildConfigurations: ['RelWithDebInfo'] # Options: Debug, RelWithDebInfo + EnableLto: true + # Controls whether unreleased onnx opsets are allowed. Default is set to 1 + AllowReleasedOpsetOnly: '0' + IsReleaseBuild: false + PublishWebGpuBuildTools: false + WebGpuBuildToolsArtifactName: 'Windows_WebGPU_BuildTools_x64' + DependsOnStageName: '' + +stages: +- stage: ${{ parameters.StageName }} + dependsOn: + - Setup + - ${{ if ne(parameters.DependsOnStageName, '') }}: + - ${{ parameters.DependsOnStageName }} + + jobs: + - job: ${{ parameters.StageName }} + timeoutInMinutes: 200 + strategy: + maxParallel: 2 + matrix: + ${{ each BuildConfiguration in parameters.BuildConfigurations }}: + ${{ BuildConfiguration }}: + BuildConfig: ${{ BuildConfiguration }} + workspace: + clean: all + pool: + name: onnxruntime-Win-CPU-2022 + demands: ${{ parameters.AgentDemands }} + variables: + buildDirectory: '$(Build.BinariesDirectory)' + OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' + runCodesignValidationInjection: ${{ parameters. DoEsrp}} #For the others, code sign is in a separated job + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + BuildCommandExtra: '' + ${{ if eq(parameters.EnableLto, true) }}: + build_py_lto_flag: --enable_lto + + steps: + - checkout: self + clean: true + submodules: none + + - powershell: | + if($env:TELEMETRYGUID) + { + $length = $env:TELEMETRYGUID.length + $fileContent = "#define TraceLoggingOptionMicrosoftTelemetry() \ + TraceLoggingOptionGroup("+$env:TELEMETRYGUID.substring(1, $length-2)+")" + New-Item -Path "$(Build.SourcesDirectory)\include\onnxruntime\core\platform\windows\TraceLoggingConfigPrivate.h" -ItemType "file" -Value "$fileContent" -Force + Write-Output "Enabling TELEMETRY" + } + displayName: 'Create TraceLoggingConfigPrivate.h For WinML Telemetry' + env: + TELEMETRYGUID: $(TELEMETRYGUID) + + - task: NodeTool@0 + inputs: + versionSpec: '20.x' + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: ${{ parameters.BuildArch }} + + # need to set PROCESSOR_ARCHITECTURE so the x86 SDK is installed correctly + - task: UseDotNet@2 + inputs: + version: 8.x + env: + PROCESSOR_ARCHITECTURE: ${{ parameters.BuildArch }} + + - task: BatchScript@1 + displayName: 'Setup VS2022 env vars' + inputs: + filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' + arguments: ${{ parameters.BuildArch }} + modifyEnvironment: true + + - ${{ if and(ne(parameters.WebGpuBuildToolsArtifactName, ''), eq(parameters.sln_platform, 'arm64')) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download WebGPU build tools from x64 build' + inputs: + artifactName: '${{ parameters.WebGpuBuildToolsArtifactName }}' + targetPath: '$(Build.BinariesDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}' + - script: | + @echo ##vso[task.setvariable variable=LLVM_TABLEGEN_PATH]$(Build.BinariesDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}\llvm-tblgen.exe + @echo ##vso[task.setvariable variable=CLANG_TABLEGEN_PATH]$(Build.BinariesDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}\clang-tblgen.exe + displayName: 'Set tablegen paths' + - powershell: | + Write-Host "Using LLVM_TABLEGEN_PATH: $(LLVM_TABLEGEN_PATH)" + Write-Host "Using CLANG_TABLEGEN_PATH: $(CLANG_TABLEGEN_PATH)" + Write-Host "##vso[task.setvariable variable=BuildCommandExtra]--cmake_extra_defines LLVM_TABLEGEN=$(LLVM_TABLEGEN_PATH) CLANG_TABLEGEN=$(CLANG_TABLEGEN_PATH)" + displayName: 'Set build flags for WebGPU cross-compilation' + + - powershell: | + python tools\ci_build\build.py --build_dir $(Build.BinariesDirectory) ${{ parameters.BuildCommand }} $(BuildCommandExtra) --use_binskim_compliant_compile_flags --parallel --build --update --config $(BuildConfig) --msbuild_extra_options IncludeMobileTargets=false ${{ variables.build_py_lto_flag }} + + - ${{ if notIn(parameters['sln_platform'], 'Win32', 'x64') }}: + # Use cross-compiled protoc + - script: | + @echo ##vso[task.setvariable variable=ProtocDirectory]$(Build.BinariesDirectory)\installed\bin + + # The Configuration variable is required to build C# + - script: | + @echo ##vso[task.setvariable variable=Configuration]$(BuildConfig) + displayName: 'Set Configuration variable' + + # Node.js Publish + - task: BatchScript@1 + displayName: 'Setup VS env vars' + inputs: + filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' + arguments: ${{ parameters.BuildArch }} + modifyEnvironment: true + - task: CopyFiles@2 + displayName: 'Copy DirectML binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' + Contents: 'DirectML.dll' + TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + - powershell: | + $dxcZipUrl = "https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.8.2502/dxc_2025_02_20.zip" + $dxcZipPath = "$(Build.BinariesDirectory)\dxc.zip" + $dxcExtractPath = "$(Build.BinariesDirectory)\dxc_extracted" + $targetArch = "${{ parameters.sln_platform }}" + + # Download the DXC package + Write-Host "Downloading DXC release from $dxcZipUrl" + Invoke-WebRequest -Uri $dxcZipUrl -OutFile $dxcZipPath + + # Create extraction directory + if (-not (Test-Path $dxcExtractPath)) { + New-Item -Path $dxcExtractPath -ItemType Directory -Force + } + + # Extract the zip file + Write-Host "Extracting DXC package to $dxcExtractPath" + Expand-Archive -Path $dxcZipPath -DestinationPath $dxcExtractPath -Force + + # Copy the necessary DLLs to the target directory + $sourcePath = Join-Path $dxcExtractPath "bin\$targetArch" + $targetPath = "$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\$targetArch" + + Write-Host "Copying dxil.dll and dxcompiler.dll from $sourcePath to $targetPath" + Copy-Item -Path "$sourcePath\dxil.dll" -Destination $targetPath -Force + Copy-Item -Path "$sourcePath\dxcompiler.dll" -Destination $targetPath -Force + + Write-Host "DXC DLLs successfully copied to the target directory" + displayName: 'Download and Copy DXC Binaries' + - template: ../templates/win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + DisplayName: 'ESRP - Sign Node.js binding binaries' + DoEsrp: ${{ parameters.DoEsrp }} + Pattern: '*.dll,*.node' + + - script: | + del /Q $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}\CodeSignSummary-*.* + call npm pack + copy $(Build.SourcesDirectory)\js\node\onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) + workingDirectory: '$(Build.SourcesDirectory)\js\node' + displayName: 'Create NPM Package' + + - task: 1ES.PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + artifactName: ${{ parameters.ArtifactName }} + + - ${{ if and(eq(parameters.PublishWebGpuBuildTools, true), eq(parameters.sln_platform, 'x64')) }}: + - script: | + mkdir $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} + copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\llvm-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} + copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\clang-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} + displayName: 'Copy WebGPU build tools' + - task: 1ES.PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}' + artifactName: ${{ parameters.WebGpuBuildToolsArtifactName }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 893bf3f1ec394..a4fe78a7088e3 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -39,6 +39,11 @@ stages: buildJava: ${{ parameters.buildJava }} buildNodejs: ${{ parameters.buildNodejs }} +- ${{ if eq(parameters.buildNodejs, 'true') }}: + - template: nodejs-linux-packaging-stage.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + - template: nuget-win-cuda-packaging-stage.yml parameters: RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 8560817331475..e36fe98fe0ac2 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -32,9 +32,7 @@ stages: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }} - DockerBuildArgs: " - --build-arg BUILD_UID=$( id -u ) - " + DockerBuildArgs: " --build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}build - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -113,13 +111,6 @@ stages: nativeLibraryName: 'libonnxruntime4j_jni.so' is1ES: true - - ${{ if eq(parameters.buildNodejs, 'true') }}: - - template: ../templates/nodejs-artifacts-package-and-publish-steps-posix.yml - parameters: - arch: 'x64' - os: 'linux' - artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' - - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml parameters: buildConfig: 'Release' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index bb789edc1cf21..9b1d7b705e741 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -182,10 +182,10 @@ stages: buildArch: x64 msbuildPlatform: arm64 packageName: arm64 - buildparameter: --build_nodejs --arm64 ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} + buildparameter: --arm64 ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} runTests: false buildJava: false - buildNodejs: true + buildNodejs: false - template: win-ci.yml parameters: @@ -194,10 +194,10 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64 - buildparameter: --build_java --build_nodejs ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} + buildparameter: --build_java ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true - buildNodejs: true + buildNodejs: false - stage: Jar_Packaging dependsOn: @@ -506,14 +506,11 @@ stages: - stage: Nodejs_Packaging dependsOn: - - Windows_CI_GPU_DML_Dev - - Windows_CI_GPU_DML_Dev_arm64 + - Windows_Nodejs_Packaging_x64 + - Windows_Nodejs_Packaging_arm64 + - Linux_Nodejs_Packaging_x64 - Linux_C_API_Packaging_CPU - - Linux_C_API_Packaging_GPU - MacOS_C_API_Package_Publish - - Windows_Packaging_CPU_x86_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} condition: succeeded() jobs: - job: Nodejs_Packaging @@ -544,74 +541,78 @@ stages: # Node.js binding artifacts preparation # # This stage prepares Node.js binding artifacts for publishing. The artifacts support the following platforms: - # - Windows x64 with DML support - # - Windows arm64 with DML support - # - Linux x64 with TensorRT support + # - Windows x64 (CPU, DML, WebGPU) + # - Windows arm64 (CPU, DML, WebGPU) + # - Linux x64 (CPU, CUDA, TensorRT, WebGPU) # - Linux arm64 (CPU only) - # - macOS x64 (CPU only) - # - macOS arm64 (CPU only) + # - macOS x64 (CPU, CoreML, WebGPU) + # - macOS arm64 (CPU, CoreML, WebGPU) + # + # File manifest: + # - Windows x64 (CPU, DML, WebGPU): + # dependency: Windows_Nodejs_Packaging_x64 (drop-onnxruntime-nodejs-win-x64) + # files: + # - onnxruntime_binding.node + # - onnxruntime.dll + # - DirectML.dll + # - dxil.dll + # - dxcompiler.dll + # + # - Windows arm64 (CPU, DML, WebGPU): + # dependency: Windows_Nodejs_Packaging_arm64 (drop-onnxruntime-nodejs-win-arm64) + # files: + # - onnxruntime_binding.node + # - onnxruntime.dll + # - DirectML.dll + # - dxil.dll + # - dxcompiler.dll # - # ORT Node.js binding artifacts contain 2 parts: - # 1. ONNX Runtime native shared libraries and their dependencies - # - Windows (x64, arm64): - # - onnxruntime.dll - # - DirectML.dll - # - Linux (x64, arm64): - # - libonnxruntime.so{.version} - # - libonnxruntime_providers_shared.so - # - libonnxruntime_providers_{provider}.so - # - macOS (x64, arm64): - # - libonnxruntime.dylib - # 2. ONNX Runtime Node.js binding - # - onnxruntime_binding.node + # - Linux x64 (CPU, CUDA, TensorRT, WebGPU): + # dependency: Linux_Nodejs_Packaging_x64 (drop-onnxruntime-nodejs-linux-x64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.so.1 + # - libonnxruntime_providers_shared.so + # - libonnxruntime_providers_cuda.so + # - libonnxruntime_providers_tensorrt.so # - # For windows platform, the artifact is named as 'onnxruntime-nodejs-win-x64-dml' for x64, and - # 'onnxruntime-nodejs-win-arm64-dml' for arm64. Each artifact contains both (1) and (2). + # - Linux arm64 (CPU only): + # dependency: Linux_C_API_Packaging_CPU_aarch64 (drop-onnxruntime-nodejs-linux-aarch64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.so.1 # - # For Linux and macOS platforms, (1) and (2) are packed into separate artifacts. - # The following artifacts contain (1): - # - onnxruntime-osx - # - onnxruntime-linux-x64-tensorrt - # - onnxruntime-linux-aarch64 - # The following artifacts contain (2): - # - drop-onnxruntime-nodejs-linux-x64-tensorrt - # - drop-onnxruntime-nodejs-linux-aarch64 - # - drop-onnxruntime-nodejs-osx-x86_64 - # - drop-onnxruntime-nodejs-osx-arm64 + # - macOS x64 (CPU, CoreML, WebGPU): + # dependency: MacOS_C_API_Packaging_CPU_x86_64 (drop-onnxruntime-nodejs-osx-x86_64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.{version}.dylib # - # All binary artifacts will eventually be put into folder before packaging 'onnxruntime-node': + # - macOS arm64 (CPU, CoreML, WebGPU): + # dependency: MacOS_C_API_Packaging_CPU_arm64 (drop-onnxruntime-nodejs-osx-arm64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.{version}.dylib + # + # The following files will be excluded from the further packaging because they are too large to be included in the + # NPM package: + # - linux/x64/libonnxruntime_providers_cuda.so + # + # Rest binary artifacts will eventually be put into folder before packaging 'onnxruntime-node': # $(Build.SourcesDirectory)\js\node\bin\napi-v3\{os}\{cpu_arch}\ # # {os} is one of 'win32', 'darwin', 'linux' and {cpu_arch} is one of 'x64', 'arm64'. - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (OSX)' - inputs: - artifactName: 'onnxruntime-osx' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Linux x64)' - inputs: - artifactName: 'onnxruntime-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Linux aarch64)' - inputs: - artifactName: 'onnxruntime-linux-aarch64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-x64-dml' + artifactName: 'drop-onnxruntime-nodejs-win-x64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/x64/' - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win ARM64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-arm64-dml' + artifactName: 'drop-onnxruntime-nodejs-win-arm64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/arm64/' - task: DownloadPipelineArtifact@0 @@ -629,7 +630,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Linux x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' + artifactName: 'drop-onnxruntime-nodejs-linux-x64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/linux/x64/' - task: DownloadPipelineArtifact@0 @@ -638,15 +639,9 @@ stages: artifactName: 'drop-onnxruntime-nodejs-linux-aarch64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/linux/arm64/' - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\extract_nuget_files.ps1 - - script: | - dir - workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' + dir /S + workingDirectory: '$(Build.BinariesDirectory)/nodejs-artifacts' displayName: 'List artifacts' - script: | @@ -683,61 +678,43 @@ stages: TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' # Node.js binding linux/x64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-tensorrt\lib' - Contents: | - libonnxruntime.so.1 - libonnxruntime_providers_shared.so - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\linux\x64' - Contents: '*.node' + Contents: | + libonnxruntime.so.1 + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64' # Node.js binding linux/arm64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-aarch64\lib' - Contents: 'libonnxruntime.so.1' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\linux\arm64' - Contents: '*.node' + Contents: | + libonnxruntime.so.1 + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64' # Node.js binding darwin/x64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-osx-x86_64\lib' - Contents: 'libonnxruntime.*.dylib' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\darwin\x64' - Contents: '*.node' + Contents: | + libonnxruntime.*.dylib + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64' # Node.js binding darwin/arm64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-osx-arm64\lib' - Contents: 'libonnxruntime.*.dylib' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\darwin\arm64' - Contents: '*.node' + Contents: | + libonnxruntime.*.dylib + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64' - task: PowerShell@2 diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index 7ac2e3a8addb6..fb1c63e1f8a24 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -34,7 +34,7 @@ stages: PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} - PackageNodeJS: ${{ parameters.PackageNodeJS }} + PackageNodeJS: false - template: c-api-linux-cpu.yml parameters: diff --git a/tools/ci_build/github/linux/build_nodejs_package.sh b/tools/ci_build/github/linux/build_nodejs_package.sh new file mode 100755 index 0000000000000..29ee91a122e39 --- /dev/null +++ b/tools/ci_build/github/linux/build_nodejs_package.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -e -x +mkdir -p $HOME/.onnx +docker run -e SYSTEM_COLLECTIONURI --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ +--volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ +/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_nodejs --use_webgpu --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/python/requirements.txt b/tools/ci_build/github/linux/python/requirements.txt index e51cfb38f57a3..1a580b848a55a 100644 --- a/tools/ci_build/github/linux/python/requirements.txt +++ b/tools/ci_build/github/linux/python/requirements.txt @@ -7,4 +7,5 @@ onnx==1.17.0 ; python_version < '3.13' protobuf==4.21.12 sympy==1.12 flatbuffers -psutil \ No newline at end of file +psutil +onnxscript==0.2.3 diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index 200b9c2e50288..2b222c4b1d4a4 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -8,3 +8,4 @@ protobuf==4.21.12 sympy==1.12 flatbuffers psutil +onnxscript==0.2.3 diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index 0fb37e3a1550a..47286c364a90f 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -11,3 +11,4 @@ parameterized>=0.8.1 sentencepiece psutil einops +onnxscript==0.2.3