-
Notifications
You must be signed in to change notification settings - Fork 11.1k
feat: add AWS SigV4 auth for OpenAI-compatible model providers #17820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dc75dc1
a37ddcb
7e1074a
05673f5
36d8e60
2714225
245c651
26fa21e
2d220ab
c11484f
98df4aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| load("//:defs.bzl", "codex_rust_crate") | ||
|
|
||
| codex_rust_crate( | ||
| name = "aws-auth", | ||
| crate_name = "codex_aws_auth", | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| [package] | ||
| edition.workspace = true | ||
| license.workspace = true | ||
| name = "codex-aws-auth" | ||
| version.workspace = true | ||
|
|
||
| [lib] | ||
| doctest = false | ||
| name = "codex_aws_auth" | ||
| path = "src/lib.rs" | ||
|
|
||
| [lints] | ||
| workspace = true | ||
|
|
||
| [dependencies] | ||
| aws-config = { workspace = true } | ||
| aws-credential-types = { workspace = true } | ||
| aws-sigv4 = { workspace = true } | ||
| aws-types = { workspace = true } | ||
| bytes = { workspace = true } | ||
| http = { workspace = true } | ||
| thiserror = { workspace = true } | ||
|
|
||
| [dev-dependencies] | ||
| pretty_assertions = { workspace = true } | ||
| tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| use aws_config::BehaviorVersion; | ||
| use aws_config::SdkConfig; | ||
| use aws_credential_types::provider::SharedCredentialsProvider; | ||
| use aws_types::region::Region; | ||
|
|
||
| use crate::AwsAuthConfig; | ||
| use crate::AwsAuthError; | ||
|
|
||
| pub(crate) async fn load_sdk_config(config: &AwsAuthConfig) -> Result<SdkConfig, AwsAuthError> { | ||
| if config.service.trim().is_empty() { | ||
| return Err(AwsAuthError::EmptyService); | ||
| } | ||
|
|
||
| let mut loader = aws_config::defaults(BehaviorVersion::latest()); | ||
| if let Some(profile) = config.profile.as_ref() { | ||
| loader = loader.profile_name(profile); | ||
| } | ||
| if let Some(region) = config.region.as_ref() { | ||
| loader = loader.region(Region::new(region.clone())); | ||
| } | ||
|
|
||
| Ok(loader.load().await) | ||
| } | ||
|
|
||
| pub(crate) fn credentials_provider( | ||
| sdk_config: &SdkConfig, | ||
| ) -> Result<SharedCredentialsProvider, AwsAuthError> { | ||
| sdk_config | ||
| .credentials_provider() | ||
| .ok_or(AwsAuthError::MissingCredentialsProvider) | ||
| } | ||
|
|
||
| pub(crate) fn resolved_region(sdk_config: &SdkConfig) -> Result<String, AwsAuthError> { | ||
| sdk_config | ||
| .region() | ||
| .map(ToString::to_string) | ||
| .ok_or(AwsAuthError::MissingRegion) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,261 @@ | ||
| mod config; | ||
| mod signing; | ||
|
|
||
| use std::time::SystemTime; | ||
|
|
||
| use aws_credential_types::provider::ProvideCredentials; | ||
| use aws_credential_types::provider::SharedCredentialsProvider; | ||
| use bytes::Bytes; | ||
| use http::HeaderMap; | ||
| use http::Method; | ||
| use thiserror::Error; | ||
|
|
||
| /// AWS auth configuration used to resolve credentials and sign requests. | ||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||
| pub struct AwsAuthConfig { | ||
| pub profile: Option<String>, | ||
| pub region: Option<String>, | ||
| pub service: String, | ||
| } | ||
|
|
||
| /// Generic HTTP request shape consumed by SigV4 signing. | ||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||
| pub struct AwsRequestToSign { | ||
| pub method: Method, | ||
| pub url: String, | ||
| pub headers: HeaderMap, | ||
| pub body: Bytes, | ||
| } | ||
|
|
||
| /// Signed request parts returned to the caller. | ||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||
| pub struct AwsSignedRequest { | ||
| pub url: String, | ||
| pub headers: HeaderMap, | ||
| } | ||
|
|
||
| /// Errors returned by credential loading or SigV4 signing. | ||
| #[derive(Debug, Error)] | ||
| pub enum AwsAuthError { | ||
| #[error("AWS service name must not be empty")] | ||
| EmptyService, | ||
| #[error("AWS SDK config did not resolve a credentials provider")] | ||
| MissingCredentialsProvider, | ||
| #[error("AWS SDK config did not resolve a region")] | ||
| MissingRegion, | ||
| #[error("failed to load AWS credentials: {0}")] | ||
| Credentials(#[from] aws_credential_types::provider::error::CredentialsError), | ||
| #[error("request URL is not a valid URI: {0}")] | ||
| InvalidUri(#[source] http::uri::InvalidUri), | ||
| #[error("failed to construct HTTP request for signing: {0}")] | ||
| BuildHttpRequest(#[source] http::Error), | ||
| #[error("request contains a non-UTF8 header value: {0}")] | ||
| InvalidHeaderValue(#[source] http::header::ToStrError), | ||
| #[error("failed to build signable request: {0}")] | ||
| SigningRequest(#[source] aws_sigv4::http_request::SigningError), | ||
| #[error("failed to build SigV4 signing params: {0}")] | ||
| SigningParams(String), | ||
| #[error("SigV4 signing failed: {0}")] | ||
| SigningFailure(#[source] aws_sigv4::http_request::SigningError), | ||
| } | ||
|
|
||
| /// Loaded AWS auth context that can sign outbound HTTP requests. | ||
| #[derive(Clone)] | ||
| pub struct AwsAuthContext { | ||
| credentials_provider: SharedCredentialsProvider, | ||
| region: String, | ||
| service: String, | ||
| } | ||
|
|
||
| impl std::fmt::Debug for AwsAuthContext { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| f.debug_struct("AwsAuthContext") | ||
| .field("region", &self.region) | ||
| .field("service", &self.service) | ||
| .finish_non_exhaustive() | ||
| } | ||
| } | ||
|
|
||
| impl AwsAuthContext { | ||
| pub async fn load(config: AwsAuthConfig) -> Result<Self, AwsAuthError> { | ||
| let sdk_config = config::load_sdk_config(&config).await?; | ||
| let credentials_provider = config::credentials_provider(&sdk_config)?; | ||
| let region = config::resolved_region(&sdk_config)?; | ||
|
|
||
| Ok(Self { | ||
| credentials_provider, | ||
| region, | ||
| service: config.service.trim().to_string(), | ||
| }) | ||
| } | ||
|
|
||
| pub fn region(&self) -> &str { | ||
| &self.region | ||
| } | ||
|
|
||
| pub fn service(&self) -> &str { | ||
| &self.service | ||
| } | ||
|
|
||
| pub async fn sign(&self, request: AwsRequestToSign) -> Result<AwsSignedRequest, AwsAuthError> { | ||
| self.sign_at(request, SystemTime::now()).await | ||
| } | ||
|
|
||
| async fn sign_at( | ||
| &self, | ||
| request: AwsRequestToSign, | ||
| time: SystemTime, | ||
| ) -> Result<AwsSignedRequest, AwsAuthError> { | ||
| let credentials = self.credentials_provider.provide_credentials().await?; | ||
| signing::sign_request(&credentials, &self.region, &self.service, request, time) | ||
| } | ||
| } | ||
|
|
||
| impl AwsAuthError { | ||
| /// Returns whether retrying the outbound request can reasonably recover from this auth error. | ||
| pub fn is_retryable(&self) -> bool { | ||
| match self { | ||
| AwsAuthError::Credentials(error) => matches!( | ||
| error, | ||
| aws_credential_types::provider::error::CredentialsError::ProviderTimedOut(_) | ||
| | aws_credential_types::provider::error::CredentialsError::ProviderError(_) | ||
| ), | ||
| AwsAuthError::EmptyService | ||
| | AwsAuthError::MissingCredentialsProvider | ||
| | AwsAuthError::MissingRegion | ||
| | AwsAuthError::InvalidUri(_) | ||
| | AwsAuthError::BuildHttpRequest(_) | ||
| | AwsAuthError::InvalidHeaderValue(_) | ||
| | AwsAuthError::SigningRequest(_) | ||
| | AwsAuthError::SigningParams(_) | ||
| | AwsAuthError::SigningFailure(_) => false, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::time::Duration; | ||
| use std::time::UNIX_EPOCH; | ||
|
|
||
| use aws_credential_types::Credentials; | ||
| use aws_credential_types::provider::error::CredentialsError; | ||
| use pretty_assertions::assert_eq; | ||
|
|
||
| use super::*; | ||
|
|
||
| fn test_context(session_token: Option<&str>) -> AwsAuthContext { | ||
| AwsAuthContext { | ||
| credentials_provider: SharedCredentialsProvider::new(Credentials::new( | ||
| "AKIDEXAMPLE", | ||
| "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", | ||
| session_token.map(str::to_string), | ||
| /*expires_after*/ None, | ||
| "unit-test", | ||
| )), | ||
| region: "us-east-1".to_string(), | ||
| service: "bedrock".to_string(), | ||
| } | ||
| } | ||
|
|
||
| fn test_request() -> AwsRequestToSign { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it possible to add integration tests for AWS model provider? Set up some basic fake auth that AWS SDK will pick up and assert that the request is being built correctly?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can do it as a follow-up - seems doable. for now, I have manual tests locally that do exactly this. |
||
| let mut headers = HeaderMap::new(); | ||
| headers.insert( | ||
| http::header::CONTENT_TYPE, | ||
| http::HeaderValue::from_static("application/json"), | ||
| ); | ||
| headers.insert("x-test-header", http::HeaderValue::from_static("present")); | ||
| AwsRequestToSign { | ||
| method: Method::POST, | ||
| url: "https://bedrock-runtime.us-east-1.amazonaws.com/v1/responses".to_string(), | ||
| headers, | ||
| body: Bytes::from_static(br#"{"model":"openai.gpt-oss-120b-1:0"}"#), | ||
| } | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn sign_adds_sigv4_headers_and_preserves_existing_headers() { | ||
| let signed = test_context(/*session_token*/ None) | ||
| .sign_at( | ||
| test_request(), | ||
| UNIX_EPOCH + Duration::from_secs(1_700_000_000), | ||
| ) | ||
| .await | ||
| .expect("request should sign"); | ||
|
|
||
| assert_eq!( | ||
| signing::header_value(&signed.headers, http::header::CONTENT_TYPE.as_str()), | ||
| Some("application/json".to_string()) | ||
| ); | ||
| assert_eq!( | ||
| signing::header_value(&signed.headers, "x-test-header"), | ||
| Some("present".to_string()) | ||
| ); | ||
| assert_eq!( | ||
| signed.url, | ||
| "https://bedrock-runtime.us-east-1.amazonaws.com/v1/responses" | ||
| ); | ||
| assert!( | ||
| signing::header_value(&signed.headers, http::header::AUTHORIZATION.as_str()) | ||
| .is_some_and(|value| value.starts_with("AWS4-HMAC-SHA256 ")) | ||
| ); | ||
| assert!(signing::header_value(&signed.headers, "x-amz-date").is_some()); | ||
| } | ||
|
|
||
| #[test] | ||
| fn credentials_provider_failures_are_retryable() { | ||
| assert!( | ||
| AwsAuthError::Credentials(CredentialsError::provider_error("temporarily unavailable")) | ||
| .is_retryable() | ||
| ); | ||
| assert!( | ||
| AwsAuthError::Credentials(CredentialsError::provider_timed_out(Duration::from_secs(1))) | ||
| .is_retryable() | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn deterministic_aws_auth_errors_are_not_retryable() { | ||
| assert!(!AwsAuthError::EmptyService.is_retryable()); | ||
| assert!( | ||
| !AwsAuthError::Credentials(CredentialsError::not_loaded_no_source()).is_retryable() | ||
| ); | ||
| assert!( | ||
| !AwsAuthError::Credentials(CredentialsError::invalid_configuration("bad profile")) | ||
| .is_retryable() | ||
| ); | ||
| assert!( | ||
| !AwsAuthError::Credentials(CredentialsError::unhandled("unexpected response")) | ||
| .is_retryable() | ||
| ); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn sign_includes_session_token_when_credentials_have_one() { | ||
| let signed = test_context(Some("session-token")) | ||
| .sign_at( | ||
| test_request(), | ||
| UNIX_EPOCH + Duration::from_secs(1_700_000_000), | ||
| ) | ||
| .await | ||
| .expect("request should sign"); | ||
|
|
||
| assert_eq!( | ||
| signing::header_value(&signed.headers, "x-amz-security-token"), | ||
| Some("session-token".to_string()) | ||
| ); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn load_rejects_empty_service_name() { | ||
| let err = AwsAuthContext::load(AwsAuthConfig { | ||
| profile: None, | ||
| region: None, | ||
| service: " ".to_string(), | ||
| }) | ||
| .await | ||
| .expect_err("empty service should be rejected"); | ||
|
|
||
| assert_eq!(err.to_string(), "AWS service name must not be empty"); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.