Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

415 changes: 404 additions & 11 deletions codex-rs/Cargo.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions codex-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[workspace]
members = [
"aws-auth",
"analytics",
"backend-client",
"ansi-escape",
Expand Down Expand Up @@ -112,6 +113,7 @@ app_test_support = { path = "app-server/tests/common" }
codex-analytics = { path = "analytics" }
codex-ansi-escape = { path = "ansi-escape" }
codex-api = { path = "codex-api" }
codex-aws-auth = { path = "aws-auth" }
codex-app-server = { path = "app-server" }
codex-app-server-client = { path = "app-server-client" }
codex-app-server-protocol = { path = "app-server-protocol" }
Expand Down Expand Up @@ -217,6 +219,10 @@ async-channel = "2.3.1"
async-io = "2.6.0"
async-stream = "0.3.6"
async-trait = "0.1.89"
aws-config = "1"
aws-credential-types = "1"
aws-sigv4 = "1"
aws-types = "1"
axum = { version = "0.8", default-features = false }
base64 = "0.22.1"
bm25 = "2.3.2"
Expand Down
6 changes: 6 additions & 0 deletions codex-rs/aws-auth/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
load("//:defs.bzl", "codex_rust_crate")

codex_rust_crate(
name = "aws-auth",
crate_name = "codex_aws_auth",
)
26 changes: 26 additions & 0 deletions codex-rs/aws-auth/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
edition.workspace = true
license.workspace = true
name = "codex-aws-auth"
version.workspace = true

[lib]
doctest = false
name = "codex_aws_auth"
path = "src/lib.rs"

[lints]
workspace = true

[dependencies]
aws-config = { workspace = true }
aws-credential-types = { workspace = true }
aws-sigv4 = { workspace = true }
aws-types = { workspace = true }
bytes = { workspace = true }
http = { workspace = true }
thiserror = { workspace = true }

[dev-dependencies]
pretty_assertions = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
38 changes: 38 additions & 0 deletions codex-rs/aws-auth/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use aws_config::BehaviorVersion;
use aws_config::SdkConfig;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_types::region::Region;

use crate::AwsAuthConfig;
use crate::AwsAuthError;

pub(crate) async fn load_sdk_config(config: &AwsAuthConfig) -> Result<SdkConfig, AwsAuthError> {
if config.service.trim().is_empty() {
return Err(AwsAuthError::EmptyService);
}

let mut loader = aws_config::defaults(BehaviorVersion::latest());
if let Some(profile) = config.profile.as_ref() {
loader = loader.profile_name(profile);
}
if let Some(region) = config.region.as_ref() {
loader = loader.region(Region::new(region.clone()));
}

Ok(loader.load().await)
}

pub(crate) fn credentials_provider(
sdk_config: &SdkConfig,
) -> Result<SharedCredentialsProvider, AwsAuthError> {
sdk_config
.credentials_provider()
.ok_or(AwsAuthError::MissingCredentialsProvider)
}

pub(crate) fn resolved_region(sdk_config: &SdkConfig) -> Result<String, AwsAuthError> {
sdk_config
.region()
.map(ToString::to_string)
.ok_or(AwsAuthError::MissingRegion)
}
261 changes: 261 additions & 0 deletions codex-rs/aws-auth/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
mod config;
mod signing;

use std::time::SystemTime;

use aws_credential_types::provider::ProvideCredentials;
use aws_credential_types::provider::SharedCredentialsProvider;
use bytes::Bytes;
use http::HeaderMap;
use http::Method;
use thiserror::Error;

/// AWS auth configuration used to resolve credentials and sign requests.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AwsAuthConfig {
pub profile: Option<String>,
pub region: Option<String>,
pub service: String,
}

/// Generic HTTP request shape consumed by SigV4 signing.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AwsRequestToSign {
pub method: Method,
pub url: String,
pub headers: HeaderMap,
pub body: Bytes,
}

/// Signed request parts returned to the caller.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AwsSignedRequest {
pub url: String,
pub headers: HeaderMap,
}

