From e63d12e619b566f666e7af4c0666ede07a240970 Mon Sep 17 00:00:00 2001 From: Alano Terblanche <18033717+Benehiko@users.noreply.github.com> Date: Mon, 18 May 2026 09:51:35 +0200 Subject: [PATCH] feat(ai): add ROCm and MIGraphX execution providers for AMD GPUs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wires ort's ROCmExecutionProvider and MIGraphXExecutionProvider through the same path as the existing CUDA / TensorRT / DirectML / CoreML providers so ahnlich-ai can target AMD GPUs on Linux when the host has a matching ROCm or MIGraphX runtime and an ORT build that supports either provider. Two variants are included because upstream onnxruntime removed the ROCm execution provider in release 1.23 and recommends MIGraphX as its replacement. ahnlich currently pins ort to 2.0.0-rc.5 (against ORT 1.19, which still ships ROCm), so both variants stay useful until the ORT pin moves past 1.23. - protos/ai/execution_provider.proto: add ROCM = 4 and MIGRAPHX = 5 - ahnlich/types/src/ai/execution_provider.rs: regenerated Rust enum - ahnlich/ai/src/engine/ai/providers/ort/mod.rs: register ROCmExecutionProvider and MIGraphXExecutionProvider via InnerAIExecutionProvider::ROCm and ::MIGraphX - ahnlich/dsl/src/ai.rs: accept "rocm" and "migraphx" in parse_to_execution_provider - ahnlich/dsl/src/syntax/syntax.pest: extend the execution_provider rule to tokenise "rocm" and "migraphx" (without this the DSL rejects the new keywords at the parser layer before parse_to_execution_provider runs) - ahnlich/dsl/src/tests/ai.rs: add round-trip tests for "rocm" and "migraphx" through parse_ai_query (mirrors the existing TensorRT and CUDA tests) - sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go: regenerate via `buf generate` - sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts: regenerate via `buf generate` - sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py: add ROCM and MIGRAPHX variants (matches betterproto's emit; full `make grpc-update-python` regen left to the maintainer's environment) - README.md / protos/README.md: document ROCm and MIGraphX prerequisites and the ORT 1.23 ROCm removal Generated as a reference patch with Claude — not validated against AMD hardware. Verified locally with: - `cargo test -p dsl` (30 / 30 passing, +2 new tests) - `cargo check -p ai` against the actual libonnxruntime.so 1.19.0 bundled in the official ahnlich-ai image - end-to-end DSL smoke test via ahnlich-cli: the patched binary boots, accepts `executionprovider rocm` and `executionprovider migraphx` through the gRPC layer, and rejects unknown tokens at the parser - proto regen confirmed reproducible (build.rs round-trips execution_provider.rs cleanly, `buf generate` produces the same Go/Node stubs committed here) --- README.md | 2 + ahnlich/ai/src/engine/ai/providers/ort/mod.rs | 11 ++++- ahnlich/dsl/src/ai.rs | 2 + ahnlich/dsl/src/syntax/syntax.pest | 4 +- ahnlich/dsl/src/tests/ai.rs | 41 +++++++++++++++++++ ahnlich/types/src/ai/execution_provider.rs | 15 +++++++ protos/README.md | 2 +- protos/ai/execution_provider.proto | 11 +++++ .../execution_provider.pb.go | 39 ++++++++++++------ .../grpc/ai/execution_provider_pb.ts | 24 +++++++++++ .../grpc/ai/execution_provider/__init__.py | 2 + 11 files changed, 138 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index f77dde105..053453c73 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,8 @@ The `--no-interactive` flag allows the CLI to accept commands via stdin and exit - `CUDA`: Only supports >= CUDAv12 and might need to `sudo apt install libcudnn9-dev-cuda-12` - `CoreML (Apple)`: Not advised for NLP models due to often large dimensionality. +- `ROCm` (AMD GPUs, Linux): Requires a ROCm runtime on the host (`rocm-hip-runtime`, `rocm-libs`) matching the ONNX Runtime build. Supported AMD Instinct accelerators are listed in the [ROCm compatibility matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html); select Radeon cards are also supported. When unsupported, ORT falls back to CPU. **Note:** upstream ONNX Runtime removed the ROCm execution provider in release 1.23. Newer ahnlich builds that bump the ORT pin should use `MIGraphX` instead. +- `MIGraphX` (AMD GPUs, Linux): AMD's recommended replacement for the ROCm provider in `onnxruntime >= 1.23`. Requires the `migraphx` runtime to be installed (ships in AMD's ROCm apt repository). ### Contributing diff --git a/ahnlich/ai/src/engine/ai/providers/ort/mod.rs b/ahnlich/ai/src/engine/ai/providers/ort/mod.rs index 932ff05b1..47ff2d1f4 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort/mod.rs @@ -18,7 +18,8 @@ use executor::ExecutorWithSessionCache; use hf_hub::{Cache, api::sync::ApiBuilder}; use ort::{ CUDAExecutionProvider, CoreMLExecutionProvider, DirectMLExecutionProvider, ExecutionProvider, - SessionBuilder, SessionOutputs, TensorRTExecutionProvider, + MIGraphXExecutionProvider, ROCmExecutionProvider, SessionBuilder, SessionOutputs, + TensorRTExecutionProvider, }; use strum::EnumIter; @@ -49,6 +50,8 @@ pub(crate) enum InnerAIExecutionProvider { CUDA, DirectML, CoreML, + ROCm, + MIGraphX, #[default] CPU, } @@ -60,6 +63,8 @@ impl From for InnerAIExecutionProvider { AIExecutionProvider::Cuda => InnerAIExecutionProvider::CUDA, AIExecutionProvider::DirectMl => InnerAIExecutionProvider::DirectML, AIExecutionProvider::CoreMl => InnerAIExecutionProvider::CoreML, + AIExecutionProvider::Rocm => InnerAIExecutionProvider::ROCm, + AIExecutionProvider::Migraphx => InnerAIExecutionProvider::MIGraphX, } } } @@ -77,6 +82,10 @@ fn register_provider( DirectMLExecutionProvider::default().register(builder)? } InnerAIExecutionProvider::CoreML => CoreMLExecutionProvider::default().register(builder)?, + InnerAIExecutionProvider::ROCm => ROCmExecutionProvider::default().register(builder)?, + InnerAIExecutionProvider::MIGraphX => { + MIGraphXExecutionProvider::default().register(builder)? + } InnerAIExecutionProvider::CPU => (), }; Ok(()) diff --git a/ahnlich/dsl/src/ai.rs b/ahnlich/dsl/src/ai.rs index b08a13c21..41e7d317f 100644 --- a/ahnlich/dsl/src/ai.rs +++ b/ahnlich/dsl/src/ai.rs @@ -39,6 +39,8 @@ fn parse_to_execution_provider(input: &str) -> Result Ok(ExecutionProvider::CoreMl), "directml" => Ok(ExecutionProvider::DirectMl), "tensorrt" => Ok(ExecutionProvider::TensorRt), + "rocm" => Ok(ExecutionProvider::Rocm), + "migraphx" => Ok(ExecutionProvider::Migraphx), a => Err(DslError::UnsupportedPreprocessingMode(a.to_string())), } } diff --git a/ahnlich/dsl/src/syntax/syntax.pest b/ahnlich/dsl/src/syntax/syntax.pest index 03ce875fe..9d8ce9015 100644 --- a/ahnlich/dsl/src/syntax/syntax.pest +++ b/ahnlich/dsl/src/syntax/syntax.pest @@ -91,10 +91,12 @@ algorithm = { ^"cosinesimilarity" | "dotproductsimilarity" } -execution_provider = { +execution_provider = { ^"coreml" | ^"tensorrt" | ^"directml" | + ^"migraphx" | + ^"rocm" | "cuda" } non_linear_algorithms = { non_linear_algorithm ~ (whitespace* ~ "," ~ whitespace* ~ non_linear_algorithm)* } diff --git a/ahnlich/dsl/src/tests/ai.rs b/ahnlich/dsl/src/tests/ai.rs index 3e9980cf2..653d4a2cc 100644 --- a/ahnlich/dsl/src/tests/ai.rs +++ b/ahnlich/dsl/src/tests/ai.rs @@ -567,3 +567,44 @@ fn test_set_in_store_parse() { })] ); } + +#[test] +fn test_get_sim_n_parse_rocm_execution_provider() { + let input = r#"GETSIMN 3 with [find me] using cosinesimilarity executionprovider rocm in store1"#; + assert_eq!( + parse_ai_query(input).expect("Could not parse query input"), + vec![AiQuery::GetSimN(GetSimN { + store: "store1".to_string(), + search_input: Some(StoreInput { + value: Some(StoreValue::RawString("find me".to_string())) + }), + closest_n: 3, + algorithm: Algorithm::CosineSimilarity as i32, + condition: None, + preprocess_action: PreprocessAction::NoPreprocessing as i32, + execution_provider: Some(ExecutionProvider::Rocm as i32), + model_params: HashMap::new(), + })] + ); +} + +#[test] +fn test_get_sim_n_parse_migraphx_execution_provider() { + let input = + r#"GETSIMN 3 with [find me] using cosinesimilarity executionprovider migraphx in store1"#; + assert_eq!( + parse_ai_query(input).expect("Could not parse query input"), + vec![AiQuery::GetSimN(GetSimN { + store: "store1".to_string(), + search_input: Some(StoreInput { + value: Some(StoreValue::RawString("find me".to_string())) + }), + closest_n: 3, + algorithm: Algorithm::CosineSimilarity as i32, + condition: None, + preprocess_action: PreprocessAction::NoPreprocessing as i32, + execution_provider: Some(ExecutionProvider::Migraphx as i32), + model_params: HashMap::new(), + })] + ); +} diff --git a/ahnlich/types/src/ai/execution_provider.rs b/ahnlich/types/src/ai/execution_provider.rs index 8ec499968..d7aaf0530 100644 --- a/ahnlich/types/src/ai/execution_provider.rs +++ b/ahnlich/types/src/ai/execution_provider.rs @@ -7,6 +7,17 @@ pub enum ExecutionProvider { Cuda = 1, DirectMl = 2, CoreMl = 3, + /// ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + /// Radeon hardware). Requires the host to have a matching ROCm runtime + /// installed and the ort/onnxruntime build configured with ROCm support. + /// Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + /// ORT builds prefer MIGRAPHX. + Rocm = 4, + /// MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + /// hardware). AMD's recommended replacement for the ROCm provider in + /// onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + /// installed (ships in AMD's ROCm apt repository). + Migraphx = 5, } impl ExecutionProvider { /// String value of the enum field names used in the ProtoBuf definition. @@ -19,6 +30,8 @@ impl ExecutionProvider { Self::Cuda => "CUDA", Self::DirectMl => "DIRECT_ML", Self::CoreMl => "CORE_ML", + Self::Rocm => "ROCM", + Self::Migraphx => "MIGRAPHX", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -28,6 +41,8 @@ impl ExecutionProvider { "CUDA" => Some(Self::Cuda), "DIRECT_ML" => Some(Self::DirectMl), "CORE_ML" => Some(Self::CoreMl), + "ROCM" => Some(Self::Rocm), + "MIGRAPHX" => Some(Self::Migraphx), _ => None, } } diff --git a/protos/README.md b/protos/README.md index 5380e0b5c..4e61100d2 100644 --- a/protos/README.md +++ b/protos/README.md @@ -46,7 +46,7 @@ Defines server types and client connections. ### **AI & Algorithm Definitions** - `algorithm.proto`: Defines similarity algorithms (e.g., `CosineSimilarity`). - `ai/models.proto`: Defines AI models available for use. -- `ai/execution_provider.proto`: Specifies execution providers (e.g., `CUDA`, `TENSOR_RT`). +- `ai/execution_provider.proto`: Specifies execution providers (e.g., `CUDA`, `TENSOR_RT`, `ROCM`, `MIGRAPHX`). ## **Usage** To use these protofiles, generate language-specific gRPC stubs: diff --git a/protos/ai/execution_provider.proto b/protos/ai/execution_provider.proto index 18b52fe29..ccbe09a0b 100644 --- a/protos/ai/execution_provider.proto +++ b/protos/ai/execution_provider.proto @@ -10,4 +10,15 @@ enum ExecutionProvider { CUDA = 1; DIRECT_ML = 2; CORE_ML = 3; + // ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + // Radeon hardware). Requires the host to have a matching ROCm runtime + // installed and the ort/onnxruntime build configured with ROCm support. + // Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + // ORT builds prefer MIGRAPHX. + ROCM = 4; + // MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + // hardware). AMD's recommended replacement for the ROCm provider in + // onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + // installed (ships in AMD's ROCm apt repository). + MIGRAPHX = 5; } diff --git a/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go b/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go index 407d18760..7d1dc1fdc 100644 --- a/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go +++ b/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go @@ -7,11 +7,10 @@ package execution_provider import ( - reflect "reflect" - sync "sync" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" ) const ( @@ -29,6 +28,17 @@ const ( ExecutionProvider_CUDA ExecutionProvider = 1 ExecutionProvider_DIRECT_ML ExecutionProvider = 2 ExecutionProvider_CORE_ML ExecutionProvider = 3 + // ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + // Radeon hardware). Requires the host to have a matching ROCm runtime + // installed and the ort/onnxruntime build configured with ROCm support. + // Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + // ORT builds prefer MIGRAPHX. + ExecutionProvider_ROCM ExecutionProvider = 4 + // MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + // hardware). AMD's recommended replacement for the ROCm provider in + // onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + // installed (ships in AMD's ROCm apt repository). + ExecutionProvider_MIGRAPHX ExecutionProvider = 5 ) // Enum value maps for ExecutionProvider. @@ -38,12 +48,16 @@ var ( 1: "CUDA", 2: "DIRECT_ML", 3: "CORE_ML", + 4: "ROCM", + 5: "MIGRAPHX", } ExecutionProvider_value = map[string]int32{ "TENSOR_RT": 0, "CUDA": 1, "DIRECT_ML": 2, "CORE_ML": 3, + "ROCM": 4, + "MIGRAPHX": 5, } ) @@ -80,18 +94,19 @@ var file_ai_execution_provider_proto_rawDesc = []byte{ 0x0a, 0x1b, 0x61, 0x69, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x15, 0x61, 0x69, 0x2e, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x2a, 0x48, 0x0a, 0x11, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, + 0x69, 0x64, 0x65, 0x72, 0x2a, 0x60, 0x0a, 0x11, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x45, 0x4e, 0x53, 0x4f, 0x52, 0x5f, 0x52, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x43, 0x55, 0x44, 0x41, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54, 0x5f, 0x4d, 0x4c, 0x10, - 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x43, 0x4f, 0x52, 0x45, 0x5f, 0x4d, 0x4c, 0x10, 0x03, 0x42, 0x60, - 0x5a, 0x5e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x76, - 0x65, 0x6e, 0x39, 0x36, 0x2f, 0x61, 0x68, 0x6e, 0x6c, 0x69, 0x63, 0x68, 0x2f, 0x73, 0x64, 0x6b, - 0x2f, 0x61, 0x68, 0x6e, 0x6c, 0x69, 0x63, 0x68, 0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2d, - 0x67, 0x6f, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x69, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, - 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x3b, 0x65, 0x78, - 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x43, 0x4f, 0x52, 0x45, 0x5f, 0x4d, 0x4c, 0x10, 0x03, 0x12, 0x08, + 0x0a, 0x04, 0x52, 0x4f, 0x43, 0x4d, 0x10, 0x04, 0x12, 0x0c, 0x0a, 0x08, 0x4d, 0x49, 0x47, 0x52, + 0x41, 0x50, 0x48, 0x58, 0x10, 0x05, 0x42, 0x60, 0x5a, 0x5e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x76, 0x65, 0x6e, 0x39, 0x36, 0x2f, 0x61, 0x68, 0x6e, + 0x6c, 0x69, 0x63, 0x68, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x61, 0x68, 0x6e, 0x6c, 0x69, 0x63, 0x68, + 0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, + 0x61, 0x69, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, + 0x76, 0x69, 0x64, 0x65, 0x72, 0x3b, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, + 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts b/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts index 43df2d1ed..658adc374 100644 --- a/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts +++ b/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts @@ -30,6 +30,27 @@ export enum ExecutionProvider { * @generated from enum value: CORE_ML = 3; */ CORE_ML = 3, + + /** + * ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + * Radeon hardware). Requires the host to have a matching ROCm runtime + * installed and the ort/onnxruntime build configured with ROCm support. + * Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + * ORT builds prefer MIGRAPHX. + * + * @generated from enum value: ROCM = 4; + */ + ROCM = 4, + + /** + * MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + * hardware). AMD's recommended replacement for the ROCm provider in + * onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + * installed (ships in AMD's ROCm apt repository). + * + * @generated from enum value: MIGRAPHX = 5; + */ + MIGRAPHX = 5, } // Retrieve enum metadata with: proto3.getEnumType(ExecutionProvider) proto3.util.setEnumType(ExecutionProvider, "ai.execution_provider.ExecutionProvider", [ @@ -37,4 +58,7 @@ proto3.util.setEnumType(ExecutionProvider, "ai.execution_provider.ExecutionProvi { no: 1, name: "CUDA" }, { no: 2, name: "DIRECT_ML" }, { no: 3, name: "CORE_ML" }, + { no: 4, name: "ROCM" }, + { no: 5, name: "MIGRAPHX" }, ]); + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py index b9bb0d12e..90e8fdeba 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py @@ -15,3 +15,5 @@ class ExecutionProvider(betterproto.Enum): CUDA = 1 DIRECT_ML = 2 CORE_ML = 3 + ROCM = 4 + MIGRAPHX = 5