From 7471630cb093d8c112d536dbe7314479ebab07ed Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Wed, 8 Mar 2023 17:06:08 +0300 Subject: [PATCH 1/4] Insert into memory table --- datafusion/core/src/datasource/datasource.rs | 11 +- .../core/src/datasource/listing/table.rs | 3 +- datafusion/core/src/datasource/memory.rs | 205 +++++++++++++++++- datafusion/core/src/execution/context.rs | 56 ++++- datafusion/expr/src/logical_plan/builder.rs | 24 +- 5 files changed, 288 insertions(+), 11 deletions(-) diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs index 6277ce146adfb..60f3113dfc54a 100644 --- a/datafusion/core/src/datasource/datasource.rs +++ b/datafusion/core/src/datasource/datasource.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::Statistics; +use datafusion_common::{DataFusionError, Statistics}; use datafusion_expr::{CreateExternalTable, LogicalPlan}; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; @@ -97,6 +97,15 @@ pub trait TableProvider: Sync + Send { fn statistics(&self) -> Option { None } + + /// Insert API + async fn insert_into_table( + &self, + _state: &SessionState, + _input: &LogicalPlan, + ) -> Result<()> { + Err(DataFusionError::Internal("Not implemented".to_owned())) + } } /// A factory which creates [`TableProvider`]s at runtime given a URL. diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index f6d9c959eb1ac..0e9c0346d76ca 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -32,6 +32,7 @@ use futures::{future, stream, StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::ObjectMeta; +use super::PartitionedFile; use crate::datasource::file_format::file_type::{FileCompressionType, FileType}; use crate::datasource::{ file_format::{ @@ -55,8 +56,6 @@ use crate::{ }, }; -use super::PartitionedFile; - use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; /// Configuration for creating a [`ListingTable`] diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index ac1f4947f87d5..71351d99b2a6e 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -19,18 +19,23 @@ //! queried by DataFusion. This allows data to be pre-loaded into memory and then //! repeatedly queried without incurring additional file I/O overhead. -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use std::any::Any; use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use datafusion_expr::LogicalPlan; +use tokio::sync::RwLock; +use tokio::task; +use tokio::task::JoinHandle; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; use crate::execution::context::SessionState; use crate::logical_expr::Expr; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::memory::MemoryExec; @@ -41,7 +46,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; #[derive(Debug)] pub struct MemTable { schema: SchemaRef, - batches: Vec>, + batches: Arc>>>, } impl MemTable { @@ -54,7 +59,7 @@ impl MemTable { { Ok(Self { schema, - batches: partitions, + batches: Arc::new(RwLock::new(partitions)), }) } else { Err(DataFusionError::Plan( @@ -143,22 +148,103 @@ impl TableProvider for MemTable { _filters: &[Expr], _limit: Option, ) -> Result> { + let batches = &self.batches.read().await; Ok(Arc::new(MemoryExec::try_new( - &self.batches.clone(), + batches, self.schema(), projection.cloned(), )?)) } + + /// Inserts the results of executing a given `LogicalPlan` into this `Table`. + /// The `LogicalPlan` must have the same schema as this `Table`. + /// + /// # Arguments + /// + /// * `state` - The `SessionState` containing the context for executing the plan. + /// * `input` - The `LogicalPlan` to execute and insert. + /// + /// # Returns + /// + /// * A `Result` indicating success or failure. + async fn insert_into_table( + &self, + state: &SessionState, + input: &LogicalPlan, + ) -> Result<()> { + // Create a physical plan from the logical plan. + let plan = state.create_physical_plan(input).await?; + + // Check that the schema of the plan matches the schema of this table. + if !plan.schema().eq(&self.schema) { + return Err(DataFusionError::Plan( + "Inserting query must have same schema with the table.".to_string(), + )); + } + + // Get the number of partitions in the plan and the table. + let plan_partition_count = plan.output_partitioning().partition_count(); + let table_partition_count = self.batches.read().await.len(); + + // Adjust the plan as necessary to match the number of partitions in the table. + let plan: Arc = + if plan_partition_count == table_partition_count { + plan + } else if table_partition_count == 1 { + // If the table has only one partition, coalesce the partitions in the plan. + Arc::new(CoalescePartitionsExec::new(plan)) + } else { + // Otherwise, repartition the plan using a round-robin partitioning scheme. + Arc::new(RepartitionExec::try_new( + plan, + Partitioning::RoundRobinBatch(table_partition_count), + )?) + }; + + // Get the task context from the session state. + let task_ctx = state.task_ctx(); + + // Execute the plan and collect the results into batches. + let mut tasks = vec![]; + for i in 0..plan.output_partitioning().partition_count() { + let plan = plan.clone(); + let stream = plan.execute(i, task_ctx.clone())?; + let handle: JoinHandle>> = task::spawn(async move { + stream.try_collect().await.map_err(DataFusionError::from) + }); + tasks.push(AbortOnDropSingle::new(handle)); + } + let results: Result>> = futures::future::join_all(tasks) + .await + .into_iter() + .map(|result| { + result.map_err(|e| DataFusionError::Execution(format!("{e}")))? + }) + .collect::>() + .into_iter() + .collect::>>(); + + // Write the results into the table. + let mut all_batches = self.batches.write().await; + for (i, result) in results?.into_iter().enumerate() { + let batches = all_batches.get_mut(i).unwrap(); + batches.extend(result); + } + + Ok(()) + } } #[cfg(test)] mod tests { use super::*; + use crate::datasource::provider_as_source; use crate::from_slice::FromSlice; use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; + use datafusion_expr::LogicalPlanBuilder; use futures::StreamExt; use std::collections::HashMap; @@ -388,4 +474,115 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_insert_into_single_partition() -> Result<()> { + // Create a new session context + let session_ctx = SessionContext::new(); + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], + )?; + // Create a new table with one partition that contains the batch of data + let initial_table = Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + )?); + // Convert the table into a provider so that it can be used in a query + let provider = provider_as_source(Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + )?)); + // Create a table scan logical plan to read from the table + let table_scan = + Arc::new(LogicalPlanBuilder::scan("source", provider, None)?.build()?); + // Insert the data from the provider into the table + initial_table + .insert_into_table(&session_ctx.state(), &table_scan) + .await?; + // Ensure that the table now contains two batches of data in the same partition + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + + // Create a new provider with 2 partitions + let multi_partition_provider = provider_as_source(Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?)); + // Create a new table scan logical plan to read from the provider + let table_scan = Arc::new( + LogicalPlanBuilder::scan("source", multi_partition_provider, None)? + .build()?, + ); + // Insert the data from the provider into the table. We expect coalescing partitions. + initial_table + .insert_into_table(&session_ctx.state(), &table_scan) + .await?; + // Ensure that the table now contains 4 batches of data with only 1 partition + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); + assert_eq!(initial_table.batches.read().await.len(), 1); + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_multiple_partition() -> Result<()> { + let session_ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // create a record batch with values 1, 2, 3 in a column named "a" + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], + )?; + + // create a memory table with two partitions, each having one batch with the same data + let initial_table = Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()], vec![batch.clone()]], + )?); + + // create a data source provider from a memory table with a single partition + let single_partition_provider = provider_as_source(Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone(), batch.clone()]], + )?)); + + // create a logical plan for scanning the data source provider + let table_scan = Arc::new( + LogicalPlanBuilder::scan("source", single_partition_provider, None)? + .build()?, + ); + + // insert the data from the 2 partitions data source provider into the initial table + initial_table + .insert_into_table(&session_ctx.state(), &table_scan) + .await?; + + // We expect one-to-one partition mapping, each partition gets 1 batch. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2); + + // create a data source provider from a memory table with with only 1 partition + let multi_partition_provider = provider_as_source(Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?)); + // create a logical plan for scanning the data source provider + let table_scan = Arc::new( + LogicalPlanBuilder::scan("source", multi_partition_provider, None)? + .build()?, + ); + // insert the data from the 1 partition data source provider into the initial table. + // We expect round robin repartition here. + initial_table + .insert_into_table(&session_ctx.state(), &table_scan) + .await?; + // Ensure that the table now contains 3 batches of data with 2 partitions. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3); + assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3); + Ok(()) + } } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 0340b4761bc76..876b789c331f3 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -31,7 +31,7 @@ use crate::{ optimizer::PhysicalOptimizerRule, }, }; -use datafusion_expr::{DescribeTable, StringifiedPlan}; +use datafusion_expr::{DescribeTable, DmlStatement, StringifiedPlan, WriteOp}; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; use parking_lot::RwLock; @@ -318,6 +318,20 @@ impl SessionContext { let plan = self.state().create_logical_plan(sql).await?; match plan { + LogicalPlan::Dml(DmlStatement { + table_name, + op: WriteOp::Insert, + input, + .. + }) => { + let exist = self.table_exist(&table_name)?; + if exist { + let name = table_name.table(); + let provider = self.table_provider(name).await?; + provider.insert_into_table(&self.state(), &input).await?; + } + self.return_empty_dataframe() + } LogicalPlan::CreateExternalTable(cmd) => { self.create_external_table(&cmd).await } @@ -2714,6 +2728,46 @@ mod tests { Ok(()) } + #[tokio::test] + async fn sql_table_insert() -> Result<()> { + let session_ctx = SessionContext::with_config(SessionConfig::new()); + + session_ctx + .sql("CREATE TABLE abc AS VALUES (1,2,3), (4,5,6)") + .await? + .collect() + .await?; + session_ctx + .sql("CREATE TABLE xyz AS VALUES (1,3,3), (5,5,6)") + .await? + .collect() + .await?; + + let sql = "INSERT INTO abc SELECT * FROM xyz"; + session_ctx.sql(sql).await?.collect().await?; + + let results = session_ctx + .sql("SELECT * FROM abc") + .await? + .collect() + .await?; + + let expected = vec![ + "+---------+---------+---------+", + "| column1 | column2 | column3 |", + "+---------+---------+---------+", + "| 1 | 2 | 3 |", + "| 4 | 5 | 6 |", + "| 1 | 3 | 3 |", + "| 5 | 5 | 6 |", + "+---------+---------+---------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) + } + #[tokio::test] async fn sql_create_catalog() -> Result<()> { // the information schema used to introduce cyclic Arcs diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 0461a1d858888..97344143f6f86 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -26,7 +26,7 @@ use crate::utils::{ columnize_expr, compare_sort_expr, ensure_any_column_reference_is_unambiguous, exprlist_to_fields, from_plan, }; -use crate::{and, binary_expr, Operator}; +use crate::{and, binary_expr, DmlStatement, Operator, WriteOp}; use crate::{ logical_plan::{ Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, @@ -42,8 +42,8 @@ use crate::{ }; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - ToDFSchema, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, + ScalarValue, ToDFSchema, }; use std::any::Any; use std::cmp::Ordering; @@ -203,6 +203,24 @@ impl LogicalPlanBuilder { Self::scan_with_filters(table_name, table_source, projection, vec![]) } + /// Convert a logical plan into a builder with a DmlStatement + pub fn insert_into( + input: LogicalPlan, + table_name: impl Into, + table_schema: &Schema, + ) -> Result { + let table_name = OwnedTableReference::Bare { + table: table_name.into(), + }; + let table_schema = table_schema.clone().to_dfschema_ref()?; + Ok(Self::from(LogicalPlan::Dml(DmlStatement { + table_name, + table_schema, + op: WriteOp::Insert, + input: Arc::new(input), + }))) + } + /// Convert a table provider into a builder with a TableScan pub fn scan_with_filters( table_name: impl Into, From d8d767ba5c2fc26f2c82489080dec9052c4b9825 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 8 Mar 2023 22:44:34 -0600 Subject: [PATCH 2/4] Code simplifications --- datafusion/core/src/datasource/datasource.rs | 7 ++-- .../core/src/datasource/listing/table.rs | 3 +- datafusion/core/src/datasource/memory.rs | 39 +++++++------------ datafusion/core/src/execution/context.rs | 8 ++-- datafusion/expr/src/logical_plan/builder.rs | 2 +- 5 files changed, 26 insertions(+), 33 deletions(-) diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs index 60f3113dfc54a..8db075a30a79c 100644 --- a/datafusion/core/src/datasource/datasource.rs +++ b/datafusion/core/src/datasource/datasource.rs @@ -98,13 +98,14 @@ pub trait TableProvider: Sync + Send { None } - /// Insert API - async fn insert_into_table( + /// Insert into this table + async fn insert_into( &self, _state: &SessionState, _input: &LogicalPlan, ) -> Result<()> { - Err(DataFusionError::Internal("Not implemented".to_owned())) + let msg = "Insertion not implemented for this table".to_owned(); + Err(DataFusionError::NotImplemented(msg)) } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 0e9c0346d76ca..f6d9c959eb1ac 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -32,7 +32,6 @@ use futures::{future, stream, StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::ObjectMeta; -use super::PartitionedFile; use crate::datasource::file_format::file_type::{FileCompressionType, FileType}; use crate::datasource::{ file_format::{ @@ -56,6 +55,8 @@ use crate::{ }, }; +use super::PartitionedFile; + use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; /// Configuration for creating a [`ListingTable`] diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 71351d99b2a6e..69019b0d9b565 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -29,7 +29,6 @@ use async_trait::async_trait; use datafusion_expr::LogicalPlan; use tokio::sync::RwLock; use tokio::task; -use tokio::task::JoinHandle; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; @@ -156,29 +155,25 @@ impl TableProvider for MemTable { )?)) } - /// Inserts the results of executing a given `LogicalPlan` into this `Table`. - /// The `LogicalPlan` must have the same schema as this `Table`. + /// Inserts the execution results of a given [LogicalPlan] into this [MemTable]. + /// The `LogicalPlan` must have the same schema as this `MemTable`. /// /// # Arguments /// - /// * `state` - The `SessionState` containing the context for executing the plan. + /// * `state` - The [SessionState] containing the context for executing the plan. /// * `input` - The `LogicalPlan` to execute and insert. /// /// # Returns /// /// * A `Result` indicating success or failure. - async fn insert_into_table( - &self, - state: &SessionState, - input: &LogicalPlan, - ) -> Result<()> { + async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> { // Create a physical plan from the logical plan. let plan = state.create_physical_plan(input).await?; // Check that the schema of the plan matches the schema of this table. if !plan.schema().eq(&self.schema) { return Err(DataFusionError::Plan( - "Inserting query must have same schema with the table.".to_string(), + "Inserting query must have the same schema with the table.".to_string(), )); } @@ -206,28 +201,24 @@ impl TableProvider for MemTable { // Execute the plan and collect the results into batches. let mut tasks = vec![]; - for i in 0..plan.output_partitioning().partition_count() { - let plan = plan.clone(); - let stream = plan.execute(i, task_ctx.clone())?; - let handle: JoinHandle>> = task::spawn(async move { + for idx in 0..table_partition_count { + let stream = plan.execute(idx, task_ctx.clone())?; + let handle = task::spawn(async move { stream.try_collect().await.map_err(DataFusionError::from) }); tasks.push(AbortOnDropSingle::new(handle)); } - let results: Result>> = futures::future::join_all(tasks) + let results = futures::future::join_all(tasks) .await .into_iter() .map(|result| { result.map_err(|e| DataFusionError::Execution(format!("{e}")))? }) - .collect::>() - .into_iter() - .collect::>>(); + .collect::>>>()?; // Write the results into the table. let mut all_batches = self.batches.write().await; - for (i, result) in results?.into_iter().enumerate() { - let batches = all_batches.get_mut(i).unwrap(); + for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) { batches.extend(result); } @@ -502,7 +493,7 @@ mod tests { Arc::new(LogicalPlanBuilder::scan("source", provider, None)?.build()?); // Insert the data from the provider into the table initial_table - .insert_into_table(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &table_scan) .await?; // Ensure that the table now contains two batches of data in the same partition assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); @@ -519,7 +510,7 @@ mod tests { ); // Insert the data from the provider into the table. We expect coalescing partitions. initial_table - .insert_into_table(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &table_scan) .await?; // Ensure that the table now contains 4 batches of data with only 1 partition assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); @@ -558,7 +549,7 @@ mod tests { // insert the data from the 2 partitions data source provider into the initial table initial_table - .insert_into_table(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &table_scan) .await?; // We expect one-to-one partition mapping, each partition gets 1 batch. @@ -578,7 +569,7 @@ mod tests { // insert the data from the 1 partition data source provider into the initial table. // We expect round robin repartition here. initial_table - .insert_into_table(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &table_scan) .await?; // Ensure that the table now contains 3 batches of data with 2 partitions. assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3); diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 876b789c331f3..3c43aa7b98c5b 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -308,7 +308,8 @@ impl SessionContext { /// Creates a [`DataFrame`] that will execute a SQL query. /// - /// Note: This API implements DDL such as `CREATE TABLE` and `CREATE VIEW` with in-memory + /// Note: This API implements DDL statements such as `CREATE TABLE` and + /// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory /// default implementations. /// /// If this is not desirable, consider using [`SessionState::create_logical_plan()`] which @@ -324,11 +325,10 @@ impl SessionContext { input, .. }) => { - let exist = self.table_exist(&table_name)?; - if exist { + if self.table_exist(&table_name)? { let name = table_name.table(); let provider = self.table_provider(name).await?; - provider.insert_into_table(&self.state(), &input).await?; + provider.insert_into(&self.state(), &input).await?; } self.return_empty_dataframe() } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 97344143f6f86..2188126630226 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -203,7 +203,7 @@ impl LogicalPlanBuilder { Self::scan_with_filters(table_name, table_source, projection, vec![]) } - /// Convert a logical plan into a builder with a DmlStatement + /// Convert a logical plan into a builder with a [DmlStatement] pub fn insert_into( input: LogicalPlan, table_name: impl Into, From 3cccce161faf79b03a7c944aebb87762557fcbc9 Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Thu, 9 Mar 2023 10:49:58 +0300 Subject: [PATCH 3/4] Minor comment refactor --- datafusion/core/src/datasource/memory.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 69019b0d9b565..683a49e28e25d 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -161,7 +161,7 @@ impl TableProvider for MemTable { /// # Arguments /// /// * `state` - The [SessionState] containing the context for executing the plan. - /// * `input` - The `LogicalPlan` to execute and insert. + /// * `input` - The [LogicalPlan] to execute and insert. /// /// # Returns /// From a7f79436dd21bf4472a80116979e15d8af9d6e7e Mon Sep 17 00:00:00 2001 From: metesynnada <100111937+metesynnada@users.noreply.github.com> Date: Mon, 13 Mar 2023 14:32:34 +0300 Subject: [PATCH 4/4] Revamping tests and refactor code --- datafusion/core/src/datasource/memory.rs | 161 ++++++++++-------- datafusion/core/src/execution/context.rs | 45 +---- .../src/engines/datafusion/insert.rs | 93 ---------- .../src/engines/datafusion/mod.rs | 3 - .../tests/sqllogictests/test_files/ddl.slt | 28 ++- datafusion/expr/src/logical_plan/builder.rs | 9 +- 6 files changed, 125 insertions(+), 214 deletions(-) delete mode 100644 datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 683a49e28e25d..b5fa33e38827d 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -182,26 +182,28 @@ impl TableProvider for MemTable { let table_partition_count = self.batches.read().await.len(); // Adjust the plan as necessary to match the number of partitions in the table. - let plan: Arc = - if plan_partition_count == table_partition_count { - plan - } else if table_partition_count == 1 { - // If the table has only one partition, coalesce the partitions in the plan. - Arc::new(CoalescePartitionsExec::new(plan)) - } else { - // Otherwise, repartition the plan using a round-robin partitioning scheme. - Arc::new(RepartitionExec::try_new( - plan, - Partitioning::RoundRobinBatch(table_partition_count), - )?) - }; + let plan: Arc = if plan_partition_count + == table_partition_count + || table_partition_count == 0 + { + plan + } else if table_partition_count == 1 { + // If the table has only one partition, coalesce the partitions in the plan. + Arc::new(CoalescePartitionsExec::new(plan)) + } else { + // Otherwise, repartition the plan using a round-robin partitioning scheme. + Arc::new(RepartitionExec::try_new( + plan, + Partitioning::RoundRobinBatch(table_partition_count), + )?) + }; // Get the task context from the session state. let task_ctx = state.task_ctx(); // Execute the plan and collect the results into batches. let mut tasks = vec![]; - for idx in 0..table_partition_count { + for idx in 0..plan.output_partitioning().partition_count() { let stream = plan.execute(idx, task_ctx.clone())?; let handle = task::spawn(async move { stream.try_collect().await.map_err(DataFusionError::from) @@ -218,8 +220,13 @@ impl TableProvider for MemTable { // Write the results into the table. let mut all_batches = self.batches.write().await; - for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) { - batches.extend(result); + + if all_batches.is_empty() { + *all_batches = results + } else { + for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) { + batches.extend(result); + } } Ok(()) @@ -466,8 +473,19 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_single_partition() -> Result<()> { + fn create_mem_table_scan( + schema: SchemaRef, + data: Vec>, + ) -> Result> { + // Convert the table into a provider so that it can be used in a query + let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?)); + // Create a table scan logical plan to read from the table + Ok(Arc::new( + LogicalPlanBuilder::scan("source", provider, None)?.build()?, + )) + } + + fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> { // Create a new session context let session_ctx = SessionContext::new(); // Create a new schema with one field called "a" of type Int32 @@ -478,39 +496,35 @@ mod tests { schema.clone(), vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], )?; - // Create a new table with one partition that contains the batch of data + Ok((session_ctx, schema, batch)) + } + + #[tokio::test] + async fn test_insert_into_single_partition() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; let initial_table = Arc::new(MemTable::try_new( schema.clone(), vec![vec![batch.clone()]], )?); - // Convert the table into a provider so that it can be used in a query - let provider = provider_as_source(Arc::new(MemTable::try_new( - schema.clone(), - vec![vec![batch.clone()]], - )?)); // Create a table scan logical plan to read from the table - let table_scan = - Arc::new(LogicalPlanBuilder::scan("source", provider, None)?.build()?); + let single_partition_table_scan = + create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?; // Insert the data from the provider into the table initial_table - .insert_into(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &single_partition_table_scan) .await?; // Ensure that the table now contains two batches of data in the same partition assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); // Create a new provider with 2 partitions - let multi_partition_provider = provider_as_source(Arc::new(MemTable::try_new( + let multi_partition_table_scan = create_mem_table_scan( schema.clone(), vec![vec![batch.clone()], vec![batch]], - )?)); - // Create a new table scan logical plan to read from the provider - let table_scan = Arc::new( - LogicalPlanBuilder::scan("source", multi_partition_provider, None)? - .build()?, - ); + )?; + // Insert the data from the provider into the table. We expect coalescing partitions. initial_table - .insert_into(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &multi_partition_table_scan) .await?; // Ensure that the table now contains 4 batches of data with only 1 partition assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); @@ -520,60 +534,73 @@ mod tests { #[tokio::test] async fn test_insert_into_multiple_partition() -> Result<()> { - let session_ctx = SessionContext::new(); - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - // create a record batch with values 1, 2, 3 in a column named "a" - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], - )?; - + let (session_ctx, schema, batch) = create_initial_ctx()?; // create a memory table with two partitions, each having one batch with the same data let initial_table = Arc::new(MemTable::try_new( schema.clone(), vec![vec![batch.clone()], vec![batch.clone()]], )?); - // create a data source provider from a memory table with a single partition - let single_partition_provider = provider_as_source(Arc::new(MemTable::try_new( + // scan a data source provider from a memory table with a single partition + let single_partition_table_scan = create_mem_table_scan( schema.clone(), vec![vec![batch.clone(), batch.clone()]], - )?)); - - // create a logical plan for scanning the data source provider - let table_scan = Arc::new( - LogicalPlanBuilder::scan("source", single_partition_provider, None)? - .build()?, - ); + )?; - // insert the data from the 2 partitions data source provider into the initial table + // insert the data from the 1 partition data source provider into the initial table initial_table - .insert_into(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &single_partition_table_scan) .await?; - // We expect one-to-one partition mapping, each partition gets 1 batch. + // We expect round robin repartition here, each partition gets 1 batch. assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2); - // create a data source provider from a memory table with with only 1 partition - let multi_partition_provider = provider_as_source(Arc::new(MemTable::try_new( + // scan a data source provider from a memory table with 2 partition + let multi_partition_table_scan = create_mem_table_scan( schema.clone(), vec![vec![batch.clone()], vec![batch]], - )?)); - // create a logical plan for scanning the data source provider - let table_scan = Arc::new( - LogicalPlanBuilder::scan("source", multi_partition_provider, None)? - .build()?, - ); - // insert the data from the 1 partition data source provider into the initial table. - // We expect round robin repartition here. + )?; + // We expect one-to-one partition mapping. initial_table - .insert_into(&session_ctx.state(), &table_scan) + .insert_into(&session_ctx.state(), &multi_partition_table_scan) .await?; // Ensure that the table now contains 3 batches of data with 2 partitions. assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3); assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3); Ok(()) } + + #[tokio::test] + async fn test_insert_into_empty_table() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; + // create empty memory table + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?); + + // scan a data source provider from a memory table with a single partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone(), batch.clone()]], + )?; + + // insert the data from the 1 partition data source provider into the initial table + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + + // scan a data source provider from a memory table with 2 partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?; + // We expect coalesce partitions here. + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + // Ensure that the table now contains 3 batches of data with 2 partitions. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); + Ok(()) + } } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 3c43aa7b98c5b..06f4e9270cba6 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -329,6 +329,11 @@ impl SessionContext { let name = table_name.table(); let provider = self.table_provider(name).await?; provider.insert_into(&self.state(), &input).await?; + } else { + return Err(DataFusionError::Execution(format!( + "Table '{}' does not exist", + table_name + ))); } self.return_empty_dataframe() } @@ -2728,46 +2733,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn sql_table_insert() -> Result<()> { - let session_ctx = SessionContext::with_config(SessionConfig::new()); - - session_ctx - .sql("CREATE TABLE abc AS VALUES (1,2,3), (4,5,6)") - .await? - .collect() - .await?; - session_ctx - .sql("CREATE TABLE xyz AS VALUES (1,3,3), (5,5,6)") - .await? - .collect() - .await?; - - let sql = "INSERT INTO abc SELECT * FROM xyz"; - session_ctx.sql(sql).await?.collect().await?; - - let results = session_ctx - .sql("SELECT * FROM abc") - .await? - .collect() - .await?; - - let expected = vec![ - "+---------+---------+---------+", - "| column1 | column2 | column3 |", - "+---------+---------+---------+", - "| 1 | 2 | 3 |", - "| 4 | 5 | 6 |", - "| 1 | 3 | 3 |", - "| 5 | 5 | 6 |", - "+---------+---------+---------+", - ]; - - assert_batches_eq!(expected, &results); - - Ok(()) - } - #[tokio::test] async fn sql_create_catalog() -> Result<()> { // the information schema used to introduce cyclic Arcs diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs deleted file mode 100644 index a8fca3b16c06d..0000000000000 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::error::Result; -use crate::engines::datafusion::util::LogicTestContextProvider; -use crate::engines::output::DFOutput; -use arrow::record_batch::RecordBatch; -use datafusion::datasource::MemTable; -use datafusion::prelude::SessionContext; -use datafusion_common::{DFSchema, DataFusionError}; -use datafusion_expr::Expr as DFExpr; -use datafusion_sql::planner::{object_name_to_table_reference, PlannerContext, SqlToRel}; -use sqllogictest::DBOutput; -use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement}; -use std::sync::Arc; - -pub async fn insert(ctx: &SessionContext, insert_stmt: SQLStatement) -> Result { - // First, use sqlparser to get table name and insert values - let table_reference; - let insert_values: Vec>; - match insert_stmt { - SQLStatement::Insert { - table_name, source, .. - } => { - table_reference = object_name_to_table_reference( - table_name, - ctx.enable_ident_normalization(), - )?; - - // Todo: check columns match table schema - match *source.body { - SetExpr::Values(values) => { - insert_values = values.rows; - } - _ => { - // Directly panic: make it easy to find the location of the error. - panic!() - } - } - } - _ => unreachable!(), - } - - // Second, get batches in table and destroy the old table - let mut origin_batches = ctx.table(&table_reference).await?.collect().await?; - let schema = ctx.table_provider(&table_reference).await?.schema(); - ctx.deregister_table(&table_reference)?; - - // Third, transfer insert values to `RecordBatch` - // Attention: schema info can be ignored. (insert values don't contain schema info) - let sql_to_rel = SqlToRel::new(&LogicTestContextProvider {}); - let num_rows = insert_values.len(); - for row in insert_values.into_iter() { - let logical_exprs = row - .into_iter() - .map(|expr| { - sql_to_rel.sql_to_expr( - expr, - &DFSchema::empty(), - &mut PlannerContext::new(), - ) - }) - .collect::, DataFusionError>>()?; - // Directly use `select` to get `RecordBatch` - let dataframe = ctx.read_empty()?; - origin_batches.extend(dataframe.select(logical_exprs)?.collect().await?) - } - - // Replace new batches schema to old schema - for batch in origin_batches.iter_mut() { - *batch = RecordBatch::try_new(schema.clone(), batch.columns().to_vec())?; - } - - // Final, create new memtable with same schema. - let new_provider = MemTable::try_new(schema, vec![origin_batches])?; - ctx.register_table(&table_reference, Arc::new(new_provider))?; - - Ok(DBOutput::StatementComplete(num_rows as u64)) -} diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs index 1f8f7feb36e5b..cdd6663a5e0bb 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs +++ b/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs @@ -26,13 +26,11 @@ use create_table::create_table; use datafusion::arrow::record_batch::RecordBatch; use datafusion::prelude::SessionContext; use datafusion_sql::parser::{DFParser, Statement}; -use insert::insert; use sqllogictest::DBOutput; use sqlparser::ast::Statement as SQLStatement; mod create_table; mod error; -mod insert; mod normalize; mod util; @@ -85,7 +83,6 @@ async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result return insert(ctx, statement).await, SQLStatement::CreateTable { query, constraints, diff --git a/datafusion/core/tests/sqllogictests/test_files/ddl.slt b/datafusion/core/tests/sqllogictests/test_files/ddl.slt index 642093c364f3d..59bfc91b541f2 100644 --- a/datafusion/core/tests/sqllogictests/test_files/ddl.slt +++ b/datafusion/core/tests/sqllogictests/test_files/ddl.slt @@ -63,7 +63,7 @@ statement error Table 'user' doesn't exist. DROP TABLE user; # Can not insert into a undefined table -statement error No table named 'user' +statement error DataFusion error: Error during planning: table 'datafusion.public.user' not found insert into user values(1, 20); ########## @@ -421,9 +421,27 @@ statement ok DROP TABLE aggregate_simple +# sql_table_insert +statement ok +CREATE TABLE abc AS VALUES (1,2,3), (4,5,6); + +statement ok +CREATE TABLE xyz AS VALUES (1,3,3), (5,5,6); + +statement ok +INSERT INTO abc SELECT * FROM xyz; + +query III +SELECT * FROM abc +---- +1 2 3 +4 5 6 +1 3 3 +5 5 6 + # Should create an empty table statement ok -CREATE TABLE table_without_values(field1 BIGINT, field2 BIGINT); +CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); # Should skip existing table @@ -444,8 +462,8 @@ CREATE OR REPLACE TABLE IF NOT EXISTS table_without_values(field1 BIGINT, field2 statement ok insert into table_without_values values (1, 2), (2, 3), (2, 4); -query II rowsort -select * from table_without_values; +query II +select * from table_without_values ---- 1 2 2 3 @@ -454,7 +472,7 @@ select * from table_without_values; # Should recreate existing table statement ok -CREATE OR REPLACE TABLE table_without_values(field1 BIGINT, field2 BIGINT); +CREATE OR REPLACE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); # Should insert into a recreated table diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2188126630226..91e7ad5e3ee62 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -203,18 +203,15 @@ impl LogicalPlanBuilder { Self::scan_with_filters(table_name, table_source, projection, vec![]) } - /// Convert a logical plan into a builder with a [DmlStatement] + /// Create a [DmlStatement] for inserting the contents of this builder into the named table pub fn insert_into( input: LogicalPlan, - table_name: impl Into, + table_name: impl Into, table_schema: &Schema, ) -> Result { - let table_name = OwnedTableReference::Bare { - table: table_name.into(), - }; let table_schema = table_schema.clone().to_dfschema_ref()?; Ok(Self::from(LogicalPlan::Dml(DmlStatement { - table_name, + table_name: table_name.into(), table_schema, op: WriteOp::Insert, input: Arc::new(input),