diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 01d8ea6eac81c..d61235445a3e6 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -17,6 +17,7 @@ //! Factory for creating ListingTables with default options +use std::path::Path; use std::str::FromStr; use std::sync::Arc; @@ -66,8 +67,7 @@ impl TableProviderFactory for ListingTableFactory { DataFusionError::Execution(format!("Unknown FileType {}", cmd.file_type)) })?; - let file_extension = - file_type.get_ext_with_compression(file_compression_type.to_owned())?; + let file_extension = get_extension(cmd.location.as_str()); let file_format: Arc = match file_type { FileType::CSV => Arc::new( @@ -164,3 +164,58 @@ impl TableProviderFactory for ListingTableFactory { Ok(Arc::new(table)) } } + +// Get file extension from path +fn get_extension(path: &str) -> String { + let res = Path::new(path).extension().and_then(|ext| ext.to_str()); + match res { + Some(ext) => format!(".{}", ext), + None => "".to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::collections::HashMap; + + use crate::execution::context::SessionContext; + use datafusion_common::parsers::CompressionTypeVariant; + use datafusion_common::{DFSchema, OwnedTableReference}; + + #[tokio::test] + async fn test_create_using_non_std_file_ext() { + let csv_file = tempfile::Builder::new() + .prefix("foo") + .suffix(".tbl") + .tempfile() + .unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = OwnedTableReference::bare("foo".to_string()); + let cmd = CreateExternalTable { + name, + location: csv_file.path().to_str().unwrap().to_string(), + file_type: "csv".to_string(), + has_header: true, + delimiter: ',', + schema: Arc::new(DFSchema::empty()), + table_partition_cols: vec![], + if_not_exists: false, + file_compression_type: CompressionTypeVariant::UNCOMPRESSED, + definition: None, + order_exprs: vec![], + options: HashMap::new(), + }; + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + let listing_options = listing_table.options(); + assert_eq!(".tbl", listing_options.file_extension); + } +}