diff --git a/README.md b/README.md index f77dde105..053453c73 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,8 @@ The `--no-interactive` flag allows the CLI to accept commands via stdin and exit - `CUDA`: Only supports >= CUDAv12 and might need to `sudo apt install libcudnn9-dev-cuda-12` - `CoreML (Apple)`: Not advised for NLP models due to often large dimensionality. +- `ROCm` (AMD GPUs, Linux): Requires a ROCm runtime on the host (`rocm-hip-runtime`, `rocm-libs`) matching the ONNX Runtime build. Supported AMD Instinct accelerators are listed in the [ROCm compatibility matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html); select Radeon cards are also supported. When unsupported, ORT falls back to CPU. **Note:** upstream ONNX Runtime removed the ROCm execution provider in release 1.23. Newer ahnlich builds that bump the ORT pin should use `MIGraphX` instead. +- `MIGraphX` (AMD GPUs, Linux): AMD's recommended replacement for the ROCm provider in `onnxruntime >= 1.23`. Requires the `migraphx` runtime to be installed (ships in AMD's ROCm apt repository). ### Contributing diff --git a/ahnlich/ai/src/engine/ai/providers/ort/mod.rs b/ahnlich/ai/src/engine/ai/providers/ort/mod.rs index 932ff05b1..47ff2d1f4 100644 --- a/ahnlich/ai/src/engine/ai/providers/ort/mod.rs +++ b/ahnlich/ai/src/engine/ai/providers/ort/mod.rs @@ -18,7 +18,8 @@ use executor::ExecutorWithSessionCache; use hf_hub::{Cache, api::sync::ApiBuilder}; use ort::{ CUDAExecutionProvider, CoreMLExecutionProvider, DirectMLExecutionProvider, ExecutionProvider, - SessionBuilder, SessionOutputs, TensorRTExecutionProvider, + MIGraphXExecutionProvider, ROCmExecutionProvider, SessionBuilder, SessionOutputs, + TensorRTExecutionProvider, }; use strum::EnumIter; @@ -49,6 +50,8 @@ pub(crate) enum InnerAIExecutionProvider { CUDA, DirectML, CoreML, + ROCm, + MIGraphX, #[default] CPU, } @@ -60,6 +63,8 @@ impl From for InnerAIExecutionProvider { AIExecutionProvider::Cuda => InnerAIExecutionProvider::CUDA, AIExecutionProvider::DirectMl => InnerAIExecutionProvider::DirectML, AIExecutionProvider::CoreMl => InnerAIExecutionProvider::CoreML, + AIExecutionProvider::Rocm => InnerAIExecutionProvider::ROCm, + AIExecutionProvider::Migraphx => InnerAIExecutionProvider::MIGraphX, } } } @@ -77,6 +82,10 @@ fn register_provider( DirectMLExecutionProvider::default().register(builder)? } InnerAIExecutionProvider::CoreML => CoreMLExecutionProvider::default().register(builder)?, + InnerAIExecutionProvider::ROCm => ROCmExecutionProvider::default().register(builder)?, + InnerAIExecutionProvider::MIGraphX => { + MIGraphXExecutionProvider::default().register(builder)? + } InnerAIExecutionProvider::CPU => (), }; Ok(()) diff --git a/ahnlich/dsl/src/ai.rs b/ahnlich/dsl/src/ai.rs index b08a13c21..41e7d317f 100644 --- a/ahnlich/dsl/src/ai.rs +++ b/ahnlich/dsl/src/ai.rs @@ -39,6 +39,8 @@ fn parse_to_execution_provider(input: &str) -> Result Ok(ExecutionProvider::CoreMl), "directml" => Ok(ExecutionProvider::DirectMl), "tensorrt" => Ok(ExecutionProvider::TensorRt), + "rocm" => Ok(ExecutionProvider::Rocm), + "migraphx" => Ok(ExecutionProvider::Migraphx), a => Err(DslError::UnsupportedPreprocessingMode(a.to_string())), } } diff --git a/ahnlich/dsl/src/syntax/syntax.pest b/ahnlich/dsl/src/syntax/syntax.pest index 03ce875fe..9d8ce9015 100644 --- a/ahnlich/dsl/src/syntax/syntax.pest +++ b/ahnlich/dsl/src/syntax/syntax.pest @@ -91,10 +91,12 @@ algorithm = { ^"cosinesimilarity" | "dotproductsimilarity" } -execution_provider = { +execution_provider = { ^"coreml" | ^"tensorrt" | ^"directml" | + ^"migraphx" | + ^"rocm" | "cuda" } non_linear_algorithms = { non_linear_algorithm ~ (whitespace* ~ "," ~ whitespace* ~ non_linear_algorithm)* } diff --git a/ahnlich/dsl/src/tests/ai.rs b/ahnlich/dsl/src/tests/ai.rs index 3e9980cf2..653d4a2cc 100644 --- a/ahnlich/dsl/src/tests/ai.rs +++ b/ahnlich/dsl/src/tests/ai.rs @@ -567,3 +567,44 @@ fn test_set_in_store_parse() { })] ); } + +#[test] +fn test_get_sim_n_parse_rocm_execution_provider() { + let input = r#"GETSIMN 3 with [find me] using cosinesimilarity executionprovider rocm in store1"#; + assert_eq!( + parse_ai_query(input).expect("Could not parse query input"), + vec![AiQuery::GetSimN(GetSimN { + store: "store1".to_string(), + search_input: Some(StoreInput { + value: Some(StoreValue::RawString("find me".to_string())) + }), + closest_n: 3, + algorithm: Algorithm::CosineSimilarity as i32, + condition: None, + preprocess_action: PreprocessAction::NoPreprocessing as i32, + execution_provider: Some(ExecutionProvider::Rocm as i32), + model_params: HashMap::new(), + })] + ); +} + +#[test] +fn test_get_sim_n_parse_migraphx_execution_provider() { + let input = + r#"GETSIMN 3 with [find me] using cosinesimilarity executionprovider migraphx in store1"#; + assert_eq!( + parse_ai_query(input).expect("Could not parse query input"), + vec![AiQuery::GetSimN(GetSimN { + store: "store1".to_string(), + search_input: Some(StoreInput { + value: Some(StoreValue::RawString("find me".to_string())) + }), + closest_n: 3, + algorithm: Algorithm::CosineSimilarity as i32, + condition: None, + preprocess_action: PreprocessAction::NoPreprocessing as i32, + execution_provider: Some(ExecutionProvider::Migraphx as i32), + model_params: HashMap::new(), + })] + ); +} diff --git a/ahnlich/types/src/ai/execution_provider.rs b/ahnlich/types/src/ai/execution_provider.rs index 8ec499968..d7aaf0530 100644 --- a/ahnlich/types/src/ai/execution_provider.rs +++ b/ahnlich/types/src/ai/execution_provider.rs @@ -7,6 +7,17 @@ pub enum ExecutionProvider { Cuda = 1, DirectMl = 2, CoreMl = 3, + /// ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + /// Radeon hardware). Requires the host to have a matching ROCm runtime + /// installed and the ort/onnxruntime build configured with ROCm support. + /// Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + /// ORT builds prefer MIGRAPHX. + Rocm = 4, + /// MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + /// hardware). AMD's recommended replacement for the ROCm provider in + /// onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + /// installed (ships in AMD's ROCm apt repository). + Migraphx = 5, } impl ExecutionProvider { /// String value of the enum field names used in the ProtoBuf definition. @@ -19,6 +30,8 @@ impl ExecutionProvider { Self::Cuda => "CUDA", Self::DirectMl => "DIRECT_ML", Self::CoreMl => "CORE_ML", + Self::Rocm => "ROCM", + Self::Migraphx => "MIGRAPHX", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -28,6 +41,8 @@ impl ExecutionProvider { "CUDA" => Some(Self::Cuda), "DIRECT_ML" => Some(Self::DirectMl), "CORE_ML" => Some(Self::CoreMl), + "ROCM" => Some(Self::Rocm), + "MIGRAPHX" => Some(Self::Migraphx), _ => None, } } diff --git a/protos/README.md b/protos/README.md index 5380e0b5c..4e61100d2 100644 --- a/protos/README.md +++ b/protos/README.md @@ -46,7 +46,7 @@ Defines server types and client connections. ### **AI & Algorithm Definitions** - `algorithm.proto`: Defines similarity algorithms (e.g., `CosineSimilarity`). - `ai/models.proto`: Defines AI models available for use. -- `ai/execution_provider.proto`: Specifies execution providers (e.g., `CUDA`, `TENSOR_RT`). +- `ai/execution_provider.proto`: Specifies execution providers (e.g., `CUDA`, `TENSOR_RT`, `ROCM`, `MIGRAPHX`). ## **Usage** To use these protofiles, generate language-specific gRPC stubs: diff --git a/protos/ai/execution_provider.proto b/protos/ai/execution_provider.proto index 18b52fe29..ccbe09a0b 100644 --- a/protos/ai/execution_provider.proto +++ b/protos/ai/execution_provider.proto @@ -10,4 +10,15 @@ enum ExecutionProvider { CUDA = 1; DIRECT_ML = 2; CORE_ML = 3; + // ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + // Radeon hardware). Requires the host to have a matching ROCm runtime + // installed and the ort/onnxruntime build configured with ROCm support. + // Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + // ORT builds prefer MIGRAPHX. + ROCM = 4; + // MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + // hardware). AMD's recommended replacement for the ROCm provider in + // onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + // installed (ships in AMD's ROCm apt repository). + MIGRAPHX = 5; } diff --git a/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go b/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go index 407d18760..7d1dc1fdc 100644 --- a/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go +++ b/sdk/ahnlich-client-go/grpc/ai/execution_provider/execution_provider.pb.go @@ -7,11 +7,10 @@ package execution_provider import ( - reflect "reflect" - sync "sync" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" ) const ( @@ -29,6 +28,17 @@ const ( ExecutionProvider_CUDA ExecutionProvider = 1 ExecutionProvider_DIRECT_ML ExecutionProvider = 2 ExecutionProvider_CORE_ML ExecutionProvider = 3 + // ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + // Radeon hardware). Requires the host to have a matching ROCm runtime + // installed and the ort/onnxruntime build configured with ROCm support. + // Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + // ORT builds prefer MIGRAPHX. + ExecutionProvider_ROCM ExecutionProvider = 4 + // MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + // hardware). AMD's recommended replacement for the ROCm provider in + // onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + // installed (ships in AMD's ROCm apt repository). + ExecutionProvider_MIGRAPHX ExecutionProvider = 5 ) // Enum value maps for ExecutionProvider. @@ -38,12 +48,16 @@ var ( 1: "CUDA", 2: "DIRECT_ML", 3: "CORE_ML", + 4: "ROCM", + 5: "MIGRAPHX", } ExecutionProvider_value = map[string]int32{ "TENSOR_RT": 0, "CUDA": 1, "DIRECT_ML": 2, "CORE_ML": 3, + "ROCM": 4, + "MIGRAPHX": 5, } ) @@ -80,18 +94,19 @@ var file_ai_execution_provider_proto_rawDesc = []byte{ 0x0a, 0x1b, 0x61, 0x69, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x15, 0x61, 0x69, 0x2e, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x2a, 0x48, 0x0a, 0x11, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, + 0x69, 0x64, 0x65, 0x72, 0x2a, 0x60, 0x0a, 0x11, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x45, 0x4e, 0x53, 0x4f, 0x52, 0x5f, 0x52, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x43, 0x55, 0x44, 0x41, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x44, 0x49, 0x52, 0x45, 0x43, 0x54, 0x5f, 0x4d, 0x4c, 0x10, - 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x43, 0x4f, 0x52, 0x45, 0x5f, 0x4d, 0x4c, 0x10, 0x03, 0x42, 0x60, - 0x5a, 0x5e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x76, - 0x65, 0x6e, 0x39, 0x36, 0x2f, 0x61, 0x68, 0x6e, 0x6c, 0x69, 0x63, 0x68, 0x2f, 0x73, 0x64, 0x6b, - 0x2f, 0x61, 0x68, 0x6e, 0x6c, 0x69, 0x63, 0x68, 0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2d, - 0x67, 0x6f, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x69, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, - 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x3b, 0x65, 0x78, - 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x43, 0x4f, 0x52, 0x45, 0x5f, 0x4d, 0x4c, 0x10, 0x03, 0x12, 0x08, + 0x0a, 0x04, 0x52, 0x4f, 0x43, 0x4d, 0x10, 0x04, 0x12, 0x0c, 0x0a, 0x08, 0x4d, 0x49, 0x47, 0x52, + 0x41, 0x50, 0x48, 0x58, 0x10, 0x05, 0x42, 0x60, 0x5a, 0x5e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x65, 0x76, 0x65, 0x6e, 0x39, 0x36, 0x2f, 0x61, 0x68, 0x6e, + 0x6c, 0x69, 0x63, 0x68, 0x2f, 0x73, 0x64, 0x6b, 0x2f, 0x61, 0x68, 0x6e, 0x6c, 0x69, 0x63, 0x68, + 0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, + 0x61, 0x69, 0x2f, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x70, 0x72, 0x6f, + 0x76, 0x69, 0x64, 0x65, 0x72, 0x3b, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x5f, + 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts b/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts index 43df2d1ed..658adc374 100644 --- a/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts +++ b/sdk/ahnlich-client-node/grpc/ai/execution_provider_pb.ts @@ -30,6 +30,27 @@ export enum ExecutionProvider { * @generated from enum value: CORE_ML = 3; */ CORE_ML = 3, + + /** + * ROCm execution provider for AMD GPUs (Linux + supported AMD Instinct / + * Radeon hardware). Requires the host to have a matching ROCm runtime + * installed and the ort/onnxruntime build configured with ROCm support. + * Note: upstream onnxruntime removed the ROCm provider in 1.23; for newer + * ORT builds prefer MIGRAPHX. + * + * @generated from enum value: ROCM = 4; + */ + ROCM = 4, + + /** + * MIGraphX execution provider for AMD GPUs (Linux + supported AMD Instinct + * hardware). AMD's recommended replacement for the ROCm provider in + * onnxruntime >= 1.23. Requires the host to have the MIGraphX runtime + * installed (ships in AMD's ROCm apt repository). + * + * @generated from enum value: MIGRAPHX = 5; + */ + MIGRAPHX = 5, } // Retrieve enum metadata with: proto3.getEnumType(ExecutionProvider) proto3.util.setEnumType(ExecutionProvider, "ai.execution_provider.ExecutionProvider", [ @@ -37,4 +58,7 @@ proto3.util.setEnumType(ExecutionProvider, "ai.execution_provider.ExecutionProvi { no: 1, name: "CUDA" }, { no: 2, name: "DIRECT_ML" }, { no: 3, name: "CORE_ML" }, + { no: 4, name: "ROCM" }, + { no: 5, name: "MIGRAPHX" }, ]); + diff --git a/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py b/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py index b9bb0d12e..90e8fdeba 100644 --- a/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py +++ b/sdk/ahnlich-client-py/ahnlich_client_py/grpc/ai/execution_provider/__init__.py @@ -15,3 +15,5 @@ class ExecutionProvider(betterproto.Enum): CUDA = 1 DIRECT_ML = 2 CORE_ML = 3 + ROCM = 4 + MIGRAPHX = 5