diff --git a/crates/executor/src/datafusion/logical_plan/merge.rs b/crates/executor/src/datafusion/logical_plan/merge.rs index 8de396997..9417cf57d 100644 --- a/crates/executor/src/datafusion/logical_plan/merge.rs +++ b/crates/executor/src/datafusion/logical_plan/merge.rs @@ -25,12 +25,31 @@ impl MergeIntoCOWSink { pub fn new( input: Arc, target: DataFusionTable, + has_insert: bool, + has_update: bool, + has_delete: bool, ) -> datafusion_common::Result { - let field = Field::new("number of rows updated", DataType::Int64, false); - let schema = DFSchema::new_with_metadata( - vec![(None, Arc::new(field))], - std::collections::HashMap::new(), - )?; + let inserted = Arc::new(Field::new( + "number of rows inserted", + DataType::Int64, + false, + )); + let updated = Arc::new(Field::new("number of rows updated", DataType::Int64, false)); + let deleted = Arc::new(Field::new("number of rows deleted", DataType::Int64, false)); + let mut fields: Vec<(Option, Arc)> = Vec::new(); + if has_insert { + fields.push((None, inserted)); + } + if has_update { + fields.push((None, updated.clone())); + } + if has_delete { + fields.push((None, deleted)); + } + if fields.is_empty() { + fields.push((None, updated)); + } + let schema = DFSchema::new_with_metadata(fields, std::collections::HashMap::new())?; Ok(Self { input, diff --git a/crates/executor/src/datafusion/physical_plan/merge.rs b/crates/executor/src/datafusion/physical_plan/merge.rs index 08024402f..22b63bc75 100644 --- a/crates/executor/src/datafusion/physical_plan/merge.rs +++ b/crates/executor/src/datafusion/physical_plan/merge.rs @@ -1,6 +1,6 @@ use datafusion::{ arrow::{ - array::{Array, BooleanArray, RecordBatch, StringArray, downcast_array}, + array::{Array, ArrayRef, BooleanArray, RecordBatch, StringArray, downcast_array}, compute::{ filter, filter_record_batch, kernels::cmp::{distinct, eq}, @@ -15,12 +15,10 @@ use datafusion_iceberg::{ DataFusionTable, error::Error as DataFusionIcebergError, table::write_parquet_data_files, }; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, coalesce_partitions::CoalescePartitionsExec, execution_plan::{Boundedness, EmissionType}, - expressions::Column, - projection::ProjectionExec, stream::RecordBatchStreamAdapter, }; use futures::{Stream, StreamExt}; @@ -31,6 +29,8 @@ use snafu::ResultExt; use std::{ collections::{HashMap, HashSet}, num::NonZeroUsize, + ops::BitAnd, + sync::atomic::{AtomicI64, Ordering}, sync::{Arc, Mutex}, task::Poll, thread::available_parallelism, @@ -42,6 +42,8 @@ pub(crate) static TARGET_EXISTS_COLUMN: &str = "__target_exists"; pub(crate) static SOURCE_EXISTS_COLUMN: &str = "__source_exists"; pub(crate) static DATA_FILE_PATH_COLUMN: &str = "__data_file_path"; pub(crate) static MANIFEST_FILE_PATH_COLUMN: &str = "__manifest_file_path"; +pub(crate) static MERGE_UPDATED_COLUMN: &str = "__merge_row_updated"; +pub(crate) static MERGE_INSERTED_COLUMN: &str = "__merge_row_inserted"; static THREAD_FILE_RATIO: usize = 4; #[derive(Debug)] @@ -137,6 +139,8 @@ impl ExecutionPlan for MergeIntoCOWSinkExec { let schema = Arc::new(self.schema.as_arrow().clone()); let matching_files: Arc>> = Arc::default(); + let updated_rows: Arc = Arc::new(AtomicI64::new(0)); + let inserted_rows: Arc = Arc::new(AtomicI64::new(0)); let coalesce = CoalescePartitionsExec::new(self.input.clone()); @@ -146,16 +150,24 @@ impl ExecutionPlan for MergeIntoCOWSinkExec { matching_files.clone(), )); - // Remove auxiliary columns - let projection = - ProjectionExec::try_new(schema_projection(&self.input.schema()), filtered)?; - - let batches = projection.execute(partition, context.clone())?; + let input_batches = filtered.execute(partition, context.clone())?; + let count_and_project_stream = MergeCOWCountAndProjectStream::new( + input_batches, + updated_rows.clone(), + inserted_rows.clone(), + ); let stream = futures::stream::once({ let tabular = self.target.tabular.clone(); let branch = self.target.branch.clone(); let schema = schema.clone(); + let updated_rows = Arc::clone(&updated_rows); + let inserted_rows = Arc::clone(&inserted_rows); + let projected_schema = count_and_project_stream.projected_schema(); + let batches: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( + projected_schema, + count_and_project_stream, + )); async move { #[allow(clippy::unwrap_used)] let value = tabular.read().unwrap().clone(); @@ -204,7 +216,37 @@ impl ExecutionPlan for MergeIntoCOWSinkExec { #[allow(clippy::unwrap_used)] let mut lock = tabular.write().unwrap(); *lock = Tabular::Table(table); - Ok(RecordBatch::new_empty(schema)) + // Return a one-row result for DML, so clients don't render "No data result" on success. + let updated = updated_rows.load(Ordering::Relaxed); + let inserted = inserted_rows.load(Ordering::Relaxed); + // MERGE DELETE is not supported yet + let deleted = 0i64; + + let arrays = schema + .fields() + .iter() + .map(|f| { + let v = match f.name().as_str() { + "number of rows inserted" => inserted, + "number of rows updated" => updated, + "number of rows deleted" => deleted, + other => { + return Err(DataFusionError::Internal(format!( + "Unexpected MERGE result column: {other}" + ))); + } + }; + let a: ArrayRef = + Arc::new(datafusion::arrow::array::Int64Array::from(vec![v])); + Ok(a) + }) + .collect::, DataFusionError>>()?; + + RecordBatch::try_new(schema.clone(), arrays).map_err(|e| { + DataFusionError::Internal(format!( + "Failed to build MERGE result record batch: {e}" + )) + }) } }) .boxed(); @@ -213,6 +255,142 @@ impl ExecutionPlan for MergeIntoCOWSinkExec { } } +pin_project! { + /// Stream wrapper that counts per-action MERGE rows (insert/update markers) and projects away + /// auxiliary merge columns before writing to data files. + pub struct MergeCOWCountAndProjectStream { + projection_indices: Vec, + projected_schema: Arc, + updated_idx: Option, + inserted_idx: Option, + updated_rows: Arc, + inserted_rows: Arc, + + #[pin] + input: SendableRecordBatchStream, + } +} + +impl MergeCOWCountAndProjectStream { + fn new( + input: SendableRecordBatchStream, + updated_rows: Arc, + inserted_rows: Arc, + ) -> Self { + let input_schema = input.schema(); + + let updated_idx = input_schema.index_of(MERGE_UPDATED_COLUMN).ok(); + let inserted_idx = input_schema.index_of(MERGE_INSERTED_COLUMN).ok(); + + // Drop auxiliary columns so we only write table columns to parquet + let projection_indices: Vec = input_schema + .fields() + .iter() + .enumerate() + .filter_map(|(i, f)| { + let name = f.name(); + if name != SOURCE_EXISTS_COLUMN + && name != DATA_FILE_PATH_COLUMN + && name != MANIFEST_FILE_PATH_COLUMN + && name != MERGE_UPDATED_COLUMN + && name != MERGE_INSERTED_COLUMN + { + Some(i) + } else { + None + } + }) + .collect(); + + let projected_fields = projection_indices + .iter() + .map(|i| input_schema.field(*i).clone()) + .collect::>(); + + let projected_schema = Arc::new(Schema::new(projected_fields)); + + Self { + projection_indices, + projected_schema, + updated_idx, + inserted_idx, + updated_rows, + inserted_rows, + input, + } + } + + fn projected_schema(&self) -> Arc { + self.projected_schema.clone() + } +} + +impl Stream for MergeCOWCountAndProjectStream { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut project = self.project(); + match project.input.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(batch))) => { + if let Some(updated_idx) = *project.updated_idx + && let Some(col) = batch.columns().get(updated_idx) + { + let updated = downcast_array::(col.as_ref()); + let n = usize_to_i64_saturating(count_true_and_valid(&updated)); + project.updated_rows.fetch_add(n, Ordering::Relaxed); + } + if let Some(inserted_idx) = *project.inserted_idx + && let Some(col) = batch.columns().get(inserted_idx) + { + let inserted = downcast_array::(col.as_ref()); + let n = usize_to_i64_saturating(count_true_and_valid(&inserted)); + project.inserted_rows.fetch_add(n, Ordering::Relaxed); + } + + let cols = project + .projection_indices + .iter() + .map(|i| batch.column(*i).clone()) + .collect::>(); + + let projected = RecordBatch::try_new(project.projected_schema.clone(), cols) + .map_err(|e| { + DataFusionError::Internal(format!( + "Failed to project MERGE record batch: {e}" + )) + })?; + Poll::Ready(Some(Ok(projected))) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +/// Fast count of `true` values, treating NULL as false, using Arrow bitmaps. +#[inline] +fn count_true_and_valid(arr: &BooleanArray) -> usize { + if arr.null_count() == 0 { + return arr.values().count_set_bits(); + } + + if let Some(nulls) = arr.logical_nulls() { + let valid = nulls.inner(); + return arr.values().bitand(valid).count_set_bits(); + } + + arr.values().count_set_bits() +} + +#[inline] +fn usize_to_i64_saturating(v: usize) -> i64 { + i64::try_from(v).unwrap_or(i64::MAX) +} + #[derive(Debug)] struct MergeCOWFilterExec { input: Arc, @@ -672,39 +850,6 @@ fn unique_files_and_manifests( Ok(result) } -/// Creates a projection expression list from a schema by filtering out auxiliary columns. -/// -/// This function builds a vector of physical expressions and column names from the given schema, -/// excluding internal auxiliary columns used for merge operations. The auxiliary columns that -/// are filtered out are: -/// - `__source_exists`: Indicates if the source record exists -/// - `__data_file_path`: Path to the data file -/// - `__manifest_file_path`: Path to the manifest file -/// -/// # Arguments -/// * `schema` - The schema to create projections from -/// -/// # Returns -/// * `Vec<(Arc, String)>` - Vector of tuples containing physical expressions and column names -fn schema_projection(schema: &Schema) -> Vec<(Arc, String)> { - schema - .fields() - .iter() - .enumerate() - .filter_map(|(i, field)| -> Option<(Arc, String)> { - let name = field.name(); - if name != SOURCE_EXISTS_COLUMN - && name != DATA_FILE_PATH_COLUMN - && name != MANIFEST_FILE_PATH_COLUMN - { - Some((Arc::new(Column::new(name, i)), name.to_owned())) - } else { - None - } - }) - .collect() -} - #[cfg(test)] mod tests { #![allow(clippy::unwrap_used)] diff --git a/crates/executor/src/query.rs b/crates/executor/src/query.rs index a5420a417..44732206a 100644 --- a/crates/executor/src/query.rs +++ b/crates/executor/src/query.rs @@ -12,7 +12,8 @@ use super::utils::{NormalizedIdent, is_logical_plan_effectively_empty}; use crate::datafusion::logical_plan::merge::MergeIntoCOWSink; use crate::datafusion::physical_optimizer::runtime_physical_optimizer_rules; use crate::datafusion::physical_plan::merge::{ - DATA_FILE_PATH_COLUMN, MANIFEST_FILE_PATH_COLUMN, SOURCE_EXISTS_COLUMN, TARGET_EXISTS_COLUMN, + DATA_FILE_PATH_COLUMN, MANIFEST_FILE_PATH_COLUMN, MERGE_INSERTED_COLUMN, MERGE_UPDATED_COLUMN, + SOURCE_EXISTS_COLUMN, TARGET_EXISTS_COLUMN, }; use crate::datafusion::rewriters::session_context::SessionContextExprRewriter; use crate::error::{OperationOn, OperationType}; @@ -63,7 +64,7 @@ use datafusion_expr::planner::ContextProvider; use datafusion_expr::{ BinaryExpr, CreateMemoryTable, DdlStatement, Expr as DFExpr, ExprSchemable, Extension, JoinType, LogicalPlanBuilder, Operator, Projection, SubqueryAlias, TryCast, and, - build_join_schema, is_null, lit, when, + build_join_schema, is_null, lit, or, when, }; use datafusion_iceberg::DataFusionTable; use datafusion_iceberg::table::DataFusionTableConfigBuilder; @@ -1380,6 +1381,16 @@ impl UserQuery { .sql_to_expr((*on).clone(), &schema, &mut planner_context) .context(ex_error::DataFusionLogicalPlanMergeJoinSnafu)?; + let has_insert = clauses + .iter() + .any(|c| matches!(c.action, MergeAction::Insert(_))); + let has_update = clauses + .iter() + .any(|c| matches!(c.action, MergeAction::Update { .. })); + let has_delete = clauses + .iter() + .any(|c| matches!(c.action, MergeAction::Delete)); + let merge_clause_projection = merge_clause_projection( &sql_planner, &schema, @@ -1396,8 +1407,14 @@ impl UserQuery { .build() .context(ex_error::DataFusionLogicalPlanMergeJoinSnafu)?; - let merge_into_plan = MergeIntoCOWSink::new(Arc::new(join_plan), target_table) - .context(ex_error::DataFusionSnafu)?; + let merge_into_plan = MergeIntoCOWSink::new( + Arc::new(join_plan), + target_table, + has_insert, + has_update, + has_delete, + ) + .context(ex_error::DataFusionSnafu)?; self.execute_logical_plan(LogicalPlan::Extension(Extension { node: Arc::new(merge_into_plan), @@ -2952,6 +2969,8 @@ pub fn merge_clause_projection( HashMap::new(); let mut inserts: HashMap> = HashMap::new(); + let mut updated_ops: Vec = Vec::new(); + let mut inserted_ops: Vec = Vec::new(); let mut planner_context = datafusion::sql::planner::PlannerContext::new(); @@ -2967,6 +2986,7 @@ pub fn merge_clause_projection( }; match merge_clause.action { MergeAction::Update { assignments } => { + updated_ops.push(op.clone()); for assignment in assignments { match assignment.target { AssignmentTarget::ColumnName(mut column) => { @@ -2993,6 +3013,7 @@ pub fn merge_clause_projection( } } MergeAction::Insert(insert) => { + inserted_ops.push(op.clone()); let MergeInsertKind::Values(values) = insert.kind else { return Err(ex_error::OnlyMergeStatementsSnafu.build()); }; @@ -3027,6 +3048,19 @@ pub fn merge_clause_projection( } } let exprs = collect_merge_clause_expressions(target_schema, updates, inserts)?; + let mut exprs = exprs; + + let merge_updated_expr = updated_ops + .into_iter() + .fold(lit(false), or) + .alias(MERGE_UPDATED_COLUMN); + exprs.push(merge_updated_expr); + + let merge_inserted_expr = inserted_ops + .into_iter() + .fold(lit(false), or) + .alias(MERGE_INSERTED_COLUMN); + exprs.push(merge_inserted_expr); Ok(exprs) } diff --git a/crates/executor/src/tests/s3_tables/snapshots/merge_into.snap b/crates/executor/src/tests/s3_tables/snapshots/merge_into.snap index 0db897f40..2e0223db5 100644 --- a/crates/executor/src/tests/s3_tables/snapshots/merge_into.snap +++ b/crates/executor/src/tests/s3_tables/snapshots/merge_into.snap @@ -4,9 +4,10 @@ description: "MERGE INTO embucket.tests.first AS tgt USING embucket.tests.second --- Ok( [ - "+------------------------+", - "| number of rows updated |", - "+------------------------+", - "+------------------------+", + "+-------------------------+------------------------+", + "| number of rows inserted | number of rows updated |", + "+-------------------------+------------------------+", + "| 1 | 1 |", + "+-------------------------+------------------------+", ], ) diff --git a/crates/executor/src/tests/sql/ddl/merge_into.rs b/crates/executor/src/tests/sql/ddl/merge_into.rs index 79bc06aef..e4acaeec2 100644 --- a/crates/executor/src/tests/sql/ddl/merge_into.rs +++ b/crates/executor/src/tests/sql/ddl/merge_into.rs @@ -13,6 +13,19 @@ test_query!( snapshot_path = "merge_into" ); +test_query!( + merge_into_only_update_rowcount, + "MERGE INTO merge_target USING merge_source ON merge_target.id = merge_source.id + WHEN MATCHED THEN UPDATE SET merge_target.description = merge_source.description", + setup_queries = [ + "CREATE TABLE embucket.public.merge_target (ID INTEGER, description VARCHAR)", + "CREATE TABLE embucket.public.merge_source (ID INTEGER, description VARCHAR)", + "INSERT INTO embucket.public.merge_target VALUES (1, 'existing row')", + "INSERT INTO embucket.public.merge_source VALUES (1, 'updated row')", + ], + snapshot_path = "merge_into" +); + test_query!( merge_into_insert_and_update, "SELECT count(CASE WHEN description = 'updated row' THEN 1 ELSE NULL END) updated, count(CASE WHEN description = 'existing row' THEN 1 ELSE NULL END) existing FROM embucket.public.merge_target", @@ -26,6 +39,18 @@ test_query!( snapshot_path = "merge_into" ); +test_query!( + merge_into_only_insert_rowcount, + "MERGE INTO merge_target USING merge_source ON merge_target.id = merge_source.id + WHEN NOT MATCHED THEN INSERT (id, description) VALUES (merge_source.id, merge_source.description)", + setup_queries = [ + "CREATE TABLE embucket.public.merge_target (ID INTEGER, description VARCHAR)", + "CREATE TABLE embucket.public.merge_source (ID INTEGER, description VARCHAR)", + "INSERT INTO embucket.public.merge_source VALUES (1, 'new row'), (2, 'new row')", + ], + snapshot_path = "merge_into" +); + test_query!( merge_into_empty_source, "SELECT count(CASE WHEN description = 'updated row' THEN 1 ELSE NULL END) updated, count(CASE WHEN description = 'existing row' THEN 1 ELSE NULL END) existing FROM embucket.public.merge_target", diff --git a/crates/executor/src/tests/sql/ddl/snapshots/merge_into/query_merge_into_only_insert_rowcount.snap b/crates/executor/src/tests/sql/ddl/snapshots/merge_into/query_merge_into_only_insert_rowcount.snap new file mode 100644 index 000000000..9aca9e55a --- /dev/null +++ b/crates/executor/src/tests/sql/ddl/snapshots/merge_into/query_merge_into_only_insert_rowcount.snap @@ -0,0 +1,16 @@ +--- +source: crates/executor/src/tests/sql/ddl/merge_into.rs +description: "\"MERGE INTO merge_target USING merge_source ON merge_target.id = merge_source.id\\n WHEN NOT MATCHED THEN INSERT (id, description) VALUES (merge_source.id, merge_source.description)\"" +info: "Setup queries: CREATE TABLE embucket.public.merge_target (ID INTEGER, description VARCHAR); CREATE TABLE embucket.public.merge_source (ID INTEGER, description VARCHAR); INSERT INTO embucket.public.merge_source VALUES (1, 'new row'), (2, 'new row')" +--- +Ok( + [ + "+-------------------------+", + "| number of rows inserted |", + "+-------------------------+", + "| 2 |", + "+-------------------------+", + ], +) + + diff --git a/crates/executor/src/tests/sql/ddl/snapshots/merge_into/query_merge_into_only_update_rowcount.snap b/crates/executor/src/tests/sql/ddl/snapshots/merge_into/query_merge_into_only_update_rowcount.snap new file mode 100644 index 000000000..b99ed0f29 --- /dev/null +++ b/crates/executor/src/tests/sql/ddl/snapshots/merge_into/query_merge_into_only_update_rowcount.snap @@ -0,0 +1,16 @@ +--- +source: crates/executor/src/tests/sql/ddl/merge_into.rs +description: "\"MERGE INTO merge_target USING merge_source ON merge_target.id = merge_source.id\\n WHEN MATCHED THEN UPDATE SET merge_target.description = merge_source.description\"" +info: "Setup queries: CREATE TABLE embucket.public.merge_target (ID INTEGER, description VARCHAR); CREATE TABLE embucket.public.merge_source (ID INTEGER, description VARCHAR); INSERT INTO embucket.public.merge_target VALUES (1, 'existing row'); INSERT INTO embucket.public.merge_source VALUES (1, 'updated row')" +--- +Ok( + [ + "+------------------------+", + "| number of rows updated |", + "+------------------------+", + "| 1 |", + "+------------------------+", + ], +) + +