diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 539b8d23d08..16a353a4202 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -66,6 +66,10 @@ struct BenchmarkOpt { /// Load the data into a MemTable before executing the query #[structopt(short = "m", long = "mem-table")] mem_table: bool, + + /// Number of partitions to create when using MemTable as input + #[structopt(short = "n", long = "partitions", default_value = "8")] + partitions: usize, } #[derive(Debug, StructOpt)] @@ -134,8 +138,12 @@ async fn benchmark(opt: BenchmarkOpt) -> Result Arc> { let ctx_holder: Arc>>>> = Arc::new(Mutex::new(vec![])); + + let partitions = 16; + rt.block_on(async { - let mem_table = MemTable::load(&csv, 16 * 1024).await.unwrap(); + let mem_table = MemTable::load(&csv, 16 * 1024, Some(partitions)) + .await + .unwrap(); // create local execution context let mut ctx = ExecutionContext::new(); diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index a3d7b0f1ac8..44d7c6ae573 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -19,6 +19,7 @@ //! queried by DataFusion. This allows data to be pre-loaded into memory and then //! repeatedly queried without incurring additional file I/O overhead. +use futures::StreamExt; use log::debug; use std::any::Any; use std::sync::Arc; @@ -26,13 +27,16 @@ use std::sync::Arc; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use crate::datasource::datasource::Statistics; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Expr; use crate::physical_plan::common; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::ExecutionPlan; +use crate::{ + datasource::datasource::Statistics, + physical_plan::{repartition::RepartitionExec, Partitioning}, +}; use super::datasource::ColumnStatistics; @@ -102,7 +106,11 @@ impl MemTable { } /// Create a mem table by reading from another data source - pub async fn load(t: &dyn TableProvider, batch_size: usize) -> Result { + pub async fn load( + t: &dyn TableProvider, + batch_size: usize, + output_partitions: Option, + ) -> Result { let schema = t.schema(); let exec = t.scan(&None, batch_size, &[])?; let partition_count = exec.output_partitioning().partition_count(); @@ -126,6 +134,28 @@ impl MemTable { data.push(result); } + let exec = MemoryExec::try_new(&data, schema.clone(), None)?; + + if let Some(num_partitions) = output_partitions { + let exec = RepartitionExec::try_new( + Arc::new(exec), + Partitioning::RoundRobinBatch(num_partitions), + )?; + + // execute and collect results + let mut output_partitions = vec![]; + for i in 0..exec.output_partitioning().partition_count() { + // execute this *output* partition and collect all batches + let mut stream = exec.execute(i).await?; + let mut batches = vec![]; + while let Some(result) = stream.next().await { + batches.push(result?); + } + output_partitions.push(batches); + } + + return MemTable::try_new(schema.clone(), output_partitions); + } MemTable::try_new(schema.clone(), data) } }