diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 7b40ed11e91..f0f2e7cf57c 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -17,8 +17,11 @@ //! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. -use std::path::{Path, PathBuf}; use std::time::Instant; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::util::pretty; @@ -150,7 +153,7 @@ async fn benchmark(opt: BenchmarkOpt) -> Result Result> { +) -> Result> { match table_format { // dbgen creates .tbl ('|' delimited) files without header "tbl" => { @@ -1110,18 +1113,18 @@ fn get_table( .has_header(false) .file_extension(".tbl"); - Ok(Box::new(CsvFile::try_new(&path, options)?)) + Ok(Arc::new(CsvFile::try_new(&path, options)?)) } "csv" => { let path = format!("{}/{}", path, table); let schema = get_schema(table); let options = CsvReadOptions::new().schema(&schema).has_header(true); - Ok(Box::new(CsvFile::try_new(&path, options)?)) + Ok(Arc::new(CsvFile::try_new(&path, options)?)) } "parquet" => { let path = format!("{}/{}", path, table); - Ok(Box::new(ParquetTable::try_new(&path, max_concurrency)?)) + Ok(Arc::new(ParquetTable::try_new(&path, max_concurrency)?)) } other => { unimplemented!("Invalid file format '{}'", other); @@ -1607,7 +1610,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; - ctx.register_table(table, Box::new(provider)); + ctx.register_table(table, Arc::new(provider)); } let plan = create_logical_plan(&mut ctx, n)?; diff --git a/rust/datafusion/benches/aggregate_query_sql.rs b/rust/datafusion/benches/aggregate_query_sql.rs index c3baa6416dc..75d9d3432ba 100644 --- a/rust/datafusion/benches/aggregate_query_sql.rs +++ b/rust/datafusion/benches/aggregate_query_sql.rs @@ -150,7 +150,7 @@ fn create_context( // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, partitions)?; - ctx.register_table("t", Box::new(provider)); + ctx.register_table("t", Arc::new(provider)); Ok(Arc::new(Mutex::new(ctx))) } diff --git a/rust/datafusion/benches/filter_query_sql.rs b/rust/datafusion/benches/filter_query_sql.rs index d6d61e4f51e..363ae416f67 100644 --- a/rust/datafusion/benches/filter_query_sql.rs +++ b/rust/datafusion/benches/filter_query_sql.rs @@ -62,7 +62,7 @@ fn create_context(array_len: usize, batch_size: usize) -> Result Arc> { let partitions = 16; rt.block_on(async { - let mem_table = MemTable::load(Box::new(csv), 16 * 1024, Some(partitions)) + let mem_table = MemTable::load(Arc::new(csv), 16 * 1024, Some(partitions)) .await .unwrap(); // create local execution context let mut ctx = ExecutionContext::new(); ctx.state.lock().unwrap().config.concurrency = 1; - ctx.register_table("aggregate_test_100", Box::new(mem_table)); + ctx.register_table("aggregate_test_100", Arc::new(mem_table)); ctx_holder.lock().unwrap().push(Arc::new(Mutex::new(ctx))) }); diff --git a/rust/datafusion/examples/dataframe_in_memory.rs b/rust/datafusion/examples/dataframe_in_memory.rs index ff352669338..28414bf8700 100644 --- a/rust/datafusion/examples/dataframe_in_memory.rs +++ b/rust/datafusion/examples/dataframe_in_memory.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::boxed::Box; use std::sync::Arc; use arrow::array::{Int32Array, StringArray}; @@ -50,7 +49,7 @@ async fn main() -> Result<()> { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; - ctx.register_table("t", Box::new(provider)); + ctx.register_table("t", Arc::new(provider)); let df = ctx.table("t")?; // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL diff --git a/rust/datafusion/examples/simple_udaf.rs b/rust/datafusion/examples/simple_udaf.rs index 41ad59b7ee5..a36d200235a 100644 --- a/rust/datafusion/examples/simple_udaf.rs +++ b/rust/datafusion/examples/simple_udaf.rs @@ -48,7 +48,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Box::new(provider)); + ctx.register_table("t", Arc::new(provider)); Ok(ctx) } diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs index c37cc9cc331..d49aac48527 100644 --- a/rust/datafusion/examples/simple_udf.rs +++ b/rust/datafusion/examples/simple_udf.rs @@ -50,7 +50,7 @@ fn create_context() -> Result { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; - ctx.register_table("t", Box::new(provider)); + ctx.register_table("t", Arc::new(provider)); Ok(ctx) } diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index eab89305091..1fc0eaabc6c 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -107,7 +107,7 @@ impl MemTable { /// Create a mem table by reading from another data source pub async fn load( - t: Box, + t: Arc, batch_size: usize, output_partitions: Option, ) -> Result { diff --git a/rust/datafusion/src/datasource/parquet.rs b/rust/datafusion/src/datasource/parquet.rs index 1cd4765c22c..888103e6db7 100644 --- a/rust/datafusion/src/datasource/parquet.rs +++ b/rust/datafusion/src/datasource/parquet.rs @@ -326,15 +326,15 @@ mod tests { Ok(()) } - fn load_table(name: &str) -> Result> { + fn load_table(name: &str) -> Result> { let testdata = arrow::util::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, name); let table = ParquetTable::try_new(&filename, 2)?; - Ok(Box::new(table)) + Ok(Arc::new(table)) } async fn get_first_batch( - table: Box, + table: Arc, projection: &Option>, ) -> Result { let exec = table.scan(projection, 1024, &[])?; diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index ea79acdbc66..26aaecd42a8 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -258,7 +258,7 @@ impl ExecutionContext { filename: &str, options: CsvReadOptions, ) -> Result<()> { - self.register_table(name, Box::new(CsvFile::try_new(filename, options)?)); + self.register_table(name, Arc::new(CsvFile::try_new(filename, options)?)); Ok(()) } @@ -269,34 +269,36 @@ impl ExecutionContext { &filename, self.state.lock().unwrap().config.concurrency, )?; - self.register_table(name, Box::new(table)); + self.register_table(name, Arc::new(table)); Ok(()) } - /// Registers a table using a custom TableProvider so that it can be referenced from SQL - /// statements executed against this context. + /// Registers a named table using a custom `TableProvider` so that + /// it can be referenced from SQL statements executed against this + /// context. + /// + /// Returns the `TableProvider` previously registered for this + /// name, if any pub fn register_table( &mut self, name: &str, - provider: Box, - ) { + provider: Arc, + ) -> Option> { self.state .lock() .unwrap() .datasources - .insert(name.to_string(), provider.into()); + .insert(name.to_string(), provider) } /// Deregisters the named table. /// - /// Returns true if the table was successfully de-reregistered. - pub fn deregister_table(&mut self, name: &str) -> bool { - self.state - .lock() - .unwrap() - .datasources - .remove(&name.to_string()) - .is_some() + /// Returns the registered provider, if any + pub fn deregister_table( + &mut self, + name: &str, + ) -> Option> { + self.state.lock().unwrap().datasources.remove(name) } /// Retrieves a DataFrame representing a table previously registered by calling the @@ -744,8 +746,8 @@ mod tests { let provider = test::create_table_dual(); ctx.register_table("dual", provider); - assert_eq!(ctx.deregister_table("dual"), true); - assert_eq!(ctx.deregister_table("dual"), false); + assert!(ctx.deregister_table("dual").is_some()); + assert!(ctx.deregister_table("dual").is_none()); Ok(()) } @@ -1616,7 +1618,7 @@ mod tests { let mut ctx = ExecutionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; - ctx.register_table("t", Box::new(provider)); + ctx.register_table("t", Arc::new(provider)); let myfunc = |args: &[ArrayRef]| { let l = &args[0] @@ -1718,7 +1720,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Box::new(provider)); + ctx.register_table("t", Arc::new(provider)); let result = plan_and_collect(&mut ctx, "SELECT AVG(a) FROM t").await?; @@ -1755,7 +1757,7 @@ mod tests { let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Box::new(provider)); + ctx.register_table("t", Arc::new(provider)); // define a udaf, using a DataFusion's accumulator let my_avg = create_udaf( diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 7628e9f57e7..75a956f1cf4 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -29,7 +29,7 @@ use std::io::{BufReader, BufWriter}; use std::sync::Arc; use tempfile::TempDir; -pub fn create_table_dual() -> Box { +pub fn create_table_dual() -> Arc { let dual_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, false), @@ -43,7 +43,7 @@ pub fn create_table_dual() -> Box { ) .unwrap(); let provider = MemTable::try_new(dual_schema, vec![vec![batch]]).unwrap(); - Box::new(provider) + Arc::new(provider) } /// Generated partitioned copy of a CSV file diff --git a/rust/datafusion/tests/dataframe.rs b/rust/datafusion/tests/dataframe.rs index 0f3996803a4..e0c698ed5fb 100644 --- a/rust/datafusion/tests/dataframe.rs +++ b/rust/datafusion/tests/dataframe.rs @@ -61,11 +61,11 @@ async fn join() -> Result<()> { let table1 = MemTable::try_new(schema1, vec![vec![batch1]])?; let table2 = MemTable::try_new(schema2, vec![vec![batch2]])?; - ctx.register_table("aa", Box::new(table1)); + ctx.register_table("aa", Arc::new(table1)); let df1 = ctx.table("aa")?; - ctx.register_table("aaa", Box::new(table2)); + ctx.register_table("aaa", Arc::new(table2)); let df2 = ctx.table("aaa")?; diff --git a/rust/datafusion/tests/provider_filter_pushdown.rs b/rust/datafusion/tests/provider_filter_pushdown.rs index d4f8a6b678f..fe648bd3a10 100644 --- a/rust/datafusion/tests/provider_filter_pushdown.rs +++ b/rust/datafusion/tests/provider_filter_pushdown.rs @@ -155,7 +155,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() let result_col: &UInt64Array = as_primitive_array(results[0].column(0)); assert_eq!(result_col.value(0), expected_count); - ctx.register_table("data", Box::new(provider)); + ctx.register_table("data", Arc::new(provider)); let sql_results = ctx .sql(&format!("select count(*) from data where flag = {}", value))? .collect() diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index be6a1235089..d5a278d9301 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1147,7 +1147,7 @@ fn create_case_context() -> Result { ]))], )?; let table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("t1", Box::new(table)); + ctx.register_table("t1", Arc::new(table)); Ok(ctx) } @@ -1296,7 +1296,7 @@ fn create_join_context( ], )?; let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; - ctx.register_table("t1", Box::new(t1_table)); + ctx.register_table("t1", Arc::new(t1_table)); let t2_schema = Arc::new(Schema::new(vec![ Field::new(column_right, DataType::UInt32, true), @@ -1315,7 +1315,7 @@ fn create_join_context( ], )?; let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; - ctx.register_table("t2", Box::new(t2_table)); + ctx.register_table("t2", Arc::new(t2_table)); Ok(ctx) } @@ -1535,7 +1535,7 @@ async fn generic_query_length>>( let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT length(c1) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0"], vec!["1"], vec!["2"], vec!["3"]]; @@ -1569,7 +1569,7 @@ async fn query_not() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT NOT c1 FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["true"], vec!["NULL"], vec!["false"]]; @@ -1595,7 +1595,7 @@ async fn query_concat() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![ @@ -1626,7 +1626,7 @@ async fn query_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![ @@ -1693,7 +1693,7 @@ async fn like() -> Result<()> { Ok(()) } -fn make_timestamp_nano_table() -> Result> { +fn make_timestamp_nano_table() -> Result> { let schema = Arc::new(Schema::new(vec![ Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false), Field::new("value", DataType::Int32, true), @@ -1713,7 +1713,7 @@ fn make_timestamp_nano_table() -> Result> { ], )?; let table = MemTable::try_new(schema, vec![vec![data]])?; - Ok(Box::new(table)) + Ok(Arc::new(table)) } #[tokio::test] @@ -1745,7 +1745,7 @@ async fn query_is_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT c1 IS NULL FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["false"], vec!["true"], vec!["false"]]; @@ -1769,7 +1769,7 @@ async fn query_is_not_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT c1 IS NOT NULL FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["true"], vec!["false"], vec!["true"]]; @@ -1796,7 +1796,7 @@ async fn query_count_distinct() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT COUNT(DISTINCT c1) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["3".to_string()]]; @@ -1825,7 +1825,7 @@ async fn query_on_string_dictionary() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); // Basic SELECT let sql = "SELECT * FROM test"; @@ -1896,7 +1896,7 @@ async fn query_scalar_minus_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = "SELECT 4 - c1 FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["4"], vec!["3"], vec!["NULL"], vec!["1"]]; @@ -1975,7 +1975,7 @@ async fn csv_group_by_date() -> Result<()> { )?; let table = MemTable::try_new(schema, vec![vec![data]])?; - ctx.register_table("dates", Box::new(table)); + ctx.register_table("dates", Arc::new(table)); let sql = "SELECT SUM(cnt) FROM dates GROUP BY date"; let actual = execute(&mut ctx, sql).await; let mut actual: Vec = actual.iter().flatten().cloned().collect();