diff --git a/.github/actions/webgpu-validate-shader-key/action.yml b/.github/actions/webgpu-validate-shader-key/action.yml
index 7b341d38ea906..86406a2e91877 100644
--- a/.github/actions/webgpu-validate-shader-key/action.yml
+++ b/.github/actions/webgpu-validate-shader-key/action.yml
@@ -22,7 +22,7 @@ runs:
working-directory: ${{ github.action_path }}
- name: Validate shader keys (native log)
- if: ${{ !inputs.is_chromium_log != 'true' }}
+ if: ${{ inputs.is_chromium_log != 'true' }}
shell: cmd
run: |
node validate-shader-key.js < "${{ inputs.log_file_path }}"
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 8d966d358de01..21baab0fd191c 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -26,7 +26,7 @@ jobs:
level: info
filter_mode: diff_context
- name: shellcheck # Static check shell scripts
- uses: reviewdog/action-shellcheck@v1
+ uses: reviewdog/action-shellcheck@v1.30.0
with:
github_token: ${{ secrets.github_token }}
reporter: github-pr-check
diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml
index ce0e5167eb0a0..57f687d8502ff 100644
--- a/.github/workflows/windows-web-ci-workflow.yml
+++ b/.github/workflows/windows-web-ci-workflow.yml
@@ -200,7 +200,6 @@ jobs:
- name: Validate shader keys - WebGPU EP
if: ${{ inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }}
- continue-on-error: true
uses: ./.github/actions/webgpu-validate-shader-key
with:
log_file_path: ${{ runner.temp }}\web\test\07\chrome_debug.log
diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml
index e8cea1c5805a3..8b3b8a2fcde54 100644
--- a/.github/workflows/windows_webgpu.yml
+++ b/.github/workflows/windows_webgpu.yml
@@ -126,7 +126,6 @@ jobs:
dir ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log
- name: Validate shader keys
- continue-on-error: true
uses: ./.github/actions/webgpu-validate-shader-key
with:
log_file_path: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index b64a5c3e5a4a2..77c35aac65b92 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -3,6 +3,7 @@
using System;
using System.Runtime.InteropServices;
+using static Microsoft.ML.OnnxRuntime.NativeMethods;
namespace Microsoft.ML.OnnxRuntime
{
@@ -325,6 +326,16 @@ public struct OrtApi
public IntPtr CreateLoraAdapterFromArray;
public IntPtr ReleaseLoraAdapter;
public IntPtr RunOptionsAddActiveLoraAdapter;
+ public IntPtr SetEpDynamicOptions;
+ public IntPtr ReleaseValueInfo;
+ public IntPtr ReleaseNode;
+ public IntPtr ReleaseGraph;
+ public IntPtr ReleaseModel;
+ public IntPtr GetValueInfoName;
+ public IntPtr GetValueInfoTypeInfo;
+ public IntPtr GetModelEditorApi;
+ public IntPtr CreateTensorWithDataAndDeleterAsOrtValue;
+ public IntPtr SessionOptionsSetLoadCancellationFlag;
}
internal static class NativeMethods
@@ -404,6 +415,7 @@ static NativeMethods()
OrtReleaseSessionOptions = (DOrtReleaseSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSessionOptions, typeof(DOrtReleaseSessionOptions));
OrtCloneSessionOptions = (DOrtCloneSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.CloneSessionOptions, typeof(DOrtCloneSessionOptions));
OrtSetSessionExecutionMode = (DOrtSetSessionExecutionMode)Marshal.GetDelegateForFunctionPointer(api_.SetSessionExecutionMode, typeof(DOrtSetSessionExecutionMode));
+ OrtSessionOptionsSetLoadCancellationFlag = (DOrtSessionOptionsSetLoadCancellationFlag)Marshal.GetDelegateForFunctionPointer(api_.SessionOptionsSetLoadCancellationFlag, typeof(DOrtSessionOptionsSetLoadCancellationFlag));
OrtSetOptimizedModelFilePath = (DOrtSetOptimizedModelFilePath)Marshal.GetDelegateForFunctionPointer(api_.SetOptimizedModelFilePath, typeof(DOrtSetOptimizedModelFilePath));
OrtEnableProfiling = (DOrtEnableProfiling)Marshal.GetDelegateForFunctionPointer(api_.EnableProfiling, typeof(DOrtEnableProfiling));
OrtDisableProfiling = (DOrtDisableProfiling)Marshal.GetDelegateForFunctionPointer(api_.DisableProfiling, typeof(DOrtDisableProfiling));
@@ -1025,6 +1037,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
ExecutionMode execution_mode);
public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsSetLoadCancellationFlag(IntPtr /*(OrtSessionOptions*)*/ options,
+ bool value);
+ public static DOrtSessionOptionsSetLoadCancellationFlag OrtSessionOptionsSetLoadCancellationFlag;
+
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath);
public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index bd450451a1265..9b0f183f03681 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -802,6 +802,16 @@ public ExecutionMode ExecutionMode
}
private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL;
+ ///
+ /// Sets the load cancellation flag for the session. Default is set to false.
+ /// Provides an opportunity for the user to cancel model loading.
+ ///
+ /// true to request cancellation, false to proceed
+ public void SetLoadCancellationFlag(bool value)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value));
+ }
+
#endregion
#region Private Methods
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f582abca34706..0308b5c79c508 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2039,10 +2039,11 @@ This version of the operator has been available since version 1 of the 'com.micr
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
- If `zero_points` is not provided, 0 is the zero point.
+ If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
+ 5. For uint8 data, the `gather_axis` must be 0.
#### Version
@@ -2082,7 +2083,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T1 : tensor(int4), tensor(uint4)
+- T1 : tensor(int4), tensor(uint4), tensor(uint8)
- Constrain quantized types.
- T2 : tensor(float), tensor(float16), tensor(bfloat16)
- Constrain dequantized types.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 60d9e8e747eeb..a20333e2340c4 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -515,7 +515,7 @@ Do not modify directly.*
|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedGemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
-|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)|
+|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)|
|GatherND|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h
index 0822eba950f50..10f658f52e0d9 100644
--- a/include/onnxruntime/core/common/common.h
+++ b/include/onnxruntime/core/common/common.h
@@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
abort(); \
} while (false)
+#define ORT_THROW_FROM_STATUS(status) \
+ do { \
+ ::onnxruntime::PrintFinalMessage( \
+ ::onnxruntime::OnnxRuntimeException( \
+ ORT_WHERE_WITH_STACK, status.ToString()) \
+ .what()); \
+ abort(); \
+ } while (false)
+
+#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \
+ do { \
+ ::onnxruntime::PrintFinalMessage( \
+ ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \
+ ::onnxruntime::MakeString(__VA_ARGS__), \
+ ::onnxruntime::common::category, \
+ ::onnxruntime::common::code) \
+ .what()); \
+ abort(); \
+ } while (false)
+
#else
#define ORT_TRY try
@@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
#define ORT_THROW_EX(ex, ...) \
throw ex(__VA_ARGS__)
+#define ORT_THROW_FROM_STATUS(status) \
+ throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \
+ static_cast<::onnxruntime::common::StatusCode>(status.Code()))
+
+#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \
+ throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \
+ ::onnxruntime::MakeString(__VA_ARGS__), \
+ ::onnxruntime::common::category, \
+ ::onnxruntime::common::code)
+
#endif
#define ORT_MAKE_STATUS(category, code, ...) \
@@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
auto _status = (expr); \
if ((!_status.IsOK())) { \
::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \
- ORT_THROW(_status); \
+ ORT_THROW_FROM_STATUS(_status); \
} \
} while (0)
diff --git a/include/onnxruntime/core/common/exceptions.h b/include/onnxruntime/core/common/exceptions.h
index 494a770b8db98..6d0f6edd6e7c4 100644
--- a/include/onnxruntime/core/common/exceptions.h
+++ b/include/onnxruntime/core/common/exceptions.h
@@ -11,6 +11,7 @@
#include
#include "core/common/common.h"
+#include "core/common/status.h"
#include "core/common/code_location.h"
namespace onnxruntime {
@@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception {
/**
Create a new exception that captures the location it was thrown from.
@param location Location in the source code the exception is being thrown from
+ @param msg Message containing additional information about the exception cause.
+ @param category Error category
+ @param code Error code
+ */
+
+ OnnxRuntimeException(const CodeLocation& location,
+ const std::string& message,
+ common::StatusCategory category,
+ common::StatusCode code) noexcept
+ : OnnxRuntimeException(location, nullptr, message, category, code) {
+ }
+
+ /**
+ Create a new exception that captures the location it was thrown from.
+ The instance will be created with ONNXRUNTIME category and FAIL code.
+ @param location Location in the source code the exception is being thrown from
@param failed_condition Optional string containing the condition that failed.
e.g. "tensor.Size() == input.Size()". May be nullptr.
@param msg Message containing additional information about the exception cause.
*/
- OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg)
- : location_{location} {
+ OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept
+ : OnnxRuntimeException(location, failed_condition, msg,
+ common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) {
+ }
+
+ /**
+ Create a new exception that captures the location it was thrown from.
+ @param location Location in the source code the exception is being thrown from
+ @param failed_condition Optional string containing the condition that failed.
+ e.g. "tensor.Size() == input.Size()". May be nullptr.
+ @param msg Message containing additional information about the exception cause.
+ @param category Error category
+ @param code Error code
+ */
+ OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg,
+ common::StatusCategory category,
+ common::StatusCode code)
+ : location_{location}, category_(category), code_(code) {
std::ostringstream ss;
ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous
@@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception {
what_ = ss.str();
}
+ common::StatusCategory Category() const noexcept {
+ return category_;
+ }
+
+ common::StatusCode Code() const noexcept {
+ return code_;
+ }
+
const char* what() const noexcept override {
return what_.c_str();
}
@@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception {
const CodeLocation location_;
const std::vector stacktrace_;
std::string what_;
+ common::StatusCategory category_;
+ common::StatusCode code_;
};
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h
index 8f171daabbb1e..b222e411d7804 100644
--- a/include/onnxruntime/core/common/status.h
+++ b/include/onnxruntime/core/common/status.h
@@ -43,7 +43,8 @@ enum StatusCode {
MODEL_LOADED = 8,
NOT_IMPLEMENTED = 9,
INVALID_GRAPH = 10,
- EP_FAIL = 11
+ EP_FAIL = 11,
+ MODEL_LOAD_CANCELED = 12,
};
constexpr const char* StatusCodeToString(StatusCode status) noexcept {
@@ -72,6 +73,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept {
return "INVALID_GRAPH";
case StatusCode::EP_FAIL:
return "EP_FAIL";
+ case StatusCode::MODEL_LOAD_CANCELED:
+ return "MODEL_LOAD_CANCELED";
default:
return "GENERAL ERROR";
}
@@ -104,6 +107,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept {
return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
case StatusCode::EP_FAIL:
return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
+ case StatusCode::MODEL_LOAD_CANCELED:
+ return HRESULT_FROM_WIN32(ERROR_CANCELLED);
default:
return E_FAIL;
}
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 6d4cc8a1f2fa9..3bf0d5e19c525 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -255,6 +255,7 @@ typedef enum OrtErrorCode {
ORT_NOT_IMPLEMENTED,
ORT_INVALID_GRAPH,
ORT_EP_FAIL,
+ ORT_MODEL_LOAD_CANCELED,
} OrtErrorCode;
typedef enum OrtOpAttrType {
@@ -4898,6 +4899,24 @@ struct OrtApi {
_In_ const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type,
_Outptr_ OrtValue** out);
+
+ /** \brief sets load cancellation flag to abort session loading process.
+ *
+ * \param[in] options instance that was passed to the session at creation time.
+ * \param[in] cancel setting this to true after model loading process was initiated will
+ * attempt to cancel the loading process. If cancellation is successful, CreateSession()
+ * CreateSessionFromArray() or any other session creation API that take session options as an
+ * argument will return an OrtStatus indicating that session loading was canceled at user request,
+ * error code ORT_MODEL_LOAD_CANCELED.
+ * The APIs above would not return any valid Session instance. This is the best case effort and the result
+ * is not guaranteed. The session may have already been created and initialized
+ * before the cancellation request was issued.
+ *
+ * \snippet{doc} snippets.dox OrtStatus
+ *
+ */
+ ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options,
+ _In_ bool cancel);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 979b478e2fbb4..ce7dc1c45b05e 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -928,6 +928,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl {
SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
+ SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag
+
SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 48c5e52e33c53..524e3ecc92936 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -747,6 +747,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionM
return *this;
}
+template
+inline SessionOptionsImpl& SessionOptionsImpl::SetLoadCancellationFlag(bool value) {
+ ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value));
+ return *this;
+}
+
template
inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) {
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json
index 9e4730a407d57..708e458748b3a 100644
--- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json
+++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json
@@ -12,7 +12,7 @@
},
"devDependencies": {
"@vitejs/plugin-vue": "^5.2.1",
- "vite": "^6.2.4"
+ "vite": "^6.2.5"
}
},
"node_modules/@babel/helper-string-parser": {
@@ -1069,9 +1069,9 @@
}
},
"node_modules/vite": {
- "version": "6.2.4",
- "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.4.tgz",
- "integrity": "sha512-veHMSew8CcRzhL5o8ONjy8gkfmFJAd5Ac16oxBUjlwgX3Gq2Wqr+qNC3TjPIpy7TPV/KporLga5GT9HqdrCizw==",
+ "version": "6.2.5",
+ "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.5.tgz",
+ "integrity": "sha512-j023J/hCAa4pRIUH6J9HemwYfjB5llR2Ps0CWeikOtdR8+pAURAk0DoJC5/mm9kd+UgdnIy7d6HE4EAvlYhPhA==",
"dev": true,
"license": "MIT",
"dependencies": {
diff --git a/js/web/test/e2e/exports/testcases/vite-default/package.json b/js/web/test/e2e/exports/testcases/vite-default/package.json
index e06733f917e3f..904db7a41de9c 100644
--- a/js/web/test/e2e/exports/testcases/vite-default/package.json
+++ b/js/web/test/e2e/exports/testcases/vite-default/package.json
@@ -13,6 +13,6 @@
},
"devDependencies": {
"@vitejs/plugin-vue": "^5.2.1",
- "vite": "^6.2.4"
+ "vite": "^6.2.5"
}
}
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 345b5e793a764..1a737f3a9d251 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized);
@@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
index 5935663f114a3..b83164d806ffc 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
@@ -16,6 +16,21 @@
namespace onnxruntime {
namespace contrib {
+namespace {
+template
+int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) {
+ return static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1)));
+}
+
+template <>
+int32_t GetDataElement(const uint8_t* data_ptr, int64_t data_idx) {
+ const uint8_t data_val_u8 = data_ptr[data_idx >> 1];
+ // Weights are stored as (nibble2)(nibble1) in uint8_t.
+ auto data_val = static_cast((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F));
+ return data_val;
+}
+} // namespace
+
template
class GatherBlockQuantized : public OpKernel {
public:
@@ -98,6 +113,12 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex
for (int64_t i = p.gather_axis + 1; i < static_cast(data_rank); ++i)
shape.push_back(data_shape[narrow(i)]);
+ // When data is stored as uint8_t, each element has two int4 values.
+ // The shape in the onnx model reflects that by having the last dimension be half the number of values.
+ // Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536.
+ // However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here.
+ uint32_t components = (std::is_same_v) ? 2 : 1;
+ shape[shape.size() - 1] = shape.back() * components;
p.output_tensor = context->Output(0, TensorShape(std::move(shape)));
// validate quantization parameters
@@ -106,7 +127,7 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex
"data and scales must have the same rank.");
for (size_t i = 0; i < data_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(i == static_cast(p.quantize_axis)
- ? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i]
+ ? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i]
: data_shape[i] == scales_shape[i],
"data and scales do not match shapes.");
}
@@ -165,16 +186,22 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr,
int64_t output_idx = output_idx_base;
int64_t data_idx = data_idx_base;
for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) {
- auto data_val = static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1)));
+ auto data_val = GetDataElement(data_ptr, data_idx);
int64_t x = data_idx / quantize_full_block;
int64_t y = data_idx % quantize_full_block / quantize_N;
int64_t z = data_idx % quantize_N;
int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z;
auto scale_val = static_cast(scales_ptr[scale_idx]);
- auto zp_val = static_cast(zero_points_ptr
- ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1))
- : 0);
+ int32_t zp_val;
+ if constexpr (std::is_same_v) {
+ // The default zero point for uint8 weights as stored by MatMulNBits op is 8.
+ zp_val = 8;
+ } else {
+ zp_val = static_cast(zero_points_ptr
+ ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1))
+ : 0);
+ }
output_ptr[output_idx] = static_cast(static_cast(data_val - zp_val) * scale_val);
}
@@ -205,7 +232,7 @@ template
Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
-
+ auto components = (std::is_same_v) ? 2 : 1;
const auto& data_shape = p.data_tensor->Shape();
// re-shape the data tensor to [gather_M, gather_axis_dim, gather_block]
// re-shape the indices tensor to [gather_N]
@@ -215,7 +242,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
// 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N],
// 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :],
// 4> pick the element from the block: value_i = data_blk[blk_ele_i]
- const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1);
+ const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1) * components;
const int64_t gather_axis_dim = data_shape[narrow(p.gather_axis)];
const int64_t gather_M = data_shape.SizeToDimension(narrow(p.gather_axis));
const int64_t gather_N = p.indices_tensor->Shape().Size();
@@ -229,7 +256,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
// data_i % (quantize_axis_dim * quantize_N) / quantize_N,
// data_i % quantize_N)
// 4> get scale index: (x, y / block_size_, z)
- const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)];
+ const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)] * components;
const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt(p.quantize_axis) + 1);
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
@@ -273,6 +300,8 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
.TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \
GatherBlockQuantized);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t);
diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc
index 0d4afc8c13f4b..6e7919f281fb6 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc
@@ -99,23 +99,24 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "var tileK: array;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n";
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
- << "let m = workgroup_id.y * TILE_SIZE;\n"
- << "let n = workgroup_id.x * TILE_SIZE;\n"
- << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
- << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
+ << "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n"
+ << "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n"
+ << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));\n"
+ << "let batch_idx = batch_head_idx / uniforms.num_heads;\n"
+ << "let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
- shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
+ shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
- shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n";
+ shader.MainFunctionBody() << "let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n";
}
shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
- " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
+ " if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
" }\n"
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
@@ -123,7 +124,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
- << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
+ << " let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
@@ -152,9 +153,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " workgroupBarrier();\n"
<< "}\n";
- shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
- << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
- << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
+ shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n"
+ << " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n"
+ << " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n"
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
@@ -181,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
- components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
+ components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
@@ -199,9 +200,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
}
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
- program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
- (parameters.sequence_length_ + tile_size - 1) / tile_size,
- parameters.batch_size_ * parameters.num_heads_)
+ const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
+ const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
+ program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
.AddUniformVariables({{static_cast(parameters.sequence_length_)},
@@ -214,7 +215,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
{static_cast(parameters.kv_sequence_length_)},
{static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast(parameters.n_reps)},
- {static_cast(parameters.is_first_prompt_ ? 1 : 0)}})
+ {static_cast(parameters.is_first_prompt_ ? 1 : 0)},
+ {num_total_seq_length_tile},
+ {num_seq_length_tile}})
.SetOverridableConstants({{static_cast(tile_size)}});
return context.RunProgram(program);
@@ -228,15 +231,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AdditionalImplementation() << "var thread_max: array;\n"
<< "var thread_sum: array;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n";
- shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
- << "let sequence_length = uniforms.sequence_length;\n"
+ shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n"
+ << "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
- << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
- << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n"
+ << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"
+ << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
@@ -292,7 +295,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(work_group_size)
- .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
+ .SetDispatchGroupSize(batch_size * num_heads * sequence_length)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast(batch_size)},
{static_cast(num_heads)},
@@ -321,19 +324,20 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AdditionalImplementation() << "var tileQ: array;\n"
<< "var tileK: array;\n";
- shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
- << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
- << "let m = global_id.y;\n"
- << "let n = global_id.x;\n"
- << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
+ shader.MainFunctionBody() << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * uniforms.num_seq_length_tile));\n"
+ << "let head_idx = batch_head_idx % uniforms.num_heads;\n"
+ << "let batch_idx = batch_head_idx / uniforms.num_heads;\n"
+ << "let m = (u32(workgroup_idx / uniforms.num_head_size_tile) % uniforms.num_seq_length_tile) * TILE_SIZE + local_id.y;\n"
+ << "let n = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE + local_id.x;\n"
+ << "let offsetA = batch_head_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.K;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
- shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
+ shader.MainFunctionBody() << "let vOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n";
if (has_present_value_) {
- shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n";
+ shader.MainFunctionBody() << "let presentValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n";
}
shader.MainFunctionBody() << "var value = output_value_t(0);\n"
@@ -346,7 +350,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
- << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n"
+ << " let pastValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n"
@@ -400,7 +404,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1);
constexpr int tile_size = 12;
int tile_n_size = tile_size * components;
- VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
+ VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_value) {
@@ -414,9 +418,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank, components});
}
- program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size,
- (parameters.sequence_length_ + tile_size - 1) / tile_size,
- parameters.batch_size_ * parameters.num_heads_)
+ const uint32_t num_head_size_tile = (parameters.v_head_size_ + tile_n_size - 1) / tile_n_size;
+ const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
+ program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_head_size_tile * num_seq_length_tile)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast(parameters.sequence_length_)},
@@ -429,7 +433,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
{static_cast(parameters.kv_sequence_length_)},
{static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast(parameters.n_reps)},
- {static_cast(parameters.is_first_prompt_)}})
+ {static_cast(parameters.is_first_prompt_)},
+ {num_head_size_tile},
+ {num_seq_length_tile}})
.SetOverridableConstants({{static_cast(tile_size)}});
return context.RunProgram(program);
diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h
index 164ea72b07d9d..7c0cb40cc7f93 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/attention.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h
@@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program
class AttentionProbsProgram final : public Program {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
- bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
- : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
+ bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
+ : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -50,7 +50,9 @@ class AttentionProbsProgram final : public Program {
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
- {"is_first_prompt", ProgramUniformVariableDataType::Uint32});
+ {"is_first_prompt", ProgramUniformVariableDataType::Uint32},
+ {"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
+ {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32});
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
@@ -60,7 +62,6 @@ class AttentionProbsProgram final : public Program {
bool has_attention_bias_;
int tile_size_;
int components_;
- int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
@@ -90,8 +91,8 @@ class InPlaceSoftmaxProgram final : public Program {
class VxAttentionScoreProgram final : public Program {
public:
- VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
- : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
+ VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
+ : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -106,7 +107,9 @@ class VxAttentionScoreProgram final : public Program {
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
- {"is_first_prompt", ProgramUniformVariableDataType::Uint32});
+ {"is_first_prompt", ProgramUniformVariableDataType::Uint32},
+ {"num_head_size_tile", ProgramUniformVariableDataType::Uint32},
+ {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32});
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
@@ -114,7 +117,6 @@ class VxAttentionScoreProgram final : public Program {
bool feed_past_value_;
bool has_present_value_;
int tile_size_;
- int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
diff --git a/onnxruntime/contrib_ops/webgpu/fused_conv.cc b/onnxruntime/contrib_ops/webgpu/fused_conv.cc
new file mode 100644
index 0000000000000..e6b7ac3ec24d4
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/fused_conv.cc
@@ -0,0 +1,33 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/webgpu/shader_helper.h"
+#include "core/providers/webgpu/webgpu_supported_types.h"
+#include "core/providers/webgpu/nn/conv.h"
+#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
+#include "core/providers/webgpu/nn/fuse_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace webgpu {
+using onnxruntime::webgpu::Conv;
+template
+class FusedConv final : public Conv {
+ public:
+ FusedConv(const OpKernelInfo& info) : Conv(info) {
+ ORT_ENFORCE(GetFusedActivationAttr(info, Conv::activation_).IsOK());
+ }
+};
+
+ONNX_OPERATOR_KERNEL_EX(
+ FusedConv,
+ kMSDomain,
+ 1,
+ kWebGpuExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", onnxruntime::webgpu::WebGpuSupportedFloatTypes()),
+ FusedConv);
+
+} // namespace webgpu
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
index 6e63ba3a0caa4..4136477a1d88c 100644
--- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
@@ -40,7 +40,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- // BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h
index 703d183ea5c87..b42c6a9ba3e10 100644
--- a/onnxruntime/core/framework/error_code_helper.h
+++ b/onnxruntime/core/framework/error_code_helper.h
@@ -17,16 +17,19 @@ Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category = c
#ifndef ORT_NO_EXCEPTIONS
#define API_IMPL_BEGIN try {
-#define API_IMPL_END \
- } \
- catch (const onnxruntime::NotImplementedException& ex) { \
- return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \
- } \
- catch (const std::exception& ex) { \
- return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
- } \
- catch (...) { \
- return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \
+#define API_IMPL_END \
+ } \
+ catch (const onnxruntime::OnnxRuntimeException& ex) { \
+ return OrtApis::CreateStatus(static_cast(ex.Code()), ex.what()); \
+ } \
+ catch (const onnxruntime::NotImplementedException& ex) { \
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \
+ } \
+ catch (const std::exception& ex) { \
+ return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
+ } \
+ catch (...) { \
+ return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \
}
#else
diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc
index ff4d300f665b1..50f14104cfd7a 100644
--- a/onnxruntime/core/framework/graph_partitioner.cc
+++ b/onnxruntime/core/framework/graph_partitioner.cc
@@ -56,6 +56,7 @@ namespace {
// contains some common parameters used by the partitioning helper functions
struct PartitionParams {
std::reference_wrapper graph;
+ std::reference_wrapper check_load_cancellation_fn;
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
std::reference_wrapper func_mgr;
std::reference_wrapper fused_kernel_registry;
@@ -143,6 +144,7 @@ struct GetCapabilityForEPParams {
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
IResourceAccountant* resource_accountant;
std::reference_wrapper graph_optimizer_registry;
+ std::reference_wrapper check_load_cancellation_fn;
};
auto get_capabilities = [](const IExecutionProvider& ep,
@@ -188,7 +190,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
{
const GraphViewer graph_viewer(graph);
- capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry);
+ capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant,
+ graph_optimizer_registry);
+ if (params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "Graph partitioning was canceled by user request");
+ }
if (capabilities.empty()) {
return Status::OK();
@@ -209,6 +216,10 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
// Perform layout transformation on the specific EP assigned graph
bool modified = false;
ORT_RETURN_IF_ERROR(params.transform_layout(graph, modified, current_ep, params.debug_graph_fn));
+ if (params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "GetCapabilities was canceled by user request");
+ }
// It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes
// which are reconstructed to update domain or completely new nodes which are necessary for layout transformation.
@@ -226,7 +237,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
capabilities.clear();
const GraphViewer graph_viewer(graph);
- capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry);
+ capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant,
+ graph_optimizer_registry);
+ if (params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "GetCapabilities was canceled by user request");
+ }
// all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities
InlinedHashSet new_nodes_in_capabilities;
@@ -405,6 +421,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
int& fused_node_unique_id,
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
const layout_transformation::DebugGraphFn& debug_graph_fn,
+ const CheckLoadCancellationFn& check_load_cancellation_fn,
const logging::Logger& logger, IResourceAccountant* resource_accountant,
const GraphOptimizerRegistry& graph_optimizer_registry) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
@@ -420,7 +437,10 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
// we pass through the FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
- transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry));
+ transform_layout_fn, debug_graph_fn,
+ check_load_cancellation_fn,
+ logger, resource_accountant,
+ graph_optimizer_registry));
}
}
@@ -445,7 +465,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
std::cref(transform_layout_fn),
std::cref(debug_graph_fn),
resource_accountant,
- std::ref(graph_optimizer_registry)};
+ std::ref(graph_optimizer_registry),
+ std::cref(check_load_cancellation_fn)};
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
if (capabilities.empty()) {
@@ -532,6 +553,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
}
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));
+ ORT_RETURN_IF(check_load_cancellation_fn(),
+ "Graph partitioning is canceled due to user request.");
if (node_compute_funcs.size() != nodes_to_compile.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
@@ -633,6 +656,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
Graph& graph,
const GraphOptimizerRegistry& graph_optimizer_registry,
const logging::Logger& logger,
+ const CheckLoadCancellationFn& check_load_cancellation_fn,
InlinedHashSet& not_inlined,
size_t& inlined_count) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
@@ -650,6 +674,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
*subgraph,
graph_optimizer_registry,
logger,
+ check_load_cancellation_fn,
not_inlined,
inlined_count));
}
@@ -673,8 +698,13 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
InlinedHashSet claimed_by_ep;
for (const auto& ep : execution_providers) {
std::vector> capabilities;
- ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger,
+ ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep,
+ graph_optimizer_registry, logger,
capabilities));
+ if (check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request.");
+ }
+
for (auto& capability : capabilities) {
const auto& nodes = capability->sub_graph->nodes;
if (nodes.size() == 1) {
@@ -707,6 +737,9 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id)));
}
}
+ if (check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request.");
+ }
}
return Status::OK();
@@ -846,6 +879,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
auto& fused_kernel_registry = partition_params.fused_kernel_registry.get();
auto& fused_node_unique_id = partition_params.fused_node_unique_id.get();
const auto& transform_layout_function = partition_params.transform_layout_function;
+ const CheckLoadCancellationFn& check_load_cancellation_fn = partition_params.check_load_cancellation_fn;
do {
// process full graph with each EP
@@ -861,6 +895,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
fused_kernel_registry, *ep, mode, fused_node_unique_id,
transform_layout_function,
partition_params.debug_graph_fn,
+ check_load_cancellation_fn,
logger, resource_accountant, graph_optimizer_registry));
}
@@ -915,7 +950,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
std::cref(partition_params.debug_graph_fn),
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
nullptr,
- std::ref(graph_optimizer_registry)
+ std::ref(graph_optimizer_registry),
+ partition_params.check_load_cancellation_fn
};
// clang-format on
@@ -972,6 +1008,9 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
std::vector single_node_compute_func;
ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}},
single_node_compute_func));
+ if (partition_params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph partitioning is canceled due to user request.");
+ }
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element.");
auto& func_mgr = partition_params.func_mgr.get();
@@ -1032,6 +1071,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
return Status::OK();
}
+ auto check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); };
+
auto& graph = model.MainGraph();
InlinedHashSet not_inlined;
do {
@@ -1041,13 +1082,13 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
graph,
*graph_optimizer_registry_,
logger,
+ check_load_cancellation_fn,
not_inlined,
inlined_count));
if (inlined_count == 0) {
break;
}
-
ORT_RETURN_IF_ERROR(graph.Resolve());
} while (true);
@@ -1082,6 +1123,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified.");
}
+ CheckLoadCancellationFn check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); };
+
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// fused_kernel_registry is preparing the kernels created on the fly for fused sub graph.
// It is only visible for current session.
@@ -1092,6 +1135,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
PartitionParams partition_params{
std::ref(graph),
+ std::cref(check_load_cancellation_fn),
std::ref(func_mgr),
std::ref(*fused_kernel_registry),
std::ref(fused_node_unique_id),
@@ -1105,6 +1149,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
ORT_UNUSED_PARAMETER(debug_graph_fn);
PartitionParams partition_params{
std::ref(graph),
+ std::cref(check_load_cancellation_fn),
};
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h
index b9d4022cb5a14..87edc7a64c6b5 100644
--- a/onnxruntime/core/framework/graph_partitioner.h
+++ b/onnxruntime/core/framework/graph_partitioner.h
@@ -33,6 +33,16 @@ class GraphPartitioner {
graph_optimizer_registry_(std::move(graph_optimizer_registry)) {
}
+ GraphPartitioner(KernelRegistryManager& kernel_registry_mgr,
+ const ExecutionProviders& providers,
+ std::unique_ptr graph_optimizer_registry,
+ CheckLoadCancellationFn check_load_cancellation_fn)
+ : kernel_registry_mgr_(kernel_registry_mgr),
+ providers_(providers),
+ graph_optimizer_registry_(std::move(graph_optimizer_registry)),
+ check_load_cancellation_fn_(std::move(check_load_cancellation_fn)) {
+ }
+
// Run partitioning.
Status Partition(Graph& graph, FuncManager& func_mgr,
const layout_transformation::TransformLayoutFunction& transform_layout_function,
@@ -41,6 +51,10 @@ class GraphPartitioner {
Mode mode = Mode::kNormal,
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;
+ bool IsLoadCancellationFlagSet() const {
+ return check_load_cancellation_fn_ && check_load_cancellation_fn_();
+ }
+
#ifndef ORT_MINIMAL_BUILD
///
// Ahead of Time Function inlining. The main purpose of the function is to inline as many
@@ -69,6 +83,7 @@ class GraphPartitioner {
KernelRegistryManager& kernel_registry_mgr_;
const ExecutionProviders& providers_;
std::unique_ptr graph_optimizer_registry_;
+ CheckLoadCancellationFn check_load_cancellation_fn_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h
index 8d4db36106f28..ef323b99b006c 100644
--- a/onnxruntime/core/framework/session_options.h
+++ b/onnxruntime/core/framework/session_options.h
@@ -8,6 +8,7 @@
#include
#include
#include
+#include
#include
#include "core/common/inlined_containers.h"
#include "core/framework/config_options.h"
@@ -66,6 +67,8 @@ struct FreeDimensionOverride {
int64_t dim_value;
};
+using CheckLoadCancellationFn = std::function;
+
/**
* Configuration information for a session.
*/
@@ -184,6 +187,18 @@ struct SessionOptions {
// User specified logging func and param
OrtLoggingFunction user_logging_function = nullptr;
void* user_logging_param = nullptr;
+
+ void SetLoadCancellationFlag(bool value) noexcept {
+ *load_cancellation_flag = value;
+ }
+
+ bool IsLoadCancellationFlagSet() const noexcept {
+ return *load_cancellation_flag;
+ }
+
+ // Load cancellation flag is necessary to be within shared memory as session_options are
+ // copied internally and the flag needs to be accessible across all copies.
+ std::shared_ptr load_cancellation_flag = std::make_shared(false);
};
inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) {
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index d174d6cc72ead..6362a3169f3a3 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -422,6 +422,10 @@ Status SessionState::PrepackConstantInitializedTensors(
auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map](
bool should_cache_prepacked_weights_for_shared_initializers) -> Status {
for (auto& node : GetGraphViewer().Nodes()) {
+ if (sess_options_.IsLoadCancellationFlagSet()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "Weight pre-packing was canceled due to user request.");
+ }
auto kernel = GetMutableKernel(node.Index());
int input_idx = 0;
for (auto& input_def : node.InputDefs()) {
@@ -1541,6 +1545,11 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringname();
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 7b4a45ce8aa0f..d87688a62040c 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3571,10 +3571,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
- If `zero_points` is not provided, 0 is the zero point.
+ If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
+ 5. For uint8 data, the `gather_axis` must be 0.
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized)
@@ -3602,7 +3603,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
.Input(2, "scales", "quantization scale", "T2")
.Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional)
.Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2")
- .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.")
+ .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.")
.TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.")
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
@@ -3637,14 +3638,19 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
gather_axis = (gather_axis + r) % r;
quantize_axis = (quantize_axis + r) % r;
+ if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) {
+ fail_shape_inference("gather_axis must be 0, for uint8 data");
+ }
+
if (scales_shape.dim_size() != r) {
fail_shape_inference("scales must have the same rank as data");
}
+ uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1;
for (int i = 0; i < r; ++i) {
if (!data_shape.dim(i).has_dim_value() ||
!scales_shape.dim(i).has_dim_value() ||
- (i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
+ (i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
(i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) {
fail_shape_inference("data shape and scales shape do not match");
}
@@ -3652,6 +3658,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
// validate zero point shape
if (ctx.hasInput(3)) {
+ if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) {
+ fail_type_inference("zero_points are not supported for uint8_t data type");
+ }
+
if (!hasInputShape(ctx, 3)) {
fail_shape_inference("zero_points shape must be known");
}
@@ -3675,12 +3685,15 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
}
for (int i = 0; i < out_rank; ++i) {
+ // For uint8_t data type the last dimension needs to be expanded back to actual dimension,
+ // because the data 2 int4s are stored packed in a single uint8_t.
+ auto last_dimension_components = (i == out_rank - 1) ? components : 1;
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() =
(i < gather_axis)
? data_shape.dim(i)
: (i >= gather_axis && i < gather_axis + q)
? indices_shape.dim(i - gather_axis)
- : data_shape.dim(i - q + 1);
+ : data_shape.dim(i - q + 1) * last_dimension_components;
}
});
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 39ffc6a5b0cee..334ecb3887d14 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -1268,6 +1268,10 @@ Graph::Graph(const Model& owning_model,
#endif
}
+ if (owning_model_.IsLoadCancellationFlagSet()) {
+ ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request.");
+ }
+
// Remove constant nodes as they're replaced with initializers above.
const gsl::not_null*> graph_mutable_nodes{graph_proto_->mutable_node()};
graph_mutable_nodes->erase(
@@ -1365,6 +1369,10 @@ Graph::Graph(const Model& owning_model,
}
}
+ if (owning_model_.IsLoadCancellationFlagSet()) {
+ ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request.");
+ }
+
for (auto& graph_output : graph_proto_->output()) {
if (utils::HasName(graph_output) && utils::HasType(graph_output)) {
auto& name = graph_output.name();
diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc
index 7629e40c1b5fe..436af7115eb1a 100644
--- a/onnxruntime/core/graph/model.cc
+++ b/onnxruntime/core/graph/model.cc
@@ -82,7 +82,7 @@ Model::Model(const std::string& graph_name,
const std::vector& model_local_functions,
const logging::Logger& logger,
const ModelOptions& options)
- : model_path_(model_path) {
+ : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) {
model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
model_proto_.mutable_graph()->set_name(graph_name);
model_metadata_ = model_metadata;
@@ -161,7 +161,7 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path,
Model::Model(ModelProto&& model_proto, const PathString& model_path,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger, const ModelOptions& options)
- : model_path_(model_path) {
+ : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) {
if (!utils::HasGraph(model_proto)) {
ORT_THROW("ModelProto does not have a graph.");
}
@@ -435,6 +435,11 @@ Status Model::Load(const ModelProto& model_proto,
ORT_TRY {
model = std::make_unique(model_proto, model_path, local_registries, logger, options);
}
+ ORT_CATCH(const OnnxRuntimeException& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = Status(ex.Category(), ex.Code(), ex.what());
+ });
+ }
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
@@ -474,6 +479,11 @@ Status Model::Load(ModelProto&& model_proto,
ORT_TRY {
model = std::make_unique(std::move(model_proto), model_path, local_registries, logger, options);
}
+ ORT_CATCH(const OnnxRuntimeException& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = Status(ex.Category(), ex.Code(), ex.what());
+ });
+ }
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
@@ -509,6 +519,11 @@ static Status LoadModelHelper(const T& file_path, Loader loader) {
ORT_TRY {
status = loader(fd);
}
+ ORT_CATCH(const OnnxRuntimeException& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = Status(ex.Category(), ex.Code(), ex.what());
+ });
+ }
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = Status(ONNXRUNTIME, FAIL, ex.what());
diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h
index 6fd94c60d6b99..70f82bcfb160b 100644
--- a/onnxruntime/core/graph/model.h
+++ b/onnxruntime/core/graph/model.h
@@ -11,6 +11,7 @@
#include "core/common/flatbuffers.h"
+#include "core/framework/session_options.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/ort_format_load_options.h"
#include "core/session/onnxruntime_c_api.h"
@@ -38,6 +39,14 @@ struct ModelOptions {
// be returned.
bool strict_shape_type_inference;
+ CheckLoadCancellationFn check_load_cancellation_fn;
+
+ ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference,
+ CheckLoadCancellationFn check_load_cancellation_fn)
+ : allow_released_opsets_only(allow_released_opsets_only),
+ strict_shape_type_inference(strict_shape_type_inference),
+ check_load_cancellation_fn(std::move(check_load_cancellation_fn)) {}
+
ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference)
: allow_released_opsets_only(allow_released_opsets_only),
strict_shape_type_inference(strict_shape_type_inference) {}
@@ -102,6 +111,11 @@ class Model {
#endif // !defined(ORT_MINIMAL_BUILD)
+ // Check for load cancellation.
+ bool IsLoadCancellationFlagSet() const noexcept {
+ return check_load_cancellation_fn_ && check_load_cancellation_fn_();
+ }
+
#if !defined(ORT_MINIMAL_BUILD)
// Get model's IR version.
// Return if not specified.
@@ -343,5 +357,7 @@ class Model {
// Main graph of the model.
std::unique_ptr graph_;
+
+ CheckLoadCancellationFn check_load_cancellation_fn_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc
index ea9d8605e2417..71c8667a89b1d 100644
--- a/onnxruntime/core/optimizer/conv_activation_fusion.cc
+++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc
@@ -121,7 +121,7 @@ class ConvActivationSelector : public NodeSelector {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) {
return std::nullopt;
}
- } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) {
+ } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider || node_ep == kWebGpuExecutionProvider) {
if (!is_supported_non_cuda_rocm_ep_activation(*next_node) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) {
return std::nullopt;
diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc
index 039283bb2d4e1..83c3f70799987 100644
--- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc
@@ -27,6 +27,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor
}
for (unsigned step = 0; step < steps_; ++step) {
+ if (IsLoadCancellationFlagSet()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph transformation canceled due to user request.");
+ }
bool graph_changed = false;
for (const auto& transformer : transformers->second) {
if (step > 0 && transformer->ShouldOnlyApplyOnce())
diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h
index ed66302434ab2..eab57f12bfcbb 100644
--- a/onnxruntime/core/optimizer/graph_transformer_mgr.h
+++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h
@@ -24,6 +24,16 @@ class GraphTransformerManager {
// Get the maximum number of graph transformation steps
common::Status GetSteps(unsigned& steps) const;
+ // Set the cancellation flag ptr from session_options
+ void SetLoadCancellationFn(CheckLoadCancellationFn check_load_cancellation_fn) {
+ check_load_cancellation_fn_ = std::move(check_load_cancellation_fn);
+ }
+
+ // Get the cancellation flag ptr
+ bool IsLoadCancellationFlagSet() const noexcept {
+ return check_load_cancellation_fn_ && check_load_cancellation_fn_();
+ }
+
// Register a transformer with a level.
common::Status Register(std::unique_ptr transformer, TransformerLevel level);
@@ -38,5 +48,6 @@ class GraphTransformerManager {
InlinedHashMap>> level_to_transformer_map_;
InlinedHashMap transformers_info_;
+ CheckLoadCancellationFn check_load_cancellation_fn_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 9684394da0520..eae2a464cef7e 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -296,17 +296,19 @@ InlinedVector> GenerateTransformers(
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider,
onnxruntime::kDmlExecutionProvider};
- const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
- onnxruntime::kRocmExecutionProvider,
- onnxruntime::kAclExecutionProvider,
- onnxruntime::kArmNNExecutionProvider,
- onnxruntime::kJsExecutionProvider};
- const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
- onnxruntime::kCudaExecutionProvider,
- onnxruntime::kRocmExecutionProvider,
- onnxruntime::kAclExecutionProvider,
- onnxruntime::kArmNNExecutionProvider,
- onnxruntime::kJsExecutionProvider};
+ const InlinedHashSet cpu_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider,
+ onnxruntime::kRocmExecutionProvider,
+ onnxruntime::kAclExecutionProvider,
+ onnxruntime::kArmNNExecutionProvider,
+ onnxruntime::kJsExecutionProvider,
+ onnxruntime::kWebGpuExecutionProvider};
+ const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider,
+ onnxruntime::kCudaExecutionProvider,
+ onnxruntime::kRocmExecutionProvider,
+ onnxruntime::kAclExecutionProvider,
+ onnxruntime::kArmNNExecutionProvider,
+ onnxruntime::kJsExecutionProvider,
+ onnxruntime::kWebGpuExecutionProvider};
const InlinedHashSet cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kDmlExecutionProvider,
onnxruntime::kAclExecutionProvider};
@@ -338,7 +340,7 @@ InlinedVector> GenerateTransformers(
transformers.emplace_back(std::make_unique(cpu_dml_acl_eps));
transformers.emplace_back(std::make_unique(cpu_acl_eps));
- transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps));
+ transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_webgpu_eps));
transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level));
diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
index 93b2acb5b002c..26642459a6863 100644
--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
+++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
@@ -77,6 +77,7 @@ Status CreateNodeArgs(const std::vector& names,
const OnnxTensorInfo& tensor_info = tensor_info_table.at(name);
std::unique_ptr tensor_type = Factory::Create();
tensor_type->mutable_tensor_type()->set_elem_type(tensor_info.data_type_);
+ tensor_type->mutable_tensor_type()->mutable_shape();
for (size_t j = 0; j < tensor_info.shape_.size(); ++j) {
tensor_type->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]);
}
diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h
index 9d5e16caa361d..bc8905c225822 100644
--- a/onnxruntime/core/providers/shared_library/provider_interfaces.h
+++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h
@@ -611,6 +611,8 @@ struct ProviderHost {
virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
+ virtual void InferShapes(const std::string& m, const std::string& save_path) = 0;
+ virtual void InferShapes(ONNX_NAMESPACE::ModelProto& m) = 0;
virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0;
virtual void DeregisterSchema(const std::string& domain, const std::string& op_type, int version) = 0;
virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0;
@@ -1010,6 +1012,7 @@ struct ProviderHost {
virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0;
virtual Graph* Graph__MutableParentGraph(Graph* p) = 0;
virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0;
+ virtual void Graph__SetName(Graph* p, const std::string& name) const noexcept = 0;
virtual const std::filesystem::path& Graph__ModelPath(const Graph* p) const = 0;
virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0;
virtual bool Graph__IsSubgraph(const Graph* p) = 0;
diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
index e2af144f455e4..5f0f9ca4c8584 100644
--- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
+++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
@@ -1050,6 +1050,7 @@ struct Graph final {
const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); }
Graph* MutableParentGraph() { return g_host->Graph__MutableParentGraph(this); }
const std::string& Name() const noexcept { return g_host->Graph__Name(this); }
+ void SetName(const std::string& name) noexcept { return g_host->Graph__SetName(this, name); }
const std::filesystem::path& ModelPath() const { return g_host->Graph__ModelPath(this); }
const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); }
bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); }
diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc
index 6547f00cd47c7..33aa8fa2b31b8 100644
--- a/onnxruntime/core/providers/vitisai/imp/global_api.cc
+++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc
@@ -360,10 +360,19 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
};
the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); };
the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); };
+ the_global_api.graph_set_name = [](Graph& graph, const char* name) -> void { return graph.SetName(std::string(name)); };
the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from,
const auto& enter, const auto& leave, const auto& stop) {
graph.ReverseDFSFrom(from, enter, leave, nullptr, stop);
};
+
+ the_global_api.graph_infer_shapes_from_filepath = [](const std::string& m, const std::string& save_path) -> auto { return Provider_GetHost()->InferShapes(m, save_path); };
+ the_global_api.graph_to_graph_proto = [](const Graph& graph) -> ONNX_NAMESPACE::GraphProto* {
+ return graph.ToGraphProto().release();
+ };
+ the_global_api.graph_proto_delete = [](ONNX_NAMESPACE::GraphProto* p) { delete p; };
+ the_global_api.graph_infer_shapes = [](ONNX_NAMESPACE::ModelProto& m) -> auto { return Provider_GetHost()->InferShapes(m); };
+
// node
the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs;
the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args;
diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h
index 85a1262d8489b..6c9c728d8ffad 100644
--- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h
+++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h
@@ -20,6 +20,7 @@ struct NodeAttributes;
namespace ONNX_NAMESPACE {
struct AttributeProto;
struct TensorProto;
+struct GraphProto;
struct ModelProto;
#ifndef USE_VITISAI
enum TensorProto_DataType : int {
@@ -71,6 +72,7 @@ enum AttributeProto_AttributeType : int {
namespace vaip_core {
class GraphHolder;
using ONNX_NAMESPACE::AttributeProto;
+using ONNX_NAMESPACE::GraphProto;
using ONNX_NAMESPACE::ModelProto;
using ONNX_NAMESPACE::TensorProto;
using onnxruntime::Graph;
diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
index 0becc41d861f7..d40da70726b43 100644
--- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
+++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
@@ -13,7 +13,7 @@ struct OrtApi;
namespace vaip_core {
-#define VAIP_ORT_API_MAJOR (14u)
+#define VAIP_ORT_API_MAJOR (16u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
@@ -249,7 +249,13 @@ struct OrtApiForVaip {
const std::function& leave,
const std::function& comp,
const std::function&
- stop); // [103]
+ stop); // [103]
+ void (*graph_set_name)(Graph& graph, const char* name); // [104]
+ void (*graph_infer_shapes_from_filepath)(
+ const std::string& m, const std::string& save_path); // [105]
+ GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106]
+ void (*graph_proto_delete)(GraphProto* p); // [107]
+ void (*graph_infer_shapes)(ModelProto& m); // [108]
};
#ifndef USE_VITISAI
diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc
index 91cae111a708a..315d0cd75e946 100644
--- a/onnxruntime/core/providers/webgpu/allocator.cc
+++ b/onnxruntime/core/providers/webgpu/allocator.cc
@@ -13,15 +13,15 @@ void* GpuBufferAllocator::Alloc(size_t size) {
return nullptr;
}
- WGPUBuffer buffer;
- if (!session_initialized_ && context_.SupportsBufferMapExtendedUsages()) {
- buffer = context_.BufferManager().CreateUMA(size);
- } else {
- buffer = context_.BufferManager().Create(size);
+ stats_.num_allocs++;
+
+#if !defined(__wasm__)
+ if (!session_initialized_ && context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages)) {
+ return context_.BufferManager().CreateUMA(size);
}
+#endif // !defined(__wasm__)
- stats_.num_allocs++;
- return buffer;
+ return context_.BufferManager().Create(size);
}
void GpuBufferAllocator::Free(void* p) {
diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc
index adb37f54f2e8f..1d8c689cbd909 100644
--- a/onnxruntime/core/providers/webgpu/buffer_manager.cc
+++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc
@@ -56,14 +56,27 @@ class LazyReleaseCacheManager : public IBufferCacheManager {
}
void ReleaseBuffer(WGPUBuffer buffer) override {
- pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer));
+ pending_buffers_.emplace_back(buffer);
}
void OnRefresh() override {
+ Release();
pending_buffers_.clear();
}
- std::vector pending_buffers_;
+ public:
+ ~LazyReleaseCacheManager() {
+ Release();
+ }
+
+ protected:
+ void Release() {
+ for (auto& buffer : pending_buffers_) {
+ wgpuBufferRelease(buffer);
+ }
+ }
+
+ std::vector pending_buffers_;
};
class SimpleCacheManager : public IBufferCacheManager {
@@ -74,7 +87,7 @@ class SimpleCacheManager : public IBufferCacheManager {
WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override {
auto it = buffers_.find(buffer_size);
if (it != buffers_.end() && !it->second.empty()) {
- auto buffer = it->second.back().MoveToCHandle();
+ auto buffer = it->second.back();
it->second.pop_back();
return buffer;
}
@@ -87,18 +100,31 @@ class SimpleCacheManager : public IBufferCacheManager {
}
void ReleaseBuffer(WGPUBuffer buffer) override {
- pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer));
+ pending_buffers_.emplace_back(buffer);
}
void OnRefresh() override {
for (auto& buffer : pending_buffers_) {
- buffers_[static_cast(buffer.GetSize())].emplace_back(std::move(buffer));
+ buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer);
}
pending_buffers_.clear();
}
- std::map> buffers_;
- std::vector pending_buffers_;
+ public:
+ ~SimpleCacheManager() {
+ for (auto& buffer : pending_buffers_) {
+ wgpuBufferRelease(buffer);
+ }
+ for (auto& pair : buffers_) {
+ for (auto& buffer : pair.second) {
+ wgpuBufferRelease(buffer);
+ }
+ }
+ }
+
+ protected:
+ std::map> buffers_;
+ std::vector pending_buffers_;
};
// TODO: maybe use different bucket size for storage and uniform buffers?
@@ -155,7 +181,7 @@ class BucketCacheManager : public IBufferCacheManager {
WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override {
auto it = buckets_.find(buffer_size);
if (it != buckets_.end() && !it->second.empty()) {
- auto buffer = it->second.back().MoveToCHandle();
+ auto buffer = it->second.back();
it->second.pop_back();
return buffer;
}
@@ -167,31 +193,44 @@ class BucketCacheManager : public IBufferCacheManager {
}
void ReleaseBuffer(WGPUBuffer buffer) override {
- pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer));
+ pending_buffers_.emplace_back(buffer);
}
void OnRefresh() override {
// TODO: consider graph capture. currently not supported
for (auto& buffer : pending_buffers_) {
- auto buffer_size = static_cast(buffer.GetSize());
+ auto buffer_size = static_cast(wgpuBufferGetSize(buffer));
auto it = buckets_.find(buffer_size);
if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) {
- it->second.emplace_back(std::move(buffer));
+ it->second.emplace_back(buffer);
+ } else {
+ wgpuBufferRelease(buffer);
}
}
pending_buffers_.clear();
}
+ ~BucketCacheManager() {
+ for (auto& buffer : pending_buffers_) {
+ wgpuBufferRelease(buffer);
+ }
+ for (auto& pair : buckets_) {
+ for (auto& buffer : pair.second) {
+ wgpuBufferRelease(buffer);
+ }
+ }
+ }
+
protected:
void Initialize() {
buckets_keys_.reserve(buckets_limit_.size());
buckets_.reserve(buckets_limit_.size());
for (const auto& pair : buckets_limit_) {
buckets_keys_.push_back(pair.first);
- buckets_.emplace(pair.first, std::vector());
+ buckets_.emplace(pair.first, std::vector());
}
std::sort(buckets_keys_.begin(), buckets_keys_.end());
@@ -205,8 +244,8 @@ class BucketCacheManager : public IBufferCacheManager {
#endif
}
std::unordered_map buckets_limit_;
- std::unordered_map> buckets_;
- std::vector pending_buffers_;
+ std::unordered_map> buckets_;
+ std::vector pending_buffers_;
std::vector buckets_keys_;
};
@@ -255,11 +294,10 @@ BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buf
void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) {
// If the buffer is mapped, we can directly write to it.
- wgpu::Buffer dst_buffer = dst;
- auto mapped_data = dst_buffer.GetMappedRange();
+ void* mapped_data = wgpuBufferGetMappedRange(dst, 0, WGPU_WHOLE_MAP_SIZE); // ensure the buffer is mapped
if (mapped_data) {
memcpy(mapped_data, src, size);
- dst_buffer.Unmap();
+ wgpuBufferUnmap(dst);
return;
}
@@ -288,9 +326,11 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) {
EnforceBufferUnmapped(context_, dst);
auto buffer_size = NormalizeBufferSize(size);
- ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst),
+ auto src_size = static_cast(wgpuBufferGetSize(src));
+ auto dst_size = static_cast(wgpuBufferGetSize(dst));
+ ORT_ENFORCE(buffer_size <= src_size && buffer_size <= dst_size,
"Source and destination buffers must have enough space for the copy operation. src_size=",
- wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, ".");
+ src_size, ", dst_size=", dst_size, ", copy_size=", buffer_size, ".");
auto& command_encoder = context_.GetCommandEncoder();
context_.EndComputePass();
@@ -298,7 +338,7 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) {
}
WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) {
- auto& cache = GetCacheManager(static_cast(usage));
+ auto& cache = GetCacheManager(usage);
auto buffer_size = cache.CalculateBufferSize(size);
auto buffer = cache.TryAcquireCachedBuffer(buffer_size);
@@ -310,7 +350,6 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) {
wgpu::BufferDescriptor desc{};
desc.size = buffer_size;
desc.usage = usage;
- // desc.label = std::to_string(xx++).c_str();
buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle();
ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), ".");
@@ -320,14 +359,16 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) {
}
WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) {
- ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must have storage usage.");
- auto& cache = GetCacheManager(static_cast(usage));
+ ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer.");
+ auto& cache = GetCacheManager(usage);
auto buffer_size = cache.CalculateBufferSize(size);
+ // Ensure the buffer is mapped for writing at creation.
+ usage |= wgpu::BufferUsage::MapWrite;
+
wgpu::BufferDescriptor desc{};
desc.size = buffer_size;
- // Ensure the buffer is mapped for writing at creation.
- desc.usage = usage | wgpu::BufferUsage::MapWrite;
+ desc.usage = usage;
desc.mappedAtCreation = true;
auto buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle();
@@ -373,12 +414,12 @@ void BufferManager::RefreshPendingBuffers() {
default_cache_->OnRefresh();
}
-IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const {
- if (usage & WGPUBufferUsage_Storage) {
+IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const {
+ if (usage & wgpu::BufferUsage::Storage) {
return *storage_cache_;
- } else if (usage & WGPUBufferUsage_Uniform) {
+ } else if (usage & wgpu::BufferUsage::Uniform) {
return *uniform_cache_;
- } else if (usage & WGPUBufferUsage_QueryResolve) {
+ } else if (usage & wgpu::BufferUsage::QueryResolve) {
return *query_resolve_cache_;
} else {
return *default_cache_;
@@ -386,7 +427,8 @@ IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const
}
IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const {
- return GetCacheManager(wgpuBufferGetUsage(buffer));
+ auto usage = static_cast(wgpuBufferGetUsage(buffer));
+ return GetCacheManager(usage);
}
std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) {
diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h
index 6a8ebdd60a1ec..b9028ad5de858 100644
--- a/onnxruntime/core/providers/webgpu/buffer_manager.h
+++ b/onnxruntime/core/providers/webgpu/buffer_manager.h
@@ -70,7 +70,7 @@ class BufferManager {
void RefreshPendingBuffers();
private:
- IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const;
+ IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const;
IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const;
WebGpuContext& context_;
diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc
index 9b447d5fdb59a..cdd3909874e7f 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul.cc
+++ b/onnxruntime/core/providers/webgpu/math/matmul.cc
@@ -6,8 +6,9 @@
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
-
+#include "core/providers/webgpu/nn/fuse_utils.h"
#include "core/providers/webgpu/data_transfer.h"
+
namespace onnxruntime {
namespace webgpu {
@@ -54,11 +55,12 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
std::string process_bias;
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
- process_bias = "value += output_value_t(bias[row + i]);";
+ process_bias = is_channels_last_ ? "value += output_value_t(bias[col])" : "value += output_value_t(bias[row + i]);";
}
+ std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t");
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform |
- ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
+ ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& batch_dims = shader.AddIndices("batch_dims");
int a_components = a.NumComponents();
@@ -90,6 +92,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "for (var i = 0u; i < " << output_number_ << "u; i++) {\n"
<< " var value = values[i];\n"
<< process_bias << "\n"
+ << apply_activation << "\n"
<< " let cur_indices = output_indices_t(batch, row + i, col/ " << components << ");\n"
<< " let offset = " << output.IndicesToOffset("cur_indices") << ";\n"
<< output.SetByOffset("offset", "value")
@@ -127,7 +130,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1;
TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components});
- MatMulNaiveProgram program{output_rank, output_number, has_bias};
+ MatMulNaiveProgram program{Activation(), output_rank, output_number, has_bias};
program
.CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number))
@@ -147,11 +150,32 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
return context.RunProgram(program);
}
- int64_t batchA = a->Shape().SizeToDimension(a->Shape().NumDimensions() - 2);
- int64_t batchB = b->Shape().SizeToDimension(b->Shape().NumDimensions() - 2);
+ std::vector inputs(has_bias ? 3 : 2);
+ inputs[0] = a;
+ inputs[1] = b;
+ if (has_bias) {
+ const auto* bias = context.Input(2);
+ inputs.push_back(bias);
+ }
+ auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false);
+
+ return context.RunProgram(program);
+}
+
+MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last,
+ const TensorShape& input_a_reshape,
+ const TensorShape& input_b_reshape) {
+ const auto* a = inputs[0];
+ const auto* b = inputs[1];
+ bool has_bias = inputs.size() > 2;
+ TensorShape a_shape = input_a_reshape.NumDimensions() > 0 ? input_a_reshape : a->Shape();
+ TensorShape b_shape = input_b_reshape.NumDimensions() > 0 ? input_b_reshape : b->Shape();
+
+ MatMulComputeHelper helper;
+ ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape));
+ int64_t batchA = a_shape.SizeToDimension(a_shape.NumDimensions() - 2);
+ int64_t batchB = b_shape.SizeToDimension(b_shape.NumDimensions() - 2);
- TensorShape a_shape = a->Shape();
- TensorShape b_shape = b->Shape();
TensorShape output_shape = helper.OutputShape();
const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2];
@@ -184,9 +208,9 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
const int64_t batch_size = outer_dims.Size();
// Get dimensions for matrix multiplication from TensorShape
- const int32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension
- const int32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension
- const int32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension
+ const uint32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension
+ const uint32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension
+ const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension
const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0;
@@ -194,34 +218,36 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
? InlinedVector({4, 1, 1})
: InlinedVector({4, 4, 1});
- const uint32_t dispatch_x = narrow((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
- (MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
- const uint32_t dispatch_y = narrow((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
- (MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
- const uint32_t dispatch_z = narrow((static_cast(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
- (MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
+ const uint32_t dispatch_x = narrow((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
+ (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
+ const uint32_t dispatch_y = narrow((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
+ (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
+ const uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
+ (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
const int components = is_vec4 ? 4 : 1;
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});
- MatMulProgram program{has_bias, is_vec4, elements_per_thread};
+ MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last};
program
- .CacheHint(absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4))
+ .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4))
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
- .SetWorkgroupSize(MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z);
+ .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z);
if (has_bias) {
- const auto* bias = context.Input(2);
- program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1});
+ auto bias_components = is_channels_last ? components : 1;
+ const auto* bias = inputs[2];
+ TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
+ program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components});
}
- return context.RunProgram(program);
+ return program;
}
} // namespace webgpu
diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h
index 789e824383189..91216d8e25eec 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul.h
+++ b/onnxruntime/core/providers/webgpu/math/matmul.h
@@ -9,16 +9,20 @@
#include "core/providers/webgpu/math/matmul_utils.h"
#include "core/providers/webgpu/math/matmul_packed.h"
#include "core/providers/webgpu/webgpu_utils.h"
+#include "core/providers/webgpu/nn/fuse_utils.h"
namespace onnxruntime {
namespace webgpu {
+MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last,
+ const TensorShape& input_a_reshape = TensorShape(),
+ const TensorShape& input_b_reshape = TensorShape());
+
class MatMul final : public WebGpuKernel {
public:
MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {}
Status ComputeInternal(ComputeContext& context) const override;
-
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8;
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8;
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Z = 1;
@@ -26,8 +30,8 @@ class MatMul final : public WebGpuKernel {
class MatMulNaiveProgram final : public Program {
public:
- MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias)
- : Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} {
+ MatMulNaiveProgram(const Activation& activation, const size_t output_rank, int64_t output_number, bool has_bias, bool is_channels_last = false)
+ : Program{"MatMulNaive"}, activation_(activation), output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias}, is_channels_last_(is_channels_last) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -38,9 +42,11 @@ class MatMulNaiveProgram final : public Program {
{"K", ProgramUniformVariableDataType::Uint32});
private:
+ const Activation& activation_;
const size_t output_rank_;
const int64_t output_number_;
const bool has_bias_;
+ const bool is_channels_last_;
};
} // namespace webgpu
diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc
index 2e5cff923f442..36510eec0cd3b 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc
+++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc
@@ -5,7 +5,7 @@
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"
-
+#include
namespace onnxruntime {
namespace webgpu {
@@ -13,7 +13,8 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader,
const ShaderVariableHelper& a,
const ShaderVariableHelper& b,
const ShaderVariableHelper& output,
- const ShaderIndicesHelper& batch_dims) const {
+ const ShaderIndicesHelper& batch_dims,
+ std::string activation_snippet) const {
int components = is_vec4_ ? 4 : 1;
const std::string data_type = "a_element_t";
const std::string type_string = MakeScalarOrVectorType(components, data_type);
@@ -23,7 +24,7 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader,
<< "fn mm_readA(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n"
<< " var value = " << type_string << "(0.0);\n"
<< " let col = colIn * " << components << ";\n"
- << " if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) {\n"
+ << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n"
<< " var a_indices: a_indices_t;\n"
<< ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims.Rank(), "batch_indices")
<< a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n"
@@ -38,7 +39,7 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader,
<< "fn mm_readB(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n"
<< " var value = " << type_string << "(0.0);\n"
<< " let col = colIn * " << components << ";\n"
- << " if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) {\n"
+ << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n"
<< " var b_indices: b_indices_t;\n"
<< ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims.Rank(), "batch_indices")
<< b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n"
@@ -52,14 +53,16 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader,
shader.AdditionalImplementation()
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: " << type_string << ") {\n"
<< " let col = colIn * " << components << ";\n"
- << " if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {\n"
+ << " if (row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n"
<< " var value = valueIn;\n"
<< " let coords = vec3(batch, row, colIn);\n";
if (has_bias_) {
- shader.AdditionalImplementation() << " value = value + " << type_string << "(bias[row]);\n";
+ shader.AdditionalImplementation() << " value = value + " << (is_channels_last_ ? "bias[colIn]" : type_string + "(bias[row])") << ";\n";
}
+ shader.AdditionalImplementation() << " " << activation_snippet << "\n";
+
shader.AdditionalImplementation()
<< output.SetByIndices("vec3(coords)", "value") << "\n"
<< " }\n"
@@ -67,29 +70,36 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader,
}
Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
- const ShaderIndicesHelper& batch_dims,
const InlinedVector& elements_per_thread,
uint32_t workgroup_size_x,
- uint32_t workgroup_size_y) {
+ uint32_t workgroup_size_y,
+ const std::string& data_type,
+ const ShaderIndicesHelper* batch_dims,
+ bool transpose_a,
+ uint32_t tile_inner,
+ bool split_k,
+ uint32_t splitted_dim_inner) {
+ ORT_UNUSED_PARAMETER(split_k);
+ ORT_UNUSED_PARAMETER(splitted_dim_inner);
+ std::string write_data_to_sub_a_vec4_snippet =
+ transpose_a ? std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, kStart + inputRow, globalRowStart / innerElementSize + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n"
+ : std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRow + innerRow, kStart / innerElementSize + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n";
// elements per thread
const auto elements_per_thread_x = elements_per_thread[0];
const auto elements_per_thread_y = elements_per_thread[1];
- const decltype(elements_per_thread_x) tile_inner = 32;
const auto tile_a_outer = workgroup_size_y * elements_per_thread_y;
const auto tile_b_outer = workgroup_size_x * elements_per_thread_x;
- const auto tile_a_width = tile_inner;
-
- const auto tile_a_height = tile_a_outer;
+ const auto tile_a_width = transpose_a ? tile_a_outer : tile_inner;
+ const auto tile_a_height = transpose_a ? tile_inner : tile_a_outer;
const auto inner_elements_size = tile_a_width / workgroup_size_x;
const auto row_per_thread_b = tile_inner / workgroup_size_y;
- const std::string data_type = "a_element_t";
-
- if (!((inner_elements_size == 3 || inner_elements_size == 4) &&
- tile_a_width % workgroup_size_x == 0 &&
- tile_inner % workgroup_size_y == 0 &&
- elements_per_thread_x == 4)) {
+ if (!((transpose_a && inner_elements_size == 4 && elements_per_thread[1] == 4) ||
+ (!transpose_a && (inner_elements_size == 3 || inner_elements_size == 4))) &&
+ tile_a_width % workgroup_size_x == 0 &&
+ tile_inner % workgroup_size_y == 0 &&
+ elements_per_thread_x == 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid matrix multiplication configuration inner_elements_size: ", inner_elements_size,
" must be 3 or 4. tile_a_width: ", tile_a_width, " must be divisible by WorkgroupSizeX: ",
@@ -112,7 +122,7 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
<< " let globalRow = i32(global_id.y) * rowPerThread;\n"
<< " let globalCol = i32(global_id.x);\n"
<< " let batch = i32(global_id.z);\n"
- << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n"
+ << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
<< " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
@@ -121,14 +131,14 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
// Loop over shared dimension.
shader.MainFunctionBody()
<< " let tileRowB = localRow * " << row_per_thread_b << ";\n"
- << " for (var t = 0; t < num_tiles; t = t + 1) {\n";
+ << " for (var t = 0; t < i32(num_tiles); t = t + 1) {\n";
// Load one tile of A into local memory.
shader.MainFunctionBody()
<< " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n"
<< " let inputRow = tileRow + innerRow;\n"
<< " let inputCol = tileCol;\n"
- << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRow + innerRow, kStart / innerElementSize + inputCol, batchIndices);\n"
+ << " " << write_data_to_sub_a_vec4_snippet
<< " }\n";
// Load one tile of B into local memory.
@@ -136,7 +146,7 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
<< " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n"
<< " let inputRow = tileRowB + innerRow;\n"
<< " let inputCol = tileCol;\n"
- << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol, batchIndices);\n"
+ << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol" << (nullptr != batch_dims ? ", batchIndices" : "") << ");\n"
<< " }\n"
<< " kStart = kStart + tileInner;\n"
<< " workgroupBarrier();\n";
@@ -152,15 +162,29 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
shader.MainFunctionBody() << " let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];\n";
}
- shader.MainFunctionBody()
- << " for (var i = 0; i < rowPerThread; i = i + 1) {\n"
- << " let ACached = mm_Asub[tileRow + i][k];\n"
- << " acc[i] = BCached0 * ACached.x + acc[i];\n"
- << " acc[i] = BCached1 * ACached.y + acc[i];\n"
- << " acc[i] = BCached2 * ACached.z + acc[i];\n"
- << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n"
- << " }\n";
-
+ if (transpose_a) {
+ shader.MainFunctionBody()
+ << " let Acached0 = mm_Asub[k * innerElementSize][localRow];\n"
+ << " let Acached1 = mm_Asub[k * innerElementSize + 1][localRow];\n"
+ << " let Acached2 = mm_Asub[k * innerElementSize + 2][localRow];\n"
+ << (inner_elements_size == 3 ? "" : " let Acached3 = mm_Asub[k * innerElementSize + 3][localRow];\n")
+ << " for (var i = 0; i < rowPerThread; i = i + 1) {\n"
+ << " let ACached = mm_Asub[tileCol][i];\n"
+ << " acc[i] = BCached0 * ACached0[i] + acc[i];\n"
+ << " acc[i] = BCached1 * ACached1[i] + acc[i];\n"
+ << " acc[i] = BCached2 * ACached2[i] + acc[i];\n"
+ << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached3[i] + acc[i];") << "\n"
+ << " }\n";
+ } else {
+ shader.MainFunctionBody()
+ << " for (var i = 0; i < rowPerThread; i = i + 1) {\n"
+ << " let ACached = mm_Asub[tileRow + i][k];\n"
+ << " acc[i] = BCached0 * ACached.x + acc[i];\n"
+ << " acc[i] = BCached1 * ACached.y + acc[i];\n"
+ << " acc[i] = BCached2 * ACached.z + acc[i];\n"
+ << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n"
+ << " }\n";
+ }
shader.MainFunctionBody() << " workgroupBarrier();\n"
<< " }\n"; // main for loop
@@ -174,13 +198,22 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
return Status::OK();
}
-Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderIndicesHelper& batch_dims,
+Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader,
const InlinedVector& elements_per_thread,
uint32_t workgroup_size_x,
- uint32_t workgroup_size_y) {
+ uint32_t workgroup_size_y,
+ const std::string& data_type,
+ const ShaderIndicesHelper* batch_dims,
+ bool transpose_a,
+ uint32_t tile_inner,
+ bool split_k,
+ uint32_t splitted_dim_inner,
+ bool sequentially_access_by_threads) {
+ ORT_UNUSED_PARAMETER(split_k);
+ ORT_UNUSED_PARAMETER(splitted_dim_inner);
+
const auto elements_per_thread_x = elements_per_thread[0];
const auto elements_per_thread_y = elements_per_thread[1];
- const decltype(elements_per_thread_x) tile_inner = 32;
const auto tile_a_outer = workgroup_size_y * elements_per_thread_y;
const auto tile_b_outer = workgroup_size_x * elements_per_thread_x;
@@ -194,12 +227,11 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderI
", tile_inner: ", tile_inner, " must be divisible by WorkgroupSizeY: ", workgroup_size_y);
}
- const std::string data_type = "a_element_t";
-
const auto row_per_thread_a = tile_a_height / workgroup_size_y;
const auto col_per_thread_a = tile_a_width / workgroup_size_x;
const auto row_per_thread_b = tile_inner / workgroup_size_y;
-
+ std::string write_data_to_sub_a_snippet = transpose_a ? std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, kStart + inputRow, globalRowStart + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n"
+ : std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, kStart + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n";
shader.AdditionalImplementation()
<< "var mm_Asub: array, " << tile_a_height << ">;\n"
<< "var mm_Bsub: array, " << tile_inner << ">;\n"
@@ -208,93 +240,142 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderI
<< "const tileInner = " << tile_inner << ";\n";
shader.MainFunctionBody() << " let batch = i32(global_id.z);\n"
- << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n"
+ << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "")
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
<< " var kStart = 0;\n"
<< " var acc: array, rowPerThread>;\n";
- shader.MainFunctionBody()
- << "let tileRow = i32(local_id.y) * rowPerThread;\n"
- << "let tileCol = i32(local_id.x) * colPerThread;\n"
- << "let globalRow = i32(global_id.y) * rowPerThread;\n"
- << "let globalCol = i32(global_id.x) * colPerThread;\n"
- << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
- << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n"
- << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n"
- << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n";
-
- // Loop over shared dimension.
- shader.MainFunctionBody()
- << "for (var t = 0; t < num_tiles; t = t + 1) {\n";
-
- // Load one tile of A into local memory.
- shader.MainFunctionBody()
- << " for (var innerRow = 0; innerRow < " << row_per_thread_a << "; innerRow = innerRow + 1) {\n"
- << " for (var innerCol = 0; innerCol < " << col_per_thread_a << "; innerCol = innerCol + 1) {\n"
- << " let inputRow = tileRowA + innerRow;\n"
- << " let inputCol = tileColA + innerCol;\n"
- << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, kStart + inputCol, batchIndices);\n"
- << " }\n"
- << " }\n";
-
- // Load one tile of B into local memory.
- shader.MainFunctionBody()
- << " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n"
- << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n"
- << " let inputRow = tileRowB + innerRow;\n"
- << " let inputCol = tileCol + innerCol;\n"
- << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol + innerCol, batchIndices);\n"
- << " }\n"
- << " }\n"
- << " kStart = kStart + tileInner;\n"
- << " workgroupBarrier();\n";
-
- // Compute acc values for a single thread.
- shader.MainFunctionBody()
- << "var BCached: array<" << data_type << ", colPerThread>;\n"
- << " for (var k = 0; k < tileInner; k = k + 1) {\n"
- << " for (var inner = 0; inner < colPerThread; inner = inner + 1) {\n"
- << " BCached[inner] = mm_Bsub[k][tileCol + inner];\n"
- << " }\n"
- << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n"
- << " let ACached = mm_Asub[tileRow + innerRow][k];\n"
- << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n"
- << " acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];\n"
- << " }\n"
- << " }\n"
- << " }\n"
- << " workgroupBarrier();\n"
- << "}\n";
-
- // Write the results to the output buffer
- shader.MainFunctionBody()
- << "for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n"
- << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n"
- << " mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]);\n"
- << " }\n"
- << "}\n";
-
+ if (sequentially_access_by_threads) {
+ shader.MainFunctionBody() << "let localRow = i32(local_id.y);\n"
+ << "let localCol = i32(local_id.x);\n"
+ << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
+ << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n"
+ << "\n"
+ << "// Loop over shared dimension.\n"
+ << "for (var t = 0; t < i32(num_tiles); t = t + 1) {\n"
+ << " // Load one tile of A into local memory.\n"
+ << " for (var inputRow = localRow; inputRow < " << tile_a_height << "; inputRow = inputRow + " << workgroup_size_y << ") {\n"
+ << " for (var inputCol = localCol; inputCol < " << tile_a_width << "; inputCol = inputCol + " << workgroup_size_x << ") {\n"
+ << " " << write_data_to_sub_a_snippet << "\n"
+ << " }\n"
+ << " }\n"
+ << " // Load one tile of B into local memory.\n"
+ << " for (var inputRow = localRow; inputRow < " << tile_inner << "; inputRow = inputRow + " << workgroup_size_y << ") {\n"
+ << " for (var inputCol = localCol; inputCol < " << tile_b_outer << "; inputCol = inputCol + " << workgroup_size_x << ") {\n"
+ << " mm_Bsub[inputRow][inputCol] = mm_readB(batch,\n"
+ << " kStart + inputRow,\n"
+ << " globalColStart + inputCol" << (batch_dims ? ", batchIndices" : "") << ");\n "
+ << " }\n"
+ << " }\n"
+ << " kStart = kStart + tileInner;\n"
+ << " workgroupBarrier();\n"
+ << "\n"
+ << " // Compute acc values for a single thread.\n"
+ << " var BCached : array<" << data_type << ", colPerThread>;\n"
+ << " for (var k = 0; k < tileInner; k = k + 1) {\n"
+ << " for (var inner = 0; inner < colPerThread; inner = inner + 1) {\n"
+ << " BCached[inner] = mm_Bsub[k][localCol + inner * " << workgroup_size_x << "];\n"
+ << " }\n"
+ << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n"
+ << " let ACached = " << (transpose_a ? "mm_Asub[k][localCol + innerRow * " + std::to_string(workgroup_size_y) + "];" : "mm_Asub[localRow + innerRow * " + std::to_string(workgroup_size_y) + "][k];") << "\n"
+ << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n"
+ << " acc[innerRow][innerCol] = acc[innerRow][innerCol] +\n"
+ << " ACached * BCached[innerCol];\n"
+ << " }\n"
+ << " }\n"
+ << " }\n"
+ << " workgroupBarrier();\n"
+ << "}\n"
+ << "for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n"
+ << " let gRow = globalRowStart + localRow + innerRow * " << workgroup_size_y << ";\n"
+ << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n"
+ << " let gCol = globalColStart + localCol + innerCol * " << workgroup_size_x << ";\n"
+ << " mm_write(batch, gRow, gCol, acc[innerRow][innerCol]);\n"
+ << " }\n"
+ << "}\n";
+ } else {
+ shader.MainFunctionBody()
+ << "let tileRow = i32(local_id.y) * rowPerThread;\n"
+ << "let tileCol = i32(local_id.x) * colPerThread;\n"
+ << "let globalRow = i32(global_id.y) * rowPerThread;\n"
+ << "let globalCol = i32(global_id.x) * colPerThread;\n"
+ << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n"
+ << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n"
+ << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n"
+ << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n";
+
+ // Loop over shared dimension.
+ shader.MainFunctionBody()
+ << "for (var t = 0; t < i32(num_tiles); t = t + 1) {\n";
+
+ // Load one tile of A into local memory.
+ shader.MainFunctionBody()
+ << " for (var innerRow = 0; innerRow < i32(" << row_per_thread_a << "); innerRow = innerRow + 1) {\n"
+ << " for (var innerCol = 0; innerCol < i32(" << col_per_thread_a << "); innerCol = innerCol + 1) {\n"
+ << " let inputRow = tileRowA + innerRow;\n"
+ << " let inputCol = tileColA + innerCol;\n"
+ << " " << write_data_to_sub_a_snippet << "\n"
+ << " }\n"
+ << " }\n";
+
+ // Load one tile of B into local memory.
+ shader.MainFunctionBody()
+ << " for (var innerRow = 0; innerRow < i32(" << row_per_thread_b << "); innerRow = innerRow + 1) {\n"
+ << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n"
+ << " let inputRow = tileRowB + innerRow;\n"
+ << " let inputCol = tileCol + innerCol;\n"
+ << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol + innerCol" << (nullptr != batch_dims ? ", batchIndices" : "") << ");\n"
+ << " }\n"
+ << " }\n"
+ << " kStart = kStart + tileInner;\n"
+ << " workgroupBarrier();\n";
+
+ // Compute acc values for a single thread.
+ shader.MainFunctionBody()
+ << "var BCached: array<" << data_type << ", colPerThread>;\n"
+ << " for (var k = 0; k < tileInner; k = k + 1) {\n"
+ << " for (var inner = 0; inner < i32(colPerThread); inner = inner + 1) {\n"
+ << " BCached[inner] = mm_Bsub[k][tileCol + inner];\n"
+ << " }\n"
+ << " for (var innerRow = 0; innerRow < i32(rowPerThread); innerRow = innerRow + 1) {\n"
+ << " let ACached = mm_Asub[tileRow + innerRow][k];\n"
+ << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n"
+ << " acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];\n"
+ << " }\n"
+ << " }\n"
+ << " }\n"
+ << " workgroupBarrier();\n"
+ << "}\n";
+
+ // Write the results to the output buffer
+ shader.MainFunctionBody()
+ << "for (var innerRow = 0; innerRow < i32(rowPerThread); innerRow = innerRow + 1) {\n"
+ << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n"
+ << " mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]);\n"
+ << " }\n"
+ << "}\n";
+ }
return Status::OK();
}
Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
- const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
+ const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
-
+ std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t");
// declare the read and write functions
- MatMulReadWriteFnSource(shader, a, b, output, batch_dims);
-
+ MatMulReadWriteFnSource(shader, a, b, output, batch_dims, apply_activation);
+ std::string data_type = "a_element_t";
// generate the main function
if (is_vec4_) {
- ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY()));
+ ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims));
} else {
- ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY()));
+ ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims));
}
return Status::OK();
}
diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h
index ea76468944066..d3a68ff8a57fa 100644
--- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h
+++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h
@@ -7,38 +7,53 @@
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/math/matmul_utils.h"
+#include "core/providers/webgpu/nn/fuse_utils.h"
namespace onnxruntime {
namespace webgpu {
class MatMulProgram final : public Program {
public:
- MatMulProgram(bool bias, bool is_vec4, const gsl::span& elements_per_thread) : Program{"MatMul"},
- has_bias_{bias},
- is_vec4_{is_vec4},
- elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {}
+ MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false) : Program{"MatMul"},
+ activation_(activation),
+ has_bias_{bias},
+ is_vec4_{is_vec4},
+ elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()),
+ is_channels_last_(is_channels_last) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
- WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Int32},
- {"dim_b_outer", ProgramUniformVariableDataType::Int32},
- {"dim_inner", ProgramUniformVariableDataType::Int32});
+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
+ {"dim_b_outer", ProgramUniformVariableDataType::Uint32},
+ {"dim_inner", ProgramUniformVariableDataType::Uint32});
static Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
- const ShaderIndicesHelper& batch_dims,
const InlinedVector& elements_per_thread,
uint32_t workgroup_size_x,
- uint32_t workgroup_size_y);
+ uint32_t workgroup_size_y,
+ const std::string& data_type,
+ const ShaderIndicesHelper* batch_dims,
+ bool transpose_a = false,
+ uint32_t tile_inner = 32,
+ bool split_k = false,
+ uint32_t splitted_dim_inner = 32);
static Status MakeMatMulPackedSource(ShaderHelper& shader,
- const ShaderIndicesHelper& batch_dims,
const InlinedVector& elements_per_thread,
uint32_t workgroup_size_x,
- uint32_t workgroup_size_y);
+ uint32_t workgroup_size_y,
+ const std::string& data_type,
+ const ShaderIndicesHelper* batch_dims,
+ bool transpose_a = false,
+ uint32_t tile_inner = 32,
+ bool split_k = false,
+ uint32_t splitted_dim_inner = 32,
+ bool sequentially_access_by_threads = false);
private:
+ const Activation& activation_;
const bool has_bias_;
const bool is_vec4_;
const InlinedVector elements_per_thread_;
-
- void MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, const ShaderIndicesHelper& batch_dims) const;
+ bool is_channels_last_ = false;
+ void MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, const ShaderIndicesHelper& batch_dims, std::string apply_activation) const;
};
} // namespace webgpu
diff --git a/onnxruntime/core/providers/webgpu/nn/activation_util.cc b/onnxruntime/core/providers/webgpu/nn/activation_util.cc
new file mode 100644
index 0000000000000..b5c31d98cda93
--- /dev/null
+++ b/onnxruntime/core/providers/webgpu/nn/activation_util.cc
@@ -0,0 +1,25 @@
+#include "core/providers/webgpu/nn/activation_util.h"
+#include "core/common/common.h"
+namespace onnxruntime {
+namespace webgpu {
+std::string TypeSnippet(uint32_t component, std::string data_type) {
+ switch (component) {
+ case 1:
+ return data_type;
+ case 2:
+ return "vec2<" + data_type + ">";
+ case 3:
+ return "vec3<" + data_type + ">";
+ case 4:
+ return "vec4<" + data_type + ">";
+ default:
+ ORT_THROW("Component ", component, " is not supported.");
+ }
+}
+
+std::string BiasSnippet(bool has_bias) {
+ return has_bias ? "value = value + getBiasByOutputCoords(coords);" : "";
+}
+
+} // namespace webgpu
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/nn/activation_util.h b/onnxruntime/core/providers/webgpu/nn/activation_util.h
new file mode 100644
index 0000000000000..1c9fd93e35384
--- /dev/null
+++ b/onnxruntime/core/providers/webgpu/nn/activation_util.h
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+
+namespace onnxruntime {
+namespace webgpu {
+
+extern std::string TypeSnippet(uint32_t component, std::string data_type);
+extern std::string BiasSnippet(bool has_bias);
+
+} // namespace webgpu
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc
new file mode 100644
index 0000000000000..0edad3eebe2ea
--- /dev/null
+++ b/onnxruntime/core/providers/webgpu/nn/conv.cc
@@ -0,0 +1,273 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "core/providers/webgpu/nn/conv.h"
+#include "core/providers/webgpu/nn/conv2d_mm_webgpu.h"
+#include "core/providers/webgpu/shader_helper.h"
+#include "core/providers/webgpu/webgpu_supported_types.h"
+#include "core/providers/webgpu/tensor/transpose.h"
+#include "core/providers/webgpu/nn/grouped_conv.h"
+#include "core/providers/webgpu/webgpu_utils.h"
+#include "core/providers/webgpu/math/matmul.h"
+namespace onnxruntime {
+namespace webgpu {
+
+Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm) {
+ // Transpose weights
+ auto rank = kernel_shape.NumDimensions();
+ TensorShapeVector transposed_kernel_shape_vector(rank);
+ for (size_t i = 0; i < rank; ++i) {
+ transposed_kernel_shape_vector[i] = kernel_shape[perm[i]];
+ }
+ uint32_t output_size = onnxruntime::narrow(kernel_shape.Size());
+ TensorShape transposed_kernel_shape(transposed_kernel_shape_vector);
+ *transposed_kernel = context.CreateGPUTensor(kernel->DataType(), transposed_kernel_shape);
+ bool use_shared = false;
+ TransposeProgram program{perm, use_shared};
+ program
+ .CacheHint(absl::StrJoin(perm, "-"))
+ .AddInput({kernel, ProgramTensorMetadataDependency::TypeAndRank, kernel_shape, 1})
+ .AddOutput({transposed_kernel, ProgramTensorMetadataDependency::TypeAndRank})
+ .AddUniformVariable({output_size})
+ .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
+ return context.RunProgram(program);
+}
+
+template
+Status Conv::ComputeInternal(ComputeContext& context) const {
+ bool has_bias = context.InputCount() > 2;
+ const auto* input = context.Input(0);
+ const auto* kernel = context.Input(1);
+ const auto* bias = has_bias ? context.Input(2) : nullptr;
+ TensorShape input_shape = input->Shape();
+ TensorShape kernel_shape = kernel->Shape();
+ ConvAttributes::ConvPadVector local_pads(conv_attrs_.pads.begin(), conv_attrs_.pads.end());
+ TensorShapeVector local_dilations(conv_attrs_.dilations.begin(), conv_attrs_.dilations.end());
+ TensorShapeVector local_strides(conv_attrs_.strides.begin(), conv_attrs_.strides.end());
+ TensorShapeVector kernel_spacial_shape_vector;
+ ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(kernel_shape, kernel_spacial_shape_vector, false));
+ if (local_pads.empty()) {
+ local_pads.resize(kernel_spacial_shape_vector.size() * 2, 0);
+ }
+ if (local_dilations.empty()) {
+ local_dilations.resize(kernel_spacial_shape_vector.size(), 1);
+ }
+ if (local_strides.empty()) {
+ local_strides.resize(kernel_spacial_shape_vector.size(), 1);
+ }
+ TensorShapeVector input_shape_vector = input_shape.AsShapeVector();
+ auto batch = input_shape[0];
+ TensorShapeVector output_shape_vector = {batch};
+ TensorShape input_spacial_shape = is_channels_last ? TensorShape(TensorShapeVector(std::next(input_shape_vector.begin()), std::prev(input_shape_vector.end()))) : input_shape.Slice(2);
+ ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_spacial_shape, kernel_spacial_shape_vector, local_strides, local_dilations, local_pads, output_shape_vector));
+ auto output_channels = kernel_shape[0];
+ if (is_channels_last) {
+ output_shape_vector.push_back(output_channels);
+ } else {
+ output_shape_vector.insert(output_shape_vector.begin() + 1, output_channels);
+ }
+ auto output_shape = TensorShape(output_shape_vector);
+ auto* output = context.Output(0, output_shape);
+ std::vector strides;
+ std::vector pads;
+ std::vector dilations;
+ auto transform_dim = [](int64_t dim) { return static_cast(dim); };
+ std::transform(local_pads.begin(), local_pads.end(), std::back_inserter(pads), transform_dim);
+ std::transform(local_strides.begin(), local_strides.end(), std::back_inserter(strides), transform_dim);
+ std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim);
+ auto rank = input_shape.NumDimensions();
+ const InlinedVector perm = {2, 3, 1, 0};
+ if (rank > 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d and Conv2d are supported.");
+ } else if (rank == 4) {
+ // Conv2D
+ } else if (rank == 3) {
+ // Conv1D
+ TensorShapeVector kernel_shape_vector = kernel_shape.AsShapeVector();
+ input_shape_vector.insert(input_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1);
+ output_shape_vector.insert(output_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1);
+ kernel_shape_vector.insert(kernel_shape_vector.begin() + 2, 1);
+ input_shape = TensorShape(input_shape_vector);
+ kernel_shape = TensorShape(kernel_shape_vector);
+ pads.insert(pads.begin(), 0);
+ pads.insert(pads.begin() + 2, 0);
+ strides.insert(strides.begin(), 1);
+ dilations.insert(dilations.begin(), 1);
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input and kernel tensors must have at least 3 dimensions");
+ }
+ std::vector inputs(has_bias ? 3 : 2);
+ inputs[0] = input;
+ inputs[1] = kernel;
+ if (has_bias) {
+ inputs[2] = bias;
+ }
+ std::vector modified_input_output_shapes = {input_shape, kernel_shape};
+ if (has_bias) {
+ modified_input_output_shapes.push_back(bias->Shape());
+ }
+ modified_input_output_shapes.push_back(TensorShape(output_shape_vector));
+ uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0;
+ auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2;
+ auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2;
+ std::vector updated_pads{pad0, pad1};
+ if (conv_attrs_.group > 1) {
+ Tensor transposed_kernel;
+ if (is_channels_last) {
+ ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
+ inputs[1] = &transposed_kernel;
+ modified_input_output_shapes[1] = transposed_kernel.Shape();
+ }
+ auto output_channels_per_group = output_channels / conv_attrs_.group;
+ auto components = static_cast(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1);
+ auto output_size = output_shape.Size() / components;
+ GroupedConvProgram program(activation_, has_bias, is_channels_last);
+ auto reduced_kernel_shape = ReduceShapeByComponents(modified_input_output_shapes[1], components);
+ auto reduced_output_shape = ReduceShapeByComponents(modified_input_output_shapes[has_bias ? 3 : 2], components);
+ program.CacheHint(activation_.ToString(), std::to_string(components), std::to_string(is_channels_last))
+ .AddInput({inputs[0], ProgramTensorMetadataDependency::TypeAndRank, modified_input_output_shapes[0], 1})
+ .AddInput({inputs[1], ProgramTensorMetadataDependency::TypeAndRank, reduced_kernel_shape, components})
+ .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, reduced_output_shape, components})
+ .AddUniformVariables({{static_cast(output_size)}, {dilations}, {strides}, {updated_pads}, {static_cast(output_channels_per_group)}, {static_cast(components)}})
+ .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
+ if (has_bias) {
+ auto reduced_bias_shape = ReduceShapeByComponents(modified_input_output_shapes[2], components);
+ program.AddInput({inputs[2], ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, components});
+ }
+ return context.RunProgram(program);
+ }
+ const auto input_height = input_shape[is_channels_last ? 1 : 2];
+ const auto input_width = input_shape[is_channels_last ? 2 : 3];
+ const auto input_channels = input_shape[is_channels_last ? 3 : 1];
+ const auto kernel_height = kernel_shape[2];
+ const auto kernel_width = kernel_shape[3];
+ const auto output_height = output_shape_vector[is_channels_last ? 1 : 2];
+ const auto output_width = output_shape_vector[is_channels_last ? 2 : 3];
+
+ const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0;
+ if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) {
+ Tensor transposed_kernel;
+ TensorShape input_reshape;
+ TensorShape kernel_reshape;
+ TensorShape matmul_output_shape;
+ std::vector matmul_inputs;
+ std::vector matmul_input_reshapes;
+ if (is_channels_last) {
+ // Transpose weights
+
+ ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
+ inputs[1] = &transposed_kernel;
+ if (same_size) {
+ const auto shared_dim = input_height * input_width * input_channels;
+ input_reshape = TensorShape({1, batch, shared_dim});
+ kernel_reshape = TensorShape({1, shared_dim, output_channels});
+ matmul_output_shape = TensorShape({1, batch, output_channels});
+ } else {
+ input_reshape = TensorShape({batch, input_height * input_width, input_channels});
+ kernel_reshape = TensorShape({1, input_channels, output_channels});
+ matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels});
+ }
+ matmul_inputs.push_back(input);
+ matmul_inputs.push_back(&transposed_kernel);
+ matmul_input_reshapes.push_back(input_reshape);
+ matmul_input_reshapes.push_back(kernel_reshape);
+ } else {
+ input_reshape = TensorShape({batch, input_channels, input_height * input_width});
+ kernel_reshape = TensorShape({1, output_channels, input_channels});
+ matmul_output_shape = TensorShape({batch, output_channels, output_height * output_width});
+ matmul_inputs.push_back(kernel);
+ matmul_inputs.push_back(input);
+ matmul_input_reshapes.push_back(kernel_reshape);
+ matmul_input_reshapes.push_back(input_reshape);
+ }
+ if (has_bias) {
+ matmul_inputs.push_back(bias);
+ }
+ auto N = matmul_output_shape[2];
+ auto matmul_first_input_numdims = matmul_input_reshapes[0].NumDimensions();
+ auto K = matmul_input_reshapes[0].GetDims()[matmul_first_input_numdims - 1];
+ if (N < 8 && K < 8) {
+ const auto components = GetMaxComponents(N);
+ const auto a_components = GetMaxComponents(K);
+ const auto output_number = GetMaxComponents(output_shape[1]);
+ uint32_t output_size = static_cast(output_shape.Size() / components / output_number);
+ const size_t output_rank = matmul_output_shape.NumDimensions();
+ TensorShape outer_dims = output_rank > 2 ? matmul_output_shape.Slice(0, output_rank - 2) : TensorShape({});
+ MatMulNaiveProgram program(activation_, output_rank, output_number, has_bias);
+ program
+ .CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number))
+ .AddInputs({{matmul_inputs[0], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[0], a_components), int(a_components)},
+ {matmul_inputs[1], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[1], components), int(components)}});
+ if (has_bias) {
+ program.AddInput({bias, ProgramTensorMetadataDependency::Rank, bias->Shape(), components});
+ }
+ program
+ .AddOutputs({{output, ProgramTensorMetadataDependency::None, ReduceShapeByComponents(matmul_output_shape, components), int(components)}})
+ .SetDispatchGroupSize(static_cast((output_size + 63) / 64))
+ .AddIndices(outer_dims)
+ .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast