diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 8d9736eb640ad..64937eafa158d 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -20,15 +20,17 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; +use std::fmt::{Display, Formatter}; use std::hash::Hash; use std::sync::Arc; use crate::error::{unqualified_field_not_found, DataFusionError, Result, SchemaError}; -use crate::{field_not_found, Column, OwnedTableReference, TableReference}; +use crate::{ + field_not_found, Column, FunctionalDependencies, OwnedTableReference, TableReference, +}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; -use std::fmt::{Display, Formatter}; /// A reference-counted reference to a `DFSchema`. pub type DFSchemaRef = Arc; @@ -40,6 +42,8 @@ pub struct DFSchema { fields: Vec, /// Additional metadata in form of key value pairs metadata: HashMap, + /// Stores functional dependencies in the schema. + functional_dependencies: FunctionalDependencies, } impl DFSchema { @@ -48,6 +52,7 @@ impl DFSchema { Self { fields: vec![], metadata: HashMap::new(), + functional_dependencies: FunctionalDependencies::empty(), } } @@ -97,7 +102,11 @@ impl DFSchema { )); } } - Ok(Self { fields, metadata }) + Ok(Self { + fields, + metadata, + functional_dependencies: FunctionalDependencies::empty(), + }) } /// Create a `DFSchema` from an Arrow schema and a given qualifier @@ -116,6 +125,15 @@ impl DFSchema { ) } + /// Assigns functional dependencies. + pub fn with_functional_dependencies( + mut self, + functional_dependencies: FunctionalDependencies, + ) -> Self { + self.functional_dependencies = functional_dependencies; + self + } + /// Create a new schema that contains the fields from this schema followed by the fields /// from the supplied schema. An error will be returned if there are duplicate field names. pub fn join(&self, schema: &DFSchema) -> Result { @@ -471,6 +489,11 @@ impl DFSchema { pub fn metadata(&self) -> &HashMap { &self.metadata } + + /// Get functional dependencies + pub fn functional_dependencies(&self) -> &FunctionalDependencies { + &self.functional_dependencies + } } impl From for Schema { diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs new file mode 100644 index 0000000000000..b7a8e768ec78b --- /dev/null +++ b/datafusion/common/src/functional_dependencies.rs @@ -0,0 +1,520 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! FunctionalDependencies keeps track of functional dependencies +//! inside DFSchema. + +use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; +use sqlparser::ast::TableConstraint; +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; + +/// This object defines a constraint on a table. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum Constraint { + /// Columns with the given indices form a composite primary key (they are + /// jointly unique and not nullable): + PrimaryKey(Vec), + /// Columns with the given indices form a composite unique key: + Unique(Vec), +} + +/// This object encapsulates a list of functional constraints: +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Constraints { + inner: Vec, +} + +impl Constraints { + /// Create empty constraints + pub fn empty() -> Self { + Constraints::new(vec![]) + } + + // This method is private. + // Outside callers can either create empty constraint using `Constraints::empty` API. + // or create constraint from table constraints using `Constraints::new_from_table_constraints` API. + fn new(constraints: Vec) -> Self { + Self { inner: constraints } + } + + /// Convert each `TableConstraint` to corresponding `Constraint` + pub fn new_from_table_constraints( + constraints: &[TableConstraint], + df_schema: &DFSchemaRef, + ) -> Result { + let constraints = constraints + .iter() + .map(|c: &TableConstraint| match c { + TableConstraint::Unique { + columns, + is_primary, + .. + } => { + // Get primary key and/or unique indices in the schema: + let indices = columns + .iter() + .map(|pk| { + let idx = df_schema + .fields() + .iter() + .position(|item| { + item.qualified_name() == pk.value.clone() + }) + .ok_or_else(|| { + DataFusionError::Execution( + "Primary key doesn't exist".to_string(), + ) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(if *is_primary { + Constraint::PrimaryKey(indices) + } else { + Constraint::Unique(indices) + }) + } + TableConstraint::ForeignKey { .. } => Err(DataFusionError::Plan( + "Foreign key constraints are not currently supported".to_string(), + )), + TableConstraint::Check { .. } => Err(DataFusionError::Plan( + "Check constraints are not currently supported".to_string(), + )), + TableConstraint::Index { .. } => Err(DataFusionError::Plan( + "Indexes are not currently supported".to_string(), + )), + TableConstraint::FulltextOrSpatial { .. } => Err(DataFusionError::Plan( + "Indexes are not currently supported".to_string(), + )), + }) + .collect::>>()?; + Ok(Constraints::new(constraints)) + } + + /// Check whether constraints is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +impl Display for Constraints { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let pk: Vec = self.inner.iter().map(|c| format!("{:?}", c)).collect(); + let pk = pk.join(", "); + if !pk.is_empty() { + write!(f, " constraints=[{pk}]") + } else { + write!(f, "") + } + } +} + +/// This object defines a functional dependence in the schema. A functional +/// dependence defines a relationship between determinant keys and dependent +/// columns. A determinant key is a column, or a set of columns, whose value +/// uniquely determines values of some other (dependent) columns. If two rows +/// have the same determinant key, dependent columns in these rows are +/// necessarily the same. If the determinant key is unique, the set of +/// dependent columns is equal to the entire schema and the determinant key can +/// serve as a primary key. Note that a primary key may "downgrade" into a +/// determinant key due to an operation such as a join, and this object is +/// used to track dependence relationships in such cases. For more information +/// on functional dependencies, see: +/// +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionalDependence { + // Column indices of the (possibly composite) determinant key: + pub source_indices: Vec, + // Column indices of dependent column(s): + pub target_indices: Vec, + /// Flag indicating whether one of the `source_indices` can receive NULL values. + /// For a data source, if the constraint in question is `Constraint::Unique`, + /// this flag is `true`. If the constraint in question is `Constraint::PrimaryKey`, + /// this flag is `false`. + /// Note that as the schema changes between different stages in a plan, + /// such as after LEFT JOIN or RIGHT JOIN operations, this property may + /// change. + pub nullable: bool, + // The functional dependency mode: + pub mode: Dependency, +} + +/// Describes functional dependency mode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Dependency { + Single, // A determinant key may occur only once. + Multi, // A determinant key may occur multiple times (in multiple rows). +} + +impl FunctionalDependence { + // Creates a new functional dependence. + pub fn new( + source_indices: Vec, + target_indices: Vec, + nullable: bool, + ) -> Self { + Self { + source_indices, + target_indices, + nullable, + // Start with the least restrictive mode by default: + mode: Dependency::Multi, + } + } + + pub fn with_mode(mut self, mode: Dependency) -> Self { + self.mode = mode; + self + } +} + +/// This object encapsulates all functional dependencies in a given relation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionalDependencies { + deps: Vec, +} + +impl FunctionalDependencies { + /// Creates an empty `FunctionalDependencies` object. + pub fn empty() -> Self { + Self { deps: vec![] } + } + + /// Creates a new `FunctionalDependencies` object from a vector of + /// `FunctionalDependence` objects. + pub fn new(dependencies: Vec) -> Self { + Self { deps: dependencies } + } + + /// Creates a new `FunctionalDependencies` object from the given constraints. + pub fn new_from_constraints( + constraints: Option<&Constraints>, + n_field: usize, + ) -> Self { + if let Some(Constraints { inner: constraints }) = constraints { + // Construct dependency objects based on each individual constraint: + let dependencies = constraints + .iter() + .map(|constraint| { + // All the field indices are associated with the whole table + // since we are dealing with table level constraints: + let dependency = match constraint { + Constraint::PrimaryKey(indices) => FunctionalDependence::new( + indices.to_vec(), + (0..n_field).collect::>(), + false, + ), + Constraint::Unique(indices) => FunctionalDependence::new( + indices.to_vec(), + (0..n_field).collect::>(), + true, + ), + }; + // As primary keys are guaranteed to be unique, set the + // functional dependency mode to `Dependency::Single`: + dependency.with_mode(Dependency::Single) + }) + .collect::>(); + Self::new(dependencies) + } else { + // There is no constraint, return an empty object: + Self::empty() + } + } + + pub fn with_dependency(mut self, mode: Dependency) -> Self { + self.deps.iter_mut().for_each(|item| item.mode = mode); + self + } + + /// Merges the given functional dependencies with these. + pub fn extend(&mut self, other: FunctionalDependencies) { + self.deps.extend(other.deps); + } + + /// Adds the `offset` value to `source_indices` and `target_indices` for + /// each functional dependency. + pub fn add_offset(&mut self, offset: usize) { + self.deps.iter_mut().for_each( + |FunctionalDependence { + source_indices, + target_indices, + .. + }| { + *source_indices = add_offset_to_vec(source_indices, offset); + *target_indices = add_offset_to_vec(target_indices, offset); + }, + ) + } + + /// Updates `source_indices` and `target_indices` of each functional + /// dependence using the index mapping given in `proj_indices`. + /// + /// Assume that `proj_indices` is \[2, 5, 8\] and we have a functional + /// dependence \[5\] (`source_indices`) -> \[5, 8\] (`target_indices`). + /// In the updated schema, fields at indices \[2, 5, 8\] will transform + /// to \[0, 1, 2\]. Therefore, the resulting functional dependence will + /// be \[1\] -> \[1, 2\]. + pub fn project_functional_dependencies( + &self, + proj_indices: &[usize], + // The argument `n_out` denotes the schema field length, which is needed + // to correctly associate a `Single`-mode dependence with the whole table. + n_out: usize, + ) -> FunctionalDependencies { + let mut projected_func_dependencies = vec![]; + for FunctionalDependence { + source_indices, + target_indices, + nullable, + mode, + } in &self.deps + { + let new_source_indices = + update_elements_with_matching_indices(source_indices, proj_indices); + let new_target_indices = if *mode == Dependency::Single { + // Associate with all of the fields in the schema: + (0..n_out).collect() + } else { + // Update associations according to projection: + update_elements_with_matching_indices(target_indices, proj_indices) + }; + // All of the composite indices should still be valid after projection; + // otherwise, functional dependency cannot be propagated. + if new_source_indices.len() == source_indices.len() { + let new_func_dependence = FunctionalDependence::new( + new_source_indices, + new_target_indices, + *nullable, + ) + .with_mode(*mode); + projected_func_dependencies.push(new_func_dependence); + } + } + FunctionalDependencies::new(projected_func_dependencies) + } + + /// This function joins this set of functional dependencies with the `other` + /// according to the given `join_type`. + pub fn join( + &self, + other: &FunctionalDependencies, + join_type: &JoinType, + left_cols_len: usize, + ) -> FunctionalDependencies { + // Get mutable copies of left and right side dependencies: + let mut right_func_dependencies = other.clone(); + let mut left_func_dependencies = self.clone(); + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right => { + // Add offset to right schema: + right_func_dependencies.add_offset(left_cols_len); + + // Result may have multiple values, update the dependency mode: + left_func_dependencies = + left_func_dependencies.with_dependency(Dependency::Multi); + right_func_dependencies = + right_func_dependencies.with_dependency(Dependency::Multi); + + if *join_type == JoinType::Left { + // Downgrade the right side, since it may have additional NULL values: + right_func_dependencies.downgrade_dependencies(); + } else if *join_type == JoinType::Right { + // Downgrade the left side, since it may have additional NULL values: + left_func_dependencies.downgrade_dependencies(); + } + // Combine left and right functional dependencies: + left_func_dependencies.extend(right_func_dependencies); + left_func_dependencies + } + JoinType::LeftSemi | JoinType::LeftAnti => { + // These joins preserve functional dependencies of the left side: + left_func_dependencies + } + JoinType::RightSemi | JoinType::RightAnti => { + // These joins preserve functional dependencies of the right side: + right_func_dependencies + } + JoinType::Full => { + // All of the functional dependencies are lost in a FULL join: + FunctionalDependencies::empty() + } + } + } + + /// This function downgrades a functional dependency when nullability becomes + /// a possibility: + /// - If the dependency in question is UNIQUE (i.e. nullable), a new null value + /// invalidates the dependency. + /// - If the dependency in question is PRIMARY KEY (i.e. not nullable), a new + /// null value turns it into UNIQUE mode. + fn downgrade_dependencies(&mut self) { + // Delete nullable dependencies, since they are no longer valid: + self.deps.retain(|item| !item.nullable); + self.deps.iter_mut().for_each(|item| item.nullable = true); + } + + /// This function ensures that functional dependencies involving uniquely + /// occuring determinant keys cover their entire table in terms of + /// dependent columns. + pub fn extend_target_indices(&mut self, n_out: usize) { + self.deps.iter_mut().for_each( + |FunctionalDependence { + mode, + target_indices, + .. + }| { + // If unique, cover the whole table: + if *mode == Dependency::Single { + *target_indices = (0..n_out).collect::>(); + } + }, + ) + } +} + +/// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression. +pub fn aggregate_functional_dependencies( + aggr_input_schema: &DFSchema, + group_by_expr_names: &[String], + aggr_schema: &DFSchema, +) -> FunctionalDependencies { + let mut aggregate_func_dependencies = vec![]; + let aggr_input_fields = aggr_input_schema.fields(); + let aggr_fields = aggr_schema.fields(); + // Association covers the whole table: + let target_indices = (0..aggr_schema.fields().len()).collect::>(); + // Get functional dependencies of the schema: + let func_dependencies = aggr_input_schema.functional_dependencies(); + for FunctionalDependence { + source_indices, + nullable, + mode, + .. + } in &func_dependencies.deps + { + // Keep source indices in a `HashSet` to prevent duplicate entries: + let mut new_source_indices = HashSet::new(); + let source_field_names = source_indices + .iter() + .map(|&idx| aggr_input_fields[idx].qualified_name()) + .collect::>(); + for (idx, group_by_expr_name) in group_by_expr_names.iter().enumerate() { + // When one of the input determinant expressions matches with + // the GROUP BY expression, add the index of the GROUP BY + // expression as a new determinant key: + if source_field_names.contains(group_by_expr_name) { + new_source_indices.insert(idx); + } + } + // All of the composite indices occur in the GROUP BY expression: + if new_source_indices.len() == source_indices.len() { + aggregate_func_dependencies.push( + FunctionalDependence::new( + new_source_indices.into_iter().collect(), + target_indices.clone(), + *nullable, + ) + // input uniqueness stays the same when GROUP BY matches with input functional dependence determinants + .with_mode(*mode), + ); + } + } + // If we have a single GROUP BY key, we can guarantee uniqueness after + // aggregation: + if group_by_expr_names.len() == 1 { + // If `source_indices` contain 0, delete this functional dependency + // as it will be added anyway with mode `Dependency::Single`: + if let Some(idx) = aggregate_func_dependencies + .iter() + .position(|item| item.source_indices.contains(&0)) + { + // Delete the functional dependency that contains zeroth idx: + aggregate_func_dependencies.remove(idx); + } + // Add a new functional dependency associated with the whole table: + aggregate_func_dependencies.push( + // Use nullable property of the group by expression + FunctionalDependence::new( + vec![0], + target_indices, + aggr_fields[0].is_nullable(), + ) + .with_mode(Dependency::Single), + ); + } + FunctionalDependencies::new(aggregate_func_dependencies) +} + +/// Returns target indices, for the determinant keys that are inside +/// group by expressions. +pub fn get_target_functional_dependencies( + schema: &DFSchema, + group_by_expr_names: &[String], +) -> Option> { + let mut combined_target_indices = HashSet::new(); + let dependencies = schema.functional_dependencies(); + let field_names = schema + .fields() + .iter() + .map(|item| item.qualified_name()) + .collect::>(); + for FunctionalDependence { + source_indices, + target_indices, + .. + } in &dependencies.deps + { + let source_key_names = source_indices + .iter() + .map(|id_key_idx| field_names[*id_key_idx].clone()) + .collect::>(); + // If the GROUP BY expression contains a determinant key, we can use + // the associated fields after aggregation even if they are not part + // of the GROUP BY expression. + if source_key_names + .iter() + .all(|source_key_name| group_by_expr_names.contains(source_key_name)) + { + combined_target_indices.extend(target_indices.iter()); + } + } + (!combined_target_indices.is_empty()) + .then_some(combined_target_indices.iter().cloned().collect::>()) +} + +/// Updates entries inside the `entries` vector with their corresponding +/// indices inside the `proj_indices` vector. +fn update_elements_with_matching_indices( + entries: &[usize], + proj_indices: &[usize], +) -> Vec { + entries + .iter() + .filter_map(|val| proj_indices.iter().position(|proj_idx| proj_idx == val)) + .collect() +} + +/// Adds `offset` value to each entry inside `in_data`. +fn add_offset_to_vec>( + in_data: &[T], + offset: T, +) -> Vec { + in_data.iter().map(|&item| item + offset).collect() +} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 63b4024579f1e..7a46f28b50736 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -23,6 +23,7 @@ pub mod delta; mod dfschema; pub mod display; mod error; +mod functional_dependencies; mod join_type; pub mod parsers; #[cfg(feature = "pyarrow")] @@ -41,6 +42,10 @@ pub use error::{ field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError, SharedResult, }; +pub use functional_dependencies::{ + aggregate_functional_dependencies, get_target_functional_dependencies, Constraints, + Dependency, FunctionalDependence, FunctionalDependencies, +}; pub use join_type::{JoinConstraint, JoinType}; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::{OwnedSchemaReference, SchemaReference}; diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index c6fd87e7f18b3..bd3e832804cbc 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -17,12 +17,14 @@ //! Default TableSource implementation used in DataFusion physical plans +use std::any::Any; +use std::sync::Arc; + use crate::datasource::TableProvider; + use arrow::datatypes::SchemaRef; -use datafusion_common::DataFusionError; +use datafusion_common::{Constraints, DataFusionError}; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource}; -use std::any::Any; -use std::sync::Arc; /// DataFusion default table source, wrapping TableProvider /// @@ -52,6 +54,11 @@ impl TableSource for DefaultTableSource { self.table_provider.schema() } + /// Get a reference to applicable constraints, if any exists. + fn constraints(&self) -> Option<&Constraints> { + self.table_provider.constraints() + } + /// Tests whether the table provider can make use of any or all filter expressions /// to optimise data retrieval. fn supports_filters_pushdown( diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 967eb988eb593..4b6653c6889f3 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -26,7 +26,7 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion_common::SchemaExt; +use datafusion_common::{Constraints, SchemaExt}; use datafusion_execution::TaskContext; use tokio::sync::RwLock; use tokio::task::JoinSet; @@ -52,6 +52,7 @@ pub type PartitionData = Arc>>; pub struct MemTable { schema: SchemaRef, pub(crate) batches: Vec, + constraints: Option, } impl MemTable { @@ -76,9 +77,18 @@ impl MemTable { .into_iter() .map(|e| Arc::new(RwLock::new(e))) .collect::>(), + constraints: None, }) } + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + if !constraints.is_empty() { + self.constraints = Some(constraints); + } + self + } + /// Create a mem table by reading from another data source pub async fn load( t: Arc, @@ -153,6 +163,10 @@ impl TableProvider for MemTable { self.schema.clone() } + fn constraints(&self) -> Option<&Constraints> { + self.constraints.as_ref() + } + fn table_type(&self) -> TableType { TableType::Base } diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 11f30f33d1399..9c97935105890 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::{DataFusionError, Statistics}; +use datafusion_common::{Constraints, DataFusionError, Statistics}; use datafusion_expr::{CreateExternalTable, LogicalPlan}; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; @@ -41,6 +41,11 @@ pub trait TableProvider: Sync + Send { /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef; + /// Get a reference to the constraints of the table. + fn constraints(&self) -> Option<&Constraints> { + None + } + /// Get the type of this table for metadata/catalog purposes. fn table_type(&self) -> TableType; diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 811c2ec7656ad..9a518940e27c8 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -470,19 +470,12 @@ impl SessionContext { input, if_not_exists, or_replace, - primary_key, + constraints, } = cmd; - if !primary_key.is_empty() { - Err(DataFusionError::Execution( - "Primary keys on MemoryTables are not currently supported!".to_string(), - ))?; - } - let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); let input = self.state().optimize(&input)?; let table = self.table(&name).await; - match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), (false, true, Ok(_)) => { @@ -500,11 +493,15 @@ impl SessionContext { "'IF NOT EXISTS' cannot coexist with 'REPLACE'".to_string(), )), (_, _, Err(_)) => { - let schema = Arc::new(input.schema().as_ref().into()); + let df_schema = input.schema(); + let schema = Arc::new(df_schema.as_ref().into()); let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new(schema, batches)?); + let table = Arc::new( + // pass constraints to the mem table. + MemTable::try_new(schema, batches)?.with_constraints(constraints), + ); self.register_table(&name, table)?; self.return_empty_dataframe() diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 913deebddfdbe..176f3cbe50013 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2186,7 +2186,7 @@ mod tests { dict_id: 0, \ dict_is_ordered: false, \ metadata: {} } }\ - ], metadata: {} }, \ + ], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ ExecutionPlan schema: Schema { fields: [\ Field { \ name: \"b\", \ @@ -2410,6 +2410,7 @@ mod tests { ); } } + struct ErrorExtensionPlanner {} #[async_trait] diff --git a/datafusion/core/tests/sqllogictests/test_files/groupby.slt b/datafusion/core/tests/sqllogictests/test_files/groupby.slt index de57956f0ea82..b2677679c8b4e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/groupby.slt +++ b/datafusion/core/tests/sqllogictests/test_files/groupby.slt @@ -2961,3 +2961,390 @@ ORDER BY s.sn 1 FRA 3 2022-01-02T12:00:00 EUR 200 0 GRC 4 2022-01-03T10:00:00 EUR 80 1 TUR 4 2022-01-03T10:00:00 TRY 100 + +# create a table for testing +statement ok +CREATE TABLE sales_global_with_pk (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + primary key(sn) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# create a table for testing, where primary key is composite +statement ok +CREATE TABLE sales_global_with_composite_pk (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + primary key(sn, ts) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# create a table for testing, where sn is unique key +statement ok +CREATE TABLE sales_global_with_unique (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT, + unique(sn) + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0), + (1, 'TUR', NULL, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0) + +# when group by contains primary key expression +# we can use all the expressions in the table during selection +# (not just group by expressions + aggregation result) +query TT +EXPLAIN SELECT s.sn, s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +logical_plan +Sort: s.sn ASC NULLS LAST +--Projection: s.sn, s.amount, Int64(2) * CAST(s.sn AS Int64) +----Aggregate: groupBy=[[s.sn, s.amount]], aggr=[[]] +------SubqueryAlias: s +--------TableScan: sales_global_with_pk projection=[sn, amount] +physical_plan +SortPreservingMergeExec: [sn@0 ASC NULLS LAST] +--SortExec: expr=[sn@0 ASC NULLS LAST] +----ProjectionExec: expr=[sn@0 as sn, amount@1 as amount, 2 * CAST(sn@0 AS Int64) as Int64(2) * s.sn] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] +--------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] + +query IRI +SELECT s.sn, s.amount, 2*s.sn + FROM sales_global_with_pk AS s + GROUP BY sn + ORDER BY sn +---- +0 30 0 +1 50 2 +2 75 4 +3 200 6 +4 100 8 + +# Join should propagate primary key successfully +query TT +EXPLAIN SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn + ORDER BY r.sn +---- +logical_plan +Sort: r.sn ASC NULLS LAST +--Projection: r.sn, SUM(l.amount), r.amount +----Aggregate: groupBy=[[r.sn, r.amount]], aggr=[[SUM(l.amount)]] +------Projection: l.amount, r.sn, r.amount +--------Inner Join: Filter: l.sn >= r.sn +----------SubqueryAlias: l +------------TableScan: sales_global_with_pk projection=[sn, amount] +----------SubqueryAlias: r +------------TableScan: sales_global_with_pk projection=[sn, amount] +physical_plan +SortPreservingMergeExec: [sn@0 ASC NULLS LAST] +--SortExec: expr=[sn@0 ASC NULLS LAST] +----ProjectionExec: expr=[sn@0 as sn, SUM(l.amount)@2 as SUM(l.amount), amount@1 as amount] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[SUM(l.amount)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[SUM(l.amount)] +--------------ProjectionExec: expr=[amount@1 as amount, sn@2 as sn, amount@3 as amount] +----------------NestedLoopJoinExec: join_type=Inner, filter=sn@0 >= sn@1 +------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +------------------CoalescePartitionsExec +--------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] + +query IRR +SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn + ORDER BY r.sn +---- +0 455 30 +1 425 50 +2 375 75 +3 300 200 +4 100 100 + +# when primary key consists of composite columns +# to associate it with other fields, aggregate should contain all the composite columns +query IRR +SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_composite_pk AS l + JOIN sales_global_with_composite_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn, r.ts + ORDER BY r.sn +---- +0 455 30 +1 425 50 +2 375 75 +3 300 200 +4 100 100 + +# when primary key consists of composite columns +# to associate it with other fields, aggregate should contain all the composite columns +# if any of the composite column is missing, we cannot use associated indices, inside select expression +# below query should fail +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.amount could not be resolved from available columns: r.sn, SUM\(l.amount\) +SELECT r.sn, SUM(l.amount), r.amount + FROM sales_global_with_composite_pk AS l + JOIN sales_global_with_composite_pk AS r + ON l.sn >= r.sn + GROUP BY r.sn + ORDER BY r.sn + +# left join should propagate right side constraint, +# if right side is a primary key (unique and doesn't contain null) +query IRR +SELECT r.sn, r.amount, SUM(r.amount) + FROM (SELECT * + FROM sales_global_with_pk as l + LEFT JOIN sales_global_with_pk as r + ON l.amount >= r.amount + 10) + GROUP BY r.sn +ORDER BY r.sn +---- +0 30 120 +1 50 150 +2 75 150 +4 100 100 +NULL NULL NULL + +# left join shouldn't propagate right side constraint, +# if right side is a unique key (unique and can contain null) +# Please note that, above query and this one is same except the constraint in the table. +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.amount could not be resolved from available columns: r.sn, SUM\(r.amount\) +SELECT r.sn, r.amount, SUM(r.amount) + FROM (SELECT * + FROM sales_global_with_unique as l + LEFT JOIN sales_global_with_unique as r + ON l.amount >= r.amount + 10) + GROUP BY r.sn +ORDER BY r.sn + +# left semi join should propagate constraint of left side as is. +query IRR +SELECT l.sn, l.amount, SUM(l.amount) + FROM (SELECT * + FROM sales_global_with_unique as l + LEFT SEMI JOIN sales_global_with_unique as r + ON l.amount >= r.amount + 10) + GROUP BY l.sn +ORDER BY l.sn +---- +1 50 50 +2 75 75 +3 200 200 +4 100 100 +NULL 100 100 + +# Similarly, left anti join should propagate constraint of left side as is. +query IRR +SELECT l.sn, l.amount, SUM(l.amount) + FROM (SELECT * + FROM sales_global_with_unique as l + LEFT ANTI JOIN sales_global_with_unique as r + ON l.amount >= r.amount + 10) + GROUP BY l.sn +ORDER BY l.sn +---- +0 30 30 + + +# primary key should be aware from which columns it is associated +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, SUM\(l.amount\) +SELECT l.sn, r.sn, SUM(l.amount), r.amount + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r + ON l.sn >= r.sn + GROUP BY l.sn + ORDER BY l.sn + +# window should propagate primary key successfully +query TT +EXPLAIN SELECT * + FROM(SELECT *, SUM(l.amount) OVER(ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum_amount + FROM sales_global_with_pk AS l + ) as l + GROUP BY l.sn + ORDER BY l.sn +---- +logical_plan +Sort: l.sn ASC NULLS LAST +--Projection: l.zip_code, l.country, l.sn, l.ts, l.currency, l.amount, l.sum_amount +----Aggregate: groupBy=[[l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, l.sum_amount]], aggr=[[]] +------SubqueryAlias: l +--------Projection: l.zip_code, l.country, l.sn, l.ts, l.currency, l.amount, SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum_amount +----------WindowAggr: windowExpr=[[SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +------------SubqueryAlias: l +--------------TableScan: sales_global_with_pk projection=[zip_code, country, sn, ts, currency, amount] +physical_plan +SortPreservingMergeExec: [sn@2 ASC NULLS LAST] +--SortExec: expr=[sn@2 ASC NULLS LAST] +----ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, zip_code@1 as zip_code, country@2 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([sn@0, zip_code@1, country@2, ts@3, currency@4, amount@5, sum_amount@6], 8), input_partitions=1 +------------AggregateExec: mode=Partial, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] +--------------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@6 as sum_amount] +----------------BoundedWindowAggExec: wdw=[SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +------------------CoalescePartitionsExec +--------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] + +query ITIPTRR +SELECT * + FROM(SELECT *, SUM(l.amount) OVER(ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum_amount + FROM sales_global_with_pk AS l + ) as l + GROUP BY l.sn + ORDER BY l.sn +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 80 +1 FRA 1 2022-01-01T08:00:00 EUR 50 155 +1 TUR 2 2022-01-01T11:30:00 TRY 75 325 +1 FRA 3 2022-01-02T12:00:00 EUR 200 375 +1 TUR 4 2022-01-03T10:00:00 TRY 100 300 + +# join should propagate primary key correctly +query IRP +SELECT l.sn, SUM(l.amount), l.ts +FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn) +GROUP BY l.sn +ORDER BY l.sn +---- +0 30 2022-01-01T06:00:00 +1 100 2022-01-01T08:00:00 +2 225 2022-01-01T11:30:00 +3 800 2022-01-02T12:00:00 +4 500 2022-01-03T10:00:00 + +# Projection propagates primary keys correctly +# (we can use r.ts at the final projection, because it +# is associated with primary key r.sn) +query IRP +SELECT r.sn, SUM(r.amount), r.ts +FROM + (SELECT r.ts, r.sn, r.amount + FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn)) +GROUP BY r.sn +ORDER BY r.sn +---- +0 150 2022-01-01T06:00:00 +1 200 2022-01-01T08:00:00 +2 225 2022-01-01T11:30:00 +3 400 2022-01-02T12:00:00 +4 100 2022-01-03T10:00:00 + +# after join, new window expressions shouldn't be associated with primary keys +statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, SUM\(r.amount\) +SELECT r.sn, SUM(r.amount), rn1 +FROM + (SELECT r.ts, r.sn, r.amount, + ROW_NUMBER() OVER() AS rn1 + FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn)) +GROUP BY r.sn + +# aggregate should propagate primary key successfully +query IPR +SELECT sn, ts, sum1 +FROM ( + SELECT ts, sn, SUM(amount) as sum1 + FROM sales_global_with_pk + GROUP BY sn) +GROUP BY sn +ORDER BY sn +---- +0 2022-01-01T06:00:00 30 +1 2022-01-01T08:00:00 50 +2 2022-01-01T11:30:00 75 +3 2022-01-02T12:00:00 200 +4 2022-01-03T10:00:00 100 + +# aggregate should be able to introduce functional dependence +# (when group by contains single expression, group by expression +# becomes determinant, after aggregation; since we are sure that +# it will consist of unique values.) +# please note that ts is not primary key, still +# we can use sum1, after outer aggregation because +# after inner aggregation, ts becomes determinant +# of functional dependence. +query PR +SELECT ts, sum1 +FROM ( + SELECT ts, SUM(amount) as sum1 + FROM sales_global_with_pk + GROUP BY ts) +GROUP BY ts +ORDER BY ts +---- +2022-01-01T06:00:00 30 +2022-01-01T08:00:00 50 +2022-01-01T11:30:00 75 +2022-01-02T12:00:00 200 +2022-01-03T10:00:00 100 + +# aggregate should update its functional dependence +# mode, if it is guaranteed that, after aggregation +# group by expressions will be unique. +query IRI +SELECT * +FROM ( + SELECT *, ROW_NUMBER() OVER(ORDER BY l.sn) AS rn1 + FROM ( + SELECT l.sn, SUM(l.amount) + FROM ( + SELECT l.sn, l.amount, SUM(l.amount) as sum1 + FROM + (SELECT * + FROM sales_global_with_pk AS l + JOIN sales_global_with_pk AS r ON l.sn >= r.sn) + GROUP BY l.sn) + GROUP BY l.sn) + ) +GROUP BY l.sn +ORDER BY l.sn +---- +0 30 1 +1 50 2 +2 75 3 +3 200 4 +4 100 5 diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cd8940e134639..4d985456f9824 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -42,7 +42,8 @@ use crate::{ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::{ display::ToStringifiedPlan, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - OwnedTableReference, Result, ScalarValue, TableReference, ToDFSchema, + FunctionalDependencies, OwnedTableReference, Result, ScalarValue, TableReference, + ToDFSchema, }; use std::any::Any; use std::cmp::Ordering; @@ -263,10 +264,16 @@ impl LogicalPlanBuilder { } let schema = table_source.schema(); + let func_dependencies = FunctionalDependencies::new_from_constraints( + table_source.constraints(), + schema.fields.len(), + ); let projected_schema = projection .as_ref() .map(|p| { + let projected_func_dependencies = + func_dependencies.project_functional_dependencies(p, p.len()); DFSchema::new_with_metadata( p.iter() .map(|i| { @@ -278,9 +285,14 @@ impl LogicalPlanBuilder { .collect(), schema.metadata().clone(), ) + .map(|df_schema| { + df_schema.with_functional_dependencies(projected_func_dependencies) + }) }) .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(table_name.clone(), &schema) + DFSchema::try_from_qualified_schema(table_name.clone(), &schema).map( + |df_schema| df_schema.with_functional_dependencies(func_dependencies), + ) })?; let table_scan = LogicalPlan::TableScan(TableScan { @@ -803,11 +815,12 @@ impl LogicalPlanBuilder { /// Apply a cross join pub fn cross_join(self, right: LogicalPlan) -> Result { - let schema = self.plan.schema().join(right.schema())?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; Ok(Self::from(LogicalPlan::CrossJoin(CrossJoin { left: Arc::new(self.plan), right: Arc::new(right), - schema: DFSchemaRef::new(schema), + schema: DFSchemaRef::new(join_schema), }))) } @@ -1086,10 +1099,15 @@ pub fn build_join_schema( right_fields.clone() } }; - + let func_dependencies = left.functional_dependencies().join( + right.functional_dependencies(), + join_type, + left_fields.len(), + ); let mut metadata = left.metadata().clone(); metadata.extend(right.metadata().clone()); - DFSchema::new_with_metadata(fields, metadata) + Ok(DFSchema::new_with_metadata(fields, metadata)? + .with_functional_dependencies(func_dependencies)) } /// Errors if one or more expressions have equal names. @@ -1400,10 +1418,11 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result { }) .collect::>(); - let schema = Arc::new(DFSchema::new_with_metadata( - fields, - input_schema.metadata().clone(), - )?); + let schema = Arc::new( + DFSchema::new_with_metadata(fields, input_schema.metadata().clone())? + // We can use the existing functional dependencies: + .with_functional_dependencies(input_schema.functional_dependencies().clone()), + ); Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), @@ -1414,14 +1433,16 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result { #[cfg(test)] mod tests { - use crate::{expr, expr_fn::exists}; - use arrow::datatypes::{DataType, Field}; - use datafusion_common::{OwnedTableReference, SchemaError, TableReference}; - use crate::logical_plan::StringifiedPlan; + use crate::{col, in_subquery, lit, scalar_subquery, sum}; + use crate::{expr, expr_fn::exists}; use super::*; - use crate::{col, in_subquery, lit, scalar_subquery, sum}; + + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{ + FunctionalDependence, OwnedTableReference, SchemaError, TableReference, + }; #[test] fn plan_builder_simple() -> Result<()> { @@ -1922,4 +1943,21 @@ mod tests { Ok(()) } + + #[test] + fn test_get_updated_id_keys() { + let fund_dependencies = + FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![1], + vec![0, 1, 2], + true, + )]); + let res = fund_dependencies.project_functional_dependencies(&[1, 2], 2); + let expected = FunctionalDependencies::new(vec![FunctionalDependence::new( + vec![0], + vec![0, 1], + true, + )]); + assert_eq!(res, expected); + } } diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index e005f114719df..dc247da3642c0 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{ - parsers::CompressionTypeVariant, DFSchemaRef, OwnedTableReference, -}; -use datafusion_common::{Column, OwnedSchemaReference}; use std::collections::HashMap; use std::sync::Arc; use std::{ @@ -28,6 +24,11 @@ use std::{ use crate::{Expr, LogicalPlan}; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + Constraints, DFSchemaRef, OwnedSchemaReference, OwnedTableReference, +}; + /// Various types of DDL (CREATE / DROP) catalog manipulation #[derive(Clone, PartialEq, Eq, Hash)] pub enum DdlStatement { @@ -117,16 +118,10 @@ impl DdlStatement { } DdlStatement::CreateMemoryTable(CreateMemoryTable { name, - primary_key, + constraints, .. }) => { - let pk: Vec = - primary_key.iter().map(|c| c.name.to_string()).collect(); - let mut pk = pk.join(", "); - if !pk.is_empty() { - pk = format!(" primary_key=[{pk}]"); - } - write!(f, "CreateMemoryTable: {name:?}{pk}") + write!(f, "CreateMemoryTable: {name:?}{constraints}") } DdlStatement::CreateView(CreateView { name, .. }) => { write!(f, "CreateView: {name:?}") @@ -222,8 +217,8 @@ impl Hash for CreateExternalTable { pub struct CreateMemoryTable { /// The table name pub name: OwnedTableReference, - /// The ordered list of columns in the primary key, or an empty vector if none - pub primary_key: Vec, + /// The list of constraints in the schema, such as primary key, unique, etc. + pub constraints: Constraints, /// The logical plan pub input: Arc, /// Option to not error if table already exists diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index be8270ddc3e38..c0af08a366186 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -30,24 +30,25 @@ use crate::utils::{ use crate::{ build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource, }; + +use super::DdlStatement; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeVisitor, VisitRecursion, }; use datafusion_common::{ - plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - OwnedTableReference, Result, ScalarValue, + aggregate_functional_dependencies, plan_err, Column, DFField, DFSchema, DFSchemaRef, + DataFusionError, FunctionalDependencies, OwnedTableReference, Result, ScalarValue, }; -use std::collections::{HashMap, HashSet}; -use std::fmt::{self, Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - // backwards compatibility pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; -use super::DdlStatement; +use std::collections::{HashMap, HashSet}; +use std::fmt::{self, Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by @@ -1300,6 +1301,11 @@ impl Projection { if expr.len() != schema.fields().len() { return Err(DataFusionError::Plan(format!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()))); } + // Update functional dependencies of `input` according to projection + // expressions: + let id_key_groups = calc_func_dependencies_for_project(&expr, &input)?; + let schema = schema.as_ref().clone(); + let schema = Arc::new(schema.with_functional_dependencies(id_key_groups)); Ok(Self { expr, input, @@ -1343,8 +1349,13 @@ impl SubqueryAlias { ) -> Result { let alias = alias.into(); let schema: Schema = plan.schema().as_ref().clone().into(); - let schema = - DFSchemaRef::new(DFSchema::try_from_qualified_schema(&alias, &schema)?); + // Since schema is the same, other than qualifier, we can use existing + // functional dependencies: + let func_dependencies = plan.schema().functional_dependencies().clone(); + let schema = DFSchemaRef::new( + DFSchema::try_from_qualified_schema(&alias, &schema)? + .with_functional_dependencies(func_dependencies), + ); Ok(SubqueryAlias { input: Arc::new(plan), alias, @@ -1420,10 +1431,18 @@ impl Window { .extend_from_slice(&exprlist_to_fields(window_expr.iter(), input.as_ref())?); let metadata = input.schema().metadata().clone(); + // Update functional dependencies for window: + let mut window_func_dependencies = + input.schema().functional_dependencies().clone(); + window_func_dependencies.extend_target_indices(window_fields.len()); + Ok(Window { input, window_expr, - schema: Arc::new(DFSchema::new_with_metadata(window_fields, metadata)?), + schema: Arc::new( + DFSchema::new_with_metadata(window_fields, metadata)? + .with_functional_dependencies(window_func_dependencies), + ), }) } } @@ -1610,10 +1629,12 @@ impl Aggregate { let group_expr = enumerate_grouping_sets(group_expr)?; let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); + let schema = DFSchema::new_with_metadata( exprlist_to_fields(all_expr, &input)?, input.schema().metadata().clone(), )?; + Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) } @@ -1642,6 +1663,13 @@ impl Aggregate { schema.fields().len() ))); } + + let aggregate_func_dependencies = + calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?; + let new_schema = schema.as_ref().clone(); + let schema = Arc::new( + new_schema.with_functional_dependencies(aggregate_func_dependencies), + ); Ok(Self { input, group_expr, @@ -1651,6 +1679,71 @@ impl Aggregate { } } +/// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. +fn contains_grouping_set(group_expr: &[Expr]) -> bool { + group_expr + .iter() + .any(|expr| matches!(expr, Expr::GroupingSet(_))) +} + +/// Calculates functional dependencies for aggregate expressions. +fn calc_func_dependencies_for_aggregate( + // Expressions in the GROUP BY clause: + group_expr: &[Expr], + // Input plan of the aggregate: + input: &LogicalPlan, + // Aggregate schema + aggr_schema: &DFSchema, +) -> Result { + // We can do a case analysis on how to propagate functional dependencies based on + // whether the GROUP BY in question contains a grouping set expression: + // - If so, the functional dependencies will be empty because we cannot guarantee + // that GROUP BY expression results will be unique. + // - Otherwise, it may be possible to propagate functional dependencies. + if !contains_grouping_set(group_expr) { + let group_by_expr_names = group_expr + .iter() + .map(|item| item.display_name()) + .collect::>>()?; + let aggregate_func_dependencies = aggregate_functional_dependencies( + input.schema(), + &group_by_expr_names, + aggr_schema, + ); + Ok(aggregate_func_dependencies) + } else { + Ok(FunctionalDependencies::empty()) + } +} + +/// This function projects functional dependencies of the `input` plan according +/// to projection expressions `exprs`. +fn calc_func_dependencies_for_project( + exprs: &[Expr], + input: &LogicalPlan, +) -> Result { + let input_fields = input.schema().fields(); + // Calculate expression indices (if present) in the input schema. + let proj_indices = exprs + .iter() + .filter_map(|expr| { + let expr_name = match expr { + Expr::Alias(alias) => { + format!("{}", alias.expr) + } + _ => format!("{}", expr), + }; + input_fields + .iter() + .position(|item| item.qualified_name() == expr_name) + }) + .collect::>(); + Ok(input + .schema() + .functional_dependencies() + .project_functional_dependencies(&proj_indices, exprs.len())) +} + /// Sorts its input according to a list of sort expressions. #[derive(Clone, PartialEq, Eq, Hash)] pub struct Sort { diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index 2f5a8923e8bb6..b83ce778133b1 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -18,8 +18,10 @@ //! Table source use crate::{Expr, LogicalPlan}; + use arrow::datatypes::SchemaRef; -use datafusion_common::Result; +use datafusion_common::{Constraints, Result}; + use std::any::Any; /// Indicates whether and how a filter expression can be handled by a @@ -64,6 +66,11 @@ pub trait TableSource: Sync + Send { /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef; + /// Get primary key indices, if one exists. + fn constraints(&self) -> Option<&Constraints> { + None + } + /// Get the type of this table for metadata/catalog purposes. fn table_type(&self) -> TableType { TableType::Base diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 3ddfea5105689..efaf291398427 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -34,8 +34,8 @@ use datafusion_common::tree_node::{ RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion, }; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, + Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + ScalarValue, TableReference, }; use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; use std::cmp::Ordering; @@ -446,7 +446,9 @@ pub fn expand_qualified_wildcard( ))); } let qualified_schema = - DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())?; + DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? + // We can use the functional dependencies as is, since it only stores indices: + .with_functional_dependencies(schema.functional_dependencies().clone()); let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, @@ -921,7 +923,7 @@ pub fn from_plan( })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { input: Arc::new(inputs[0].clone()), - primary_key: vec![], + constraints: Constraints::empty(), name: name.clone(), if_not_exists: *if_not_exists, or_replace: *or_replace, @@ -1016,10 +1018,13 @@ pub fn from_plan( }) .collect::>(); - let schema = Arc::new(DFSchema::new_with_metadata( - fields, - input.schema().metadata().clone(), - )?); + let schema = Arc::new( + DFSchema::new_with_metadata(fields, input.schema().metadata().clone())? + // We can use the existing functional dependencies as is: + .with_functional_dependencies( + input.schema().functional_dependencies().clone(), + ), + ); Ok(LogicalPlan::Unnest(Unnest { input, diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index f4a51478132bd..95061e38540c7 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -262,8 +262,10 @@ fn rewrite_schema(schema: &DFSchema) -> DFSchemaRef { ) }) .collect::>(); - DFSchemaRef::new( - DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(), + Arc::new( + DFSchema::new_with_metadata(new_fields, schema.metadata().clone()) + .unwrap() + .with_functional_dependencies(schema.functional_dependencies().clone()), ) } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 2306593d424b3..920b9ea18f92b 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -493,12 +493,12 @@ mod tests { assert_eq!( "get table_scan rule\ncaused by\n\ Internal error: Optimizer rule 'get table_scan rule' failed, due to generate a different schema, \ - original schema: DFSchema { fields: [], metadata: {} }, \ + original schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ new schema: DFSchema { fields: [\ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }], \ - metadata: {} }. \ + metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }. \ This was likely caused by a bug in DataFusion's code \ and we would welcome that you file an bug report in our issue tracker", err.to_string() diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index ba7e89094b0f3..a9e65b3e7c778 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -167,12 +167,13 @@ impl OptimizerRule for SingleDistinctToGroupBy { Vec::new(), )?); + let outer_fields = outer_group_exprs + .iter() + .chain(new_aggr_exprs.iter()) + .map(|expr| expr.to_field(&inner_schema)) + .collect::>>()?; let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( - outer_group_exprs - .iter() - .chain(new_aggr_exprs.iter()) - .map(|expr| expr.to_field(&inner_schema)) - .collect::>>()?, + outer_fields, input.schema().metadata().clone(), )?); diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 2d7771d8c753c..34b24b0594fa8 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{Constraints, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Expr, LogicalPlan, LogicalPlanBuilder, }; @@ -86,7 +86,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let select_into = select.into.unwrap(); LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { name: self.object_name_to_table_reference(select_into.name)?, - primary_key: Vec::new(), + constraints: Constraints::empty(), input: Arc::new(plan), if_not_exists: false, or_replace: false, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 7841858bacdaf..daf79e969e1f2 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -15,12 +15,19 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::{ check_columns_satisfy_exprs, extract_aliases, rebase_expr, resolve_aliases_to_exprs, resolve_columns, resolve_positions_to_exprs, }; -use datafusion_common::{DataFusionError, Result}; + +use datafusion_common::{ + get_target_functional_dependencies, DFSchemaRef, DataFusionError, Result, +}; +use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, }; @@ -32,12 +39,8 @@ use datafusion_expr::utils::{ use datafusion_expr::{ Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; - -use datafusion_expr::expr::Alias; use sqlparser::ast::{Distinct, Expr as SQLExpr, WildcardAdditionalOptions, WindowType}; use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; -use std::collections::HashSet; -use std::sync::Arc; impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logic plan from an SQL select @@ -431,6 +434,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { group_by_exprs: Vec, aggr_exprs: Vec, ) -> Result<(LogicalPlan, Vec, Option)> { + let group_by_exprs = + get_updated_group_by_exprs(&group_by_exprs, select_exprs, input.schema())?; + // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? @@ -555,3 +561,40 @@ fn match_window_definitions( } Ok(()) } + +/// Update group by exprs, according to functioanl dependencies +fn get_updated_group_by_exprs( + group_by_exprs: &[Expr], + select_exprs: &[Expr], + schema: &DFSchemaRef, +) -> Result> { + let mut new_group_by_exprs = group_by_exprs.to_vec(); + let fields = schema.fields(); + let group_by_expr_names = group_by_exprs + .iter() + .map(|group_by_expr| group_by_expr.display_name()) + .collect::>>()?; + // Get targets that can be used in a select, even if they do not occur in aggregation: + if let Some(target_indices) = + get_target_functional_dependencies(schema, &group_by_expr_names) + { + // Calculate dependent fields names with determinant GROUP BY expression: + let associated_field_names = target_indices + .iter() + .map(|idx| fields[*idx].qualified_name()) + .collect::>(); + // Expand GROUP BY expressions with select expressions: If a GROUP + // BY expression is a determinant key, we can use its dependent + // columns in select statements also. + for expr in select_exprs { + let expr_name = format!("{}", expr); + if !new_group_by_exprs.contains(expr) + && associated_field_names.contains(&expr_name) + { + new_group_by_exprs.push(expr.clone()); + } + } + } + + Ok(new_group_by_exprs) +} diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index a5ff3633acad9..4af32337f77a1 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -23,12 +23,15 @@ use crate::planner::{ object_name_to_qualifier, ContextProvider, PlannerContext, SqlToRel, }; use crate::utils::normalize_ident; + use arrow_schema::DataType; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - unqualified_field_not_found, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, - ExprSchema, OwnedTableReference, Result, SchemaReference, TableReference, ToDFSchema, + unqualified_field_not_found, Column, Constraints, DFField, DFSchema, DFSchemaRef, + DataFusionError, ExprSchema, OwnedTableReference, Result, SchemaReference, + TableReference, ToDFSchema, }; +use datafusion_expr::expr::Placeholder; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; @@ -45,12 +48,11 @@ use datafusion_expr::{ use sqlparser::ast; use sqlparser::ast::{ Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SchemaName, - SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, - TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, + SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableFactor, + TableWithJoins, TransactionMode, UnaryOperator, Value, }; - -use datafusion_expr::expr::Placeholder; use sqlparser::parser::ParserError::ParserError; + use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; @@ -132,8 +134,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .. } if table_properties.is_empty() && with_options.is_empty() => match query { Some(query) => { - let primary_key = Self::primary_key_from_constraints(&constraints)?; - let plan = self.query_to_plan(*query, planner_context)?; let input_schema = plan.schema(); @@ -163,10 +163,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; + let constraints = Constraints::new_from_table_constraints( + &constraints, + plan.schema(), + )?; + Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(name)?, - primary_key, + constraints, input: Arc::new(plan), if_not_exists, or_replace, @@ -175,19 +180,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => { - let primary_key = Self::primary_key_from_constraints(&constraints)?; - let schema = self.build_schema(columns)?.to_dfschema_ref()?; let plan = EmptyRelation { produce_one_row: false, schema, }; let plan = LogicalPlan::EmptyRelation(plan); - + let constraints = Constraints::new_from_table_constraints( + &constraints, + plan.schema(), + )?; Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(name)?, - primary_key, + constraints, input: Arc::new(plan), if_not_exists, or_replace, @@ -1160,54 +1166,4 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .get_table_provider(tables_reference) .is_ok() } - - fn primary_key_from_constraints( - constraints: &[TableConstraint], - ) -> Result> { - let pk: Result>> = constraints - .iter() - .map(|c: &TableConstraint| match c { - TableConstraint::Unique { - columns, - is_primary, - .. - } => match is_primary { - true => Ok(columns), - false => Err(DataFusionError::Plan( - "Non-primary unique constraints are not supported".to_string(), - )), - }, - TableConstraint::ForeignKey { .. } => Err(DataFusionError::Plan( - "Foreign key constraints are not currently supported".to_string(), - )), - TableConstraint::Check { .. } => Err(DataFusionError::Plan( - "Check constraints are not currently supported".to_string(), - )), - TableConstraint::Index { .. } => Err(DataFusionError::Plan( - "Indexes are not currently supported".to_string(), - )), - TableConstraint::FulltextOrSpatial { .. } => Err(DataFusionError::Plan( - "Indexes are not currently supported".to_string(), - )), - }) - .collect(); - let pk = pk?; - let pk = match pk.as_slice() { - [] => return Ok(vec![]), - [pk] => pk, - _ => { - return Err(DataFusionError::Plan( - "Only one primary key is supported!".to_string(), - ))? - } - }; - let primary_key: Vec = pk - .iter() - .map(|c| Column { - relation: None, - name: c.value.clone(), - }) - .collect(); - Ok(primary_key) - } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6b498084a4986..88dddc7336a6f 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -232,7 +232,7 @@ fn cast_to_invalid_decimal_type_precision_lt_scale() { fn plan_create_table_with_pk() { let sql = "create table person (id int, name string, primary key(id))"; let plan = r#" -CreateMemoryTable: Bare { table: "person" } primary_key=[id] +CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0])] EmptyRelation "# .trim(); @@ -251,10 +251,9 @@ CreateMemoryTable: Bare { table: "person" } } #[test] -#[should_panic(expected = "Non-primary unique constraints are not supported")] fn plan_create_table_check_constraint() { let sql = "create table person (id int, name string, unique(id))"; - let plan = ""; + let plan = "CreateMemoryTable: Bare { table: \"person\" } constraints=[Unique([0])]\n EmptyRelation"; quick_test(sql, plan); }