Skip to content
35 changes: 29 additions & 6 deletions bindings/python/src/datafusion_table_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::table_provider::FFI_TableProvider;
use iceberg::TableIdent;
use iceberg::io::FileIO;
use iceberg::io::{FileIOBuilder, OpenDalStorageFactory, StorageFactory};
use iceberg::table::StaticTable;
use iceberg_datafusion::table::IcebergStaticTableProvider;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
Expand All @@ -31,6 +31,30 @@ use pyo3::types::{PyAny, PyCapsule};

use crate::runtime::runtime;

/// Parse the scheme from a URL and return the appropriate StorageFactory.
fn storage_factory_from_path(path: &str) -> PyResult<Arc<dyn StorageFactory>> {
Comment thread
blackmwk marked this conversation as resolved.
let scheme = path
.split("://")
.next()
.ok_or_else(|| PyRuntimeError::new_err(format!("Invalid path, missing scheme: {path}")))?;

let factory: Arc<dyn StorageFactory> = match scheme {
"file" | "" => Arc::new(OpenDalStorageFactory::Fs),
"s3" | "s3a" => Arc::new(OpenDalStorageFactory::S3 {
configured_scheme: scheme.to_string(),
customized_credential_load: None,
}),
"memory" => Arc::new(OpenDalStorageFactory::Memory),
_ => {
return Err(PyRuntimeError::new_err(format!(
"Unsupported storage scheme: {scheme}"
)));
}
};

Ok(factory)
}

pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if capsule_name.is_none() {
Expand Down Expand Up @@ -85,16 +109,15 @@ impl PyIcebergDataFusionTable {
let table_ident = TableIdent::from_strs(identifier)
.map_err(|e| PyRuntimeError::new_err(format!("Invalid table identifier: {e}")))?;

let mut builder = FileIO::from_path(&metadata_location)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to init FileIO: {e}")))?;
let factory = storage_factory_from_path(&metadata_location)?;

let mut builder = FileIOBuilder::new(factory);

if let Some(props) = file_io_properties {
builder = builder.with_props(props);
}

let file_io = builder
.build()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to build FileIO: {e}")))?;
let file_io = builder.build();

let static_table =
StaticTable::from_metadata_file(&metadata_location, table_ident, file_io)
Expand Down
63 changes: 43 additions & 20 deletions crates/catalog/glue/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;

use anyhow::anyhow;
use async_trait::async_trait;
use aws_sdk_glue::operation::create_table::CreateTableError;
use aws_sdk_glue::operation::update_table::UpdateTableError;
use aws_sdk_glue::types::TableInput;
use iceberg::io::{
FileIO, S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY, S3_SESSION_TOKEN,
FileIO, FileIOBuilder, OpenDalStorageFactory, S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION,
S3_SECRET_ACCESS_KEY, S3_SESSION_TOKEN, StorageFactory,
};
use iceberg::spec::{TableMetadata, TableMetadataBuilder};
use iceberg::table::Table;
Expand All @@ -51,47 +53,58 @@ pub const GLUE_CATALOG_PROP_WAREHOUSE: &str = "warehouse";

/// Builder for [`GlueCatalog`].
#[derive(Debug)]
pub struct GlueCatalogBuilder(GlueCatalogConfig);
pub struct GlueCatalogBuilder {
config: GlueCatalogConfig,
storage_factory: Option<Arc<dyn StorageFactory>>,
}

impl Default for GlueCatalogBuilder {
fn default() -> Self {
Self(GlueCatalogConfig {
name: None,
uri: None,
catalog_id: None,
warehouse: "".to_string(),
props: HashMap::new(),
})
Self {
config: GlueCatalogConfig {
name: None,
uri: None,
catalog_id: None,
warehouse: "".to_string(),
props: HashMap::new(),
},
storage_factory: None,
}
}
}

impl CatalogBuilder for GlueCatalogBuilder {
type C = GlueCatalog;

fn with_storage_factory(mut self, storage_factory: Arc<dyn StorageFactory>) -> Self {
self.storage_factory = Some(storage_factory);
self
}

fn load(
mut self,
name: impl Into<String>,
props: HashMap<String, String>,
) -> impl Future<Output = Result<Self::C>> + Send {
self.0.name = Some(name.into());
self.config.name = Some(name.into());

if props.contains_key(GLUE_CATALOG_PROP_URI) {
self.0.uri = props.get(GLUE_CATALOG_PROP_URI).cloned()
self.config.uri = props.get(GLUE_CATALOG_PROP_URI).cloned()
}

if props.contains_key(GLUE_CATALOG_PROP_CATALOG_ID) {
self.0.catalog_id = props.get(GLUE_CATALOG_PROP_CATALOG_ID).cloned()
self.config.catalog_id = props.get(GLUE_CATALOG_PROP_CATALOG_ID).cloned()
}

if props.contains_key(GLUE_CATALOG_PROP_WAREHOUSE) {
self.0.warehouse = props
self.config.warehouse = props
.get(GLUE_CATALOG_PROP_WAREHOUSE)
.cloned()
.unwrap_or_default();
}

// Collect other remaining properties
self.0.props = props
self.config.props = props
.into_iter()
.filter(|(k, _)| {
k != GLUE_CATALOG_PROP_URI
Expand All @@ -101,20 +114,20 @@ impl CatalogBuilder for GlueCatalogBuilder {
.collect();

async move {
if self.0.name.is_none() {
if self.config.name.is_none() {
return Err(Error::new(
ErrorKind::DataInvalid,
"Catalog name is required",
));
}
if self.0.warehouse.is_empty() {
if self.config.warehouse.is_empty() {
return Err(Error::new(
ErrorKind::DataInvalid,
"Catalog warehouse is required",
));
}

GlueCatalog::new(self.0).await
GlueCatalog::new(self.config, self.storage_factory).await
}
}
}
Expand Down Expand Up @@ -148,7 +161,10 @@ impl Debug for GlueCatalog {

impl GlueCatalog {
/// Create a new glue catalog
async fn new(config: GlueCatalogConfig) -> Result<Self> {
async fn new(
config: GlueCatalogConfig,
storage_factory: Option<Arc<dyn StorageFactory>>,
) -> Result<Self> {
let sdk_config = create_sdk_config(&config.props, config.uri.as_ref()).await;
let mut file_io_props = config.props.clone();
if !file_io_props.contains_key(S3_ACCESS_KEY_ID)
Expand Down Expand Up @@ -182,9 +198,16 @@ impl GlueCatalog {

let client = aws_sdk_glue::Client::new(&sdk_config);

let file_io = FileIO::from_path(&config.warehouse)?
// Use provided factory or default to OpenDalStorageFactory::S3
let factory = storage_factory.unwrap_or_else(|| {
Arc::new(OpenDalStorageFactory::S3 {
configured_scheme: "s3a".to_string(),
customized_credential_load: None,
})
});
let file_io = FileIOBuilder::new(factory)
.with_props(file_io_props)
.build()?;
.build();

Ok(GlueCatalog {
config,
Expand Down
17 changes: 11 additions & 6 deletions crates/catalog/glue/tests/glue_catalog_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
//! Each test uses unique namespaces based on module path to avoid conflicts.

use std::collections::HashMap;
use std::sync::Arc;

use iceberg::io::{S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY};
use iceberg::io::{
FileIOBuilder, OpenDalStorageFactory, S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION,
S3_SECRET_ACCESS_KEY,
};
use iceberg::spec::{NestedField, PrimitiveType, Schema, Type};
use iceberg::transaction::{ApplyTransactionAction, Transaction};
use iceberg::{
Expand Down Expand Up @@ -59,11 +63,12 @@ async fn get_catalog() -> GlueCatalog {
]);

// Wait for bucket to actually exist
let file_io = iceberg::io::FileIO::from_path("s3a://")
.unwrap()
.with_props(props.clone())
.build()
.unwrap();
let file_io = FileIOBuilder::new(Arc::new(OpenDalStorageFactory::S3 {
configured_scheme: "s3a".to_string(),
customized_credential_load: None,
}))
.with_props(props.clone())
.build();

let mut retries = 0;
while retries < 30 {
Expand Down
65 changes: 43 additions & 22 deletions crates/catalog/hms/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::net::ToSocketAddrs;
use std::sync::Arc;

use anyhow::anyhow;
use async_trait::async_trait;
use hive_metastore::{
ThriftHiveMetastoreClient, ThriftHiveMetastoreClientBuilder,
ThriftHiveMetastoreGetDatabaseException, ThriftHiveMetastoreGetTableException,
};
use iceberg::io::FileIO;
use iceberg::io::{FileIO, FileIOBuilder, StorageFactory};
use iceberg::spec::{TableMetadata, TableMetadataBuilder};
use iceberg::table::Table;
use iceberg::{
Expand All @@ -50,52 +51,63 @@ pub const THRIFT_TRANSPORT_BUFFERED: &str = "buffered";
/// HMS Catalog warehouse location
pub const HMS_CATALOG_PROP_WAREHOUSE: &str = "warehouse";

/// Builder for [`RestCatalog`].
/// Builder for [`HmsCatalog`].
#[derive(Debug)]
pub struct HmsCatalogBuilder(HmsCatalogConfig);
pub struct HmsCatalogBuilder {
config: HmsCatalogConfig,
storage_factory: Option<Arc<dyn StorageFactory>>,
}

impl Default for HmsCatalogBuilder {
fn default() -> Self {
Self(HmsCatalogConfig {
name: None,
address: "".to_string(),
thrift_transport: HmsThriftTransport::default(),
warehouse: "".to_string(),
props: HashMap::new(),
})
Self {
config: HmsCatalogConfig {
name: None,
address: "".to_string(),
thrift_transport: HmsThriftTransport::default(),
warehouse: "".to_string(),
props: HashMap::new(),
},
storage_factory: None,
}
}
}

impl CatalogBuilder for HmsCatalogBuilder {
type C = HmsCatalog;

fn with_storage_factory(mut self, storage_factory: Arc<dyn StorageFactory>) -> Self {
self.storage_factory = Some(storage_factory);
self
}

fn load(
mut self,
name: impl Into<String>,
props: HashMap<String, String>,
) -> impl Future<Output = Result<Self::C>> + Send {
self.0.name = Some(name.into());
self.config.name = Some(name.into());

if props.contains_key(HMS_CATALOG_PROP_URI) {
self.0.address = props.get(HMS_CATALOG_PROP_URI).cloned().unwrap_or_default();
self.config.address = props.get(HMS_CATALOG_PROP_URI).cloned().unwrap_or_default();
}

if let Some(tt) = props.get(HMS_CATALOG_PROP_THRIFT_TRANSPORT) {
self.0.thrift_transport = match tt.to_lowercase().as_str() {
self.config.thrift_transport = match tt.to_lowercase().as_str() {
THRIFT_TRANSPORT_FRAMED => HmsThriftTransport::Framed,
THRIFT_TRANSPORT_BUFFERED => HmsThriftTransport::Buffered,
_ => HmsThriftTransport::default(),
};
}

if props.contains_key(HMS_CATALOG_PROP_WAREHOUSE) {
self.0.warehouse = props
self.config.warehouse = props
.get(HMS_CATALOG_PROP_WAREHOUSE)
.cloned()
.unwrap_or_default();
}

self.0.props = props
self.config.props = props
.into_iter()
.filter(|(k, _)| {
k != HMS_CATALOG_PROP_URI
Expand All @@ -105,23 +117,23 @@ impl CatalogBuilder for HmsCatalogBuilder {
.collect();

let result = {
if self.0.name.is_none() {
if self.config.name.is_none() {
Err(Error::new(
ErrorKind::DataInvalid,
"Catalog name is required",
))
} else if self.0.address.is_empty() {
} else if self.config.address.is_empty() {
Err(Error::new(
ErrorKind::DataInvalid,
"Catalog address is required",
))
} else if self.0.warehouse.is_empty() {
} else if self.config.warehouse.is_empty() {
Err(Error::new(
ErrorKind::DataInvalid,
"Catalog warehouse is required",
))
} else {
HmsCatalog::new(self.0)
HmsCatalog::new(self.config, self.storage_factory)
}
};

Expand Down Expand Up @@ -169,7 +181,10 @@ impl Debug for HmsCatalog {

impl HmsCatalog {
/// Create a new hms catalog.
fn new(config: HmsCatalogConfig) -> Result<Self> {
fn new(
config: HmsCatalogConfig,
storage_factory: Option<Arc<dyn StorageFactory>>,
) -> Result<Self> {
let address = config
.address
.as_str()
Expand All @@ -194,9 +209,15 @@ impl HmsCatalog {
.build(),
};

let file_io = FileIO::from_path(&config.warehouse)?
let factory = storage_factory.ok_or_else(|| {
Error::new(
ErrorKind::Unexpected,
"StorageFactory must be provided for HmsCatalog. Use `with_storage_factory` to configure it.",
)
})?;
let file_io = FileIOBuilder::new(factory)
.with_props(&config.props)
.build()?;
.build();

Ok(Self {
config,
Expand Down
Loading
Loading