/// Errors returned by credential loading or SigV4 signing.
#[derive(Debug, Error)]
pub enum AwsAuthError {
#[error("AWS service name must not be empty")]
EmptyService,
#[error("AWS SDK config did not resolve a credentials provider")]
MissingCredentialsProvider,
#[error("AWS SDK config did not resolve a region")]
MissingRegion,
#[error("failed to load AWS credentials: {0}")]
Credentials(#[from] aws_credential_types::provider::error::CredentialsError),
#[error("request URL is not a valid URI: {0}")]
InvalidUri(#[source] http::uri::InvalidUri),
#[error("failed to construct HTTP request for signing: {0}")]
BuildHttpRequest(#[source] http::Error),
#[error("request contains a non-UTF8 header value: {0}")]
InvalidHeaderValue(#[source] http::header::ToStrError),
#[error("failed to build signable request: {0}")]
SigningRequest(#[source] aws_sigv4::http_request::SigningError),
#[error("failed to build SigV4 signing params: {0}")]
SigningParams(String),
#[error("SigV4 signing failed: {0}")]
SigningFailure(#[source] aws_sigv4::http_request::SigningError),
}

/// Loaded AWS auth context that can sign outbound HTTP requests.
#[derive(Clone)]
pub struct AwsAuthContext {
credentials_provider: SharedCredentialsProvider,
region: String,
service: String,
}

impl std::fmt::Debug for AwsAuthContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AwsAuthContext")
.field("region", &self.region)
.field("service", &self.service)
.finish_non_exhaustive()
}
}

impl AwsAuthContext {
pub async fn load(config: AwsAuthConfig) -> Result<Self, AwsAuthError> {
let sdk_config = config::load_sdk_config(&config).await?;
let credentials_provider = config::credentials_provider(&sdk_config)?;
let region = config::resolved_region(&sdk_config)?;

Ok(Self {
credentials_provider,
region,
service: config.service.trim().to_string(),
})
}

pub fn region(&self) -> &str {
&self.region
}

pub fn service(&self) -> &str {
&self.service
}

pub async fn sign(&self, request: AwsRequestToSign) -> Result<AwsSignedRequest, AwsAuthError> {
self.sign_at(request, SystemTime::now()).await
}

async fn sign_at(
&self,
request: AwsRequestToSign,
time: SystemTime,
) -> Result<AwsSignedRequest, AwsAuthError> {
let credentials = self.credentials_provider.provide_credentials().await?;
signing::sign_request(&credentials, &self.region, &self.service, request, time)
}
}

impl AwsAuthError {
/// Returns whether retrying the outbound request can reasonably recover from this auth error.
pub fn is_retryable(&self) -> bool {
match self {
AwsAuthError::Credentials(error) => matches!(
Comment thread
celia-oai marked this conversation as resolved.
error,
aws_credential_types::provider::error::CredentialsError::ProviderTimedOut(_)
| aws_credential_types::provider::error::CredentialsError::ProviderError(_)
),
AwsAuthError::EmptyService
| AwsAuthError::MissingCredentialsProvider
| AwsAuthError::MissingRegion
| AwsAuthError::InvalidUri(_)
| AwsAuthError::BuildHttpRequest(_)
| AwsAuthError::InvalidHeaderValue(_)
| AwsAuthError::SigningRequest(_)
| AwsAuthError::SigningParams(_)
| AwsAuthError::SigningFailure(_) => false,
}
}
}

#[cfg(test)]
mod tests {
use std::time::Duration;
use std::time::UNIX_EPOCH;

use aws_credential_types::Credentials;
use aws_credential_types::provider::error::CredentialsError;
use pretty_assertions::assert_eq;

use super::*;

fn test_context(session_token: Option<&str>) -> AwsAuthContext {
AwsAuthContext {
credentials_provider: SharedCredentialsProvider::new(Credentials::new(
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
session_token.map(str::to_string),
/*expires_after*/ None,
"unit-test",
)),
region: "us-east-1".to_string(),
service: "bedrock".to_string(),
}
}

fn test_request() -> AwsRequestToSign {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add integration tests for AWS model provider? Set up some basic fake auth that AWS SDK will pick up and assert that the request is being built correctly?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do it as a follow-up - seems doable. for now, I have manual tests locally that do exactly this.

let mut headers = HeaderMap::new();
headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
headers.insert("x-test-header", http::HeaderValue::from_static("present"));
AwsRequestToSign {
method: Method::POST,
url: "https://bedrock-runtime.us-east-1.amazonaws.com/v1/responses".to_string(),
headers,
body: Bytes::from_static(br#"{"model":"openai.gpt-oss-120b-1:0"}"#),
}
}

#[tokio::test]
async fn sign_adds_sigv4_headers_and_preserves_existing_headers() {
let signed = test_context(/*session_token*/ None)
.sign_at(
test_request(),
UNIX_EPOCH + Duration::from_secs(1_700_000_000),
)
.await
.expect("request should sign");

assert_eq!(
signing::header_value(&signed.headers, http::header::CONTENT_TYPE.as_str()),
Some("application/json".to_string())
);
assert_eq!(
signing::header_value(&signed.headers, "x-test-header"),
Some("present".to_string())
);
assert_eq!(
signed.url,
"https://bedrock-runtime.us-east-1.amazonaws.com/v1/responses"
);
assert!(
signing::header_value(&signed.headers, http::header::AUTHORIZATION.as_str())
.is_some_and(|value| value.starts_with("AWS4-HMAC-SHA256 "))
);
assert!(signing::header_value(&signed.headers, "x-amz-date").is_some());
}

#[test]
fn credentials_provider_failures_are_retryable() {
assert!(
AwsAuthError::Credentials(CredentialsError::provider_error("temporarily unavailable"))
.is_retryable()
);
assert!(
AwsAuthError::Credentials(CredentialsError::provider_timed_out(Duration::from_secs(1)))
.is_retryable()
);
}

#[test]
fn deterministic_aws_auth_errors_are_not_retryable() {
assert!(!AwsAuthError::EmptyService.is_retryable());
assert!(
!AwsAuthError::Credentials(CredentialsError::not_loaded_no_source()).is_retryable()
);
assert!(
!AwsAuthError::Credentials(CredentialsError::invalid_configuration("bad profile"))
.is_retryable()
);
assert!(
!AwsAuthError::Credentials(CredentialsError::unhandled("unexpected response"))
.is_retryable()
);
}

#[tokio::test]
async fn sign_includes_session_token_when_credentials_have_one() {
let signed = test_context(Some("session-token"))
.sign_at(
test_request(),
UNIX_EPOCH + Duration::from_secs(1_700_000_000),
)
.await
.expect("request should sign");

assert_eq!(
signing::header_value(&signed.headers, "x-amz-security-token"),
Some("session-token".to_string())
);
}

#[tokio::test]
async fn load_rejects_empty_service_name() {
let err = AwsAuthContext::load(AwsAuthConfig {
profile: None,
region: None,
service: " ".to_string(),
})
.await
.expect_err("empty service should be rejected");

assert_eq!(err.to_string(), "AWS service name must not be empty");
}
}
Loading
Loading