diff --git a/rust/datafusion/src/datasource/memory.rs b/rust/datafusion/src/datasource/memory.rs index de1c31918549..b454315ad906 100644 --- a/rust/datafusion/src/datasource/memory.rs +++ b/rust/datafusion/src/datasource/memory.rs @@ -30,6 +30,8 @@ use crate::error::{ExecutionError, Result}; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::ExecutionPlan; +use tokio::task::{self, JoinHandle}; + /// In-memory table pub struct MemTable { schema: SchemaRef, @@ -59,13 +61,24 @@ impl MemTable { pub async fn load(t: &dyn TableProvider, batch_size: usize) -> Result { let schema = t.schema(); let exec = t.scan(&None, batch_size)?; + let partition_count = exec.output_partitioning().partition_count(); + + let mut tasks = Vec::with_capacity(partition_count); + for partition in 0..partition_count { + let exec = exec.clone(); + let task: JoinHandle>> = task::spawn(async move { + let it = exec.execute(partition).await?; + it.into_iter() + .collect::>>() + .map_err(ExecutionError::from) + }); + tasks.push(task) + } - let mut data: Vec> = - Vec::with_capacity(exec.output_partitioning().partition_count()); - for partition in 0..exec.output_partitioning().partition_count() { - let it = exec.execute(partition).await?; - let partition_batches = it.into_iter().collect::>>()?; - data.push(partition_batches); + let mut data: Vec> = Vec::with_capacity(partition_count); + for task in tasks { + let result = task.await.expect("MemTable::load could not join task")?; + data.push(result); } MemTable::new(schema.clone(), data)