diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index a98153fa37..85996f8047 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -18,6 +18,7 @@ //! This module contains rest catalog implementation. use std::collections::HashMap; +use std::str::FromStr; use async_trait::async_trait; use itertools::Itertools; @@ -103,8 +104,7 @@ impl RestCatalogConfig { ]) } - fn try_create_rest_client(&self) -> Result { - // TODO: We will add ssl config, sigv4 later + fn http_headers(&self) -> Result { let mut headers = HeaderMap::from_iter([ ( header::CONTENT_TYPE, @@ -133,6 +133,36 @@ impl RestCatalogConfig { ); } + for (key, value) in self.props.iter() { + if let Some(stripped_key) = key.strip_prefix("header.") { + // Avoid overwriting default headers + if !headers.contains_key(stripped_key) { + headers.insert( + HeaderName::from_str(stripped_key).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid header name: {stripped_key}!"), + ) + .with_source(e) + })?, + HeaderValue::from_str(value).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid header value: {value}!"), + ) + .with_source(e) + })?, + ); + } + } + } + Ok(headers) + } + + fn try_create_rest_client(&self) -> Result { + // TODO: We will add ssl config, sigv4 later + let headers = self.http_headers()?; + Ok(HttpClient( Client::builder().default_headers(headers).build()?, )) @@ -963,6 +993,76 @@ mod tests { ); } + #[tokio::test] + async fn test_http_headers() { + let server = Server::new_async().await; + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let config = RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(); + let headers: HeaderMap = config.http_headers().unwrap(); + + let expected_headers = HeaderMap::from_iter([ + ( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ), + ( + HeaderName::from_static("x-client-version"), + HeaderValue::from_static(ICEBERG_REST_SPEC_VERSION), + ), + ( + header::USER_AGENT, + HeaderValue::from_str(&format!("iceberg-rs/{}", CARGO_PKG_VERSION)).unwrap(), + ), + ]); + assert_eq!(headers, expected_headers); + } + + #[tokio::test] + async fn test_http_headers_with_custom_headers() { + let server = Server::new_async().await; + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + props.insert( + "header.content-type".to_string(), + "application/yaml".to_string(), + ); + props.insert( + "header.customized-header".to_string(), + "some/value".to_string(), + ); + + let config = RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(); + let headers: HeaderMap = config.http_headers().unwrap(); + + let expected_headers = HeaderMap::from_iter([ + ( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ), + ( + HeaderName::from_static("x-client-version"), + HeaderValue::from_static(ICEBERG_REST_SPEC_VERSION), + ), + ( + header::USER_AGENT, + HeaderValue::from_str(&format!("iceberg-rs/{}", CARGO_PKG_VERSION)).unwrap(), + ), + ( + HeaderName::from_static("customized-header"), + HeaderValue::from_static("some/value"), + ), + ]); + assert_eq!(headers, expected_headers); + } + #[tokio::test] async fn test_oauth_with_auth_url() { let mut server = Server::new_async().await;