Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 124 additions & 38 deletions datafusion/sql/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use crate::TableReference;
use std::collections::BTreeSet;
use std::ops::ControlFlow;

use datafusion_common::{DataFusionError, Result};

use crate::TableReference;
use crate::parser::{CopyToSource, CopyToStatement, Statement as DFStatement};
use crate::planner::object_name_to_table_reference;
use sqlparser::ast::*;
Expand All @@ -45,27 +47,40 @@ const INFORMATION_SCHEMA_TABLES: &[&str] = &[
PARAMETERS,
];

// Collect table/CTE references as `TableReference`s and normalize them during traversal.
// This avoids a second normalization/conversion pass after visiting the AST.
struct RelationVisitor {
relations: BTreeSet<ObjectName>,
all_ctes: BTreeSet<ObjectName>,
ctes_in_scope: Vec<ObjectName>,
relations: BTreeSet<TableReference>,
all_ctes: BTreeSet<TableReference>,
ctes_in_scope: Vec<TableReference>,
enable_ident_normalization: bool,
}

impl RelationVisitor {
/// Record the reference to `relation`, if it's not a CTE reference.
fn insert_relation(&mut self, relation: &ObjectName) {
if !self.relations.contains(relation) && !self.ctes_in_scope.contains(relation) {
self.relations.insert(relation.clone());
fn insert_relation(&mut self, relation: &ObjectName) -> ControlFlow<DataFusionError> {
match object_name_to_table_reference(
relation.clone(),
self.enable_ident_normalization,
) {
Ok(relation) => {
if !self.relations.contains(&relation)
&& !self.ctes_in_scope.contains(&relation)
{
self.relations.insert(relation);
}
ControlFlow::Continue(())
}
Err(e) => ControlFlow::Break(e),
}
}
}

impl Visitor for RelationVisitor {
type Break = ();
type Break = DataFusionError;

fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
self.insert_relation(relation);
ControlFlow::Continue(())
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
self.insert_relation(relation)
}

fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
Expand All @@ -78,10 +93,16 @@ impl Visitor for RelationVisitor {
if !with.recursive {
// This is a bit hackish as the CTE will be visited again as part of visiting `q`,
// but thankfully `insert_relation` is idempotent.
let _ = cte.visit(self);
cte.visit(self)?;
}
let cte_name = ObjectName::from(vec![cte.alias.name.clone()]);
match object_name_to_table_reference(
cte_name,
self.enable_ident_normalization,
) {
Ok(cte_ref) => self.ctes_in_scope.push(cte_ref),
Err(e) => return ControlFlow::Break(e),
}
self.ctes_in_scope
.push(ObjectName::from(vec![cte.alias.name.clone()]));
}
}
ControlFlow::Continue(())
Expand All @@ -97,13 +118,13 @@ impl Visitor for RelationVisitor {
ControlFlow::Continue(())
}

fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
if let Statement::ShowCreate {
obj_type: ShowCreateObject::Table | ShowCreateObject::View,
obj_name,
} = statement
{
self.insert_relation(obj_name)
self.insert_relation(obj_name)?;
}

// SHOW statements will later be rewritten into a SELECT from the information_schema
Expand All @@ -120,35 +141,53 @@ impl Visitor for RelationVisitor {
);
if requires_information_schema {
for s in INFORMATION_SCHEMA_TABLES {
self.relations.insert(ObjectName::from(vec![
// Information schema references are synthesized here, so convert directly.
let obj = ObjectName::from(vec![
Ident::new(INFORMATION_SCHEMA),
Ident::new(*s),
]));
]);
match object_name_to_table_reference(obj, self.enable_ident_normalization)
{
Ok(tbl_ref) => {
self.relations.insert(tbl_ref);
}
Err(e) => return ControlFlow::Break(e),
}
}
}
ControlFlow::Continue(())
}
}

fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) {
fn control_flow_to_result(flow: ControlFlow<DataFusionError>) -> Result<()> {
match flow {
ControlFlow::Continue(()) => Ok(()),
ControlFlow::Break(err) => Err(err),
}
}

fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) -> Result<()> {
match statement {
DFStatement::Statement(s) => {
let _ = s.as_ref().visit(visitor);
control_flow_to_result(s.as_ref().visit(visitor))?;
}
DFStatement::CreateExternalTable(table) => {
visitor.relations.insert(table.name.clone());
control_flow_to_result(visitor.insert_relation(&table.name))?;
}
DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
CopyToSource::Relation(table_name) => {
visitor.insert_relation(table_name);
control_flow_to_result(visitor.insert_relation(table_name))?;
}
CopyToSource::Query(query) => {
let _ = query.visit(visitor);
control_flow_to_result(query.visit(visitor))?;
}
},
DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor),
DFStatement::Explain(explain) => {
visit_statement(&explain.statement, visitor)?;
}
DFStatement::Reset(_) => {}
}
Ok(())
}

/// Collects all tables and views referenced in the SQL statement. CTEs are collected separately.
Expand Down Expand Up @@ -188,26 +227,20 @@ fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) {
pub fn resolve_table_references(
statement: &crate::parser::Statement,
enable_ident_normalization: bool,
) -> datafusion_common::Result<(Vec<TableReference>, Vec<TableReference>)> {
) -> Result<(Vec<TableReference>, Vec<TableReference>)> {
let mut visitor = RelationVisitor {
relations: BTreeSet::new(),
all_ctes: BTreeSet::new(),
ctes_in_scope: vec![],
enable_ident_normalization,
};

visit_statement(statement, &mut visitor);

let table_refs = visitor
.relations
.into_iter()
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
.collect::<datafusion_common::Result<_>>()?;
let ctes = visitor
.all_ctes
.into_iter()
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
.collect::<datafusion_common::Result<_>>()?;
Ok((table_refs, ctes))
visit_statement(statement, &mut visitor)?;

Ok((
visitor.relations.into_iter().collect(),
visitor.all_ctes.into_iter().collect(),
))
}

#[cfg(test)]
Expand Down Expand Up @@ -270,4 +303,57 @@ mod tests {
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "nodes");
}

#[test]
fn resolve_table_references_cte_with_quoted_reference() {
use crate::parser::DFParser;

let query = r#"with barbaz as (select 1) select * from "barbaz""#;
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "barbaz");
// Quoted reference should still resolve to the CTE when normalization is on
assert_eq!(table_refs.len(), 0);
}

#[test]
fn resolve_table_references_cte_with_quoted_reference_normalization_off() {
use crate::parser::DFParser;

let query = r#"with barbaz as (select 1) select * from "barbaz""#;
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap();
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "barbaz");
// Even with normalization off, quoted reference matches same-case CTE name
assert_eq!(table_refs.len(), 0);
}

#[test]
fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_on() {
use crate::parser::DFParser;

let query = r#"with FOObar as (select 1) select * from "FOObar""#;
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
// CTE name is normalized to lowercase, quoted reference preserves case, so they differ
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "foobar");
assert_eq!(table_refs.len(), 1);
assert_eq!(table_refs[0].to_string(), "FOObar");
}

#[test]
fn resolve_table_references_cte_with_quoted_reference_uppercase_normalization_off() {
use crate::parser::DFParser;

let query = r#"with FOObar as (select 1) select * from "FOObar""#;
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
let (table_refs, ctes) = resolve_table_references(&statement, false).unwrap();
// Without normalization, cases match exactly, so quoted reference resolves to the CTE
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].to_string(), "FOObar");
assert_eq!(table_refs.len(), 0);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also please add some .slt tests so we can see this working when running end to end in queries?

Instructions are here: https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion — I’ve added an end-to-end .slt regression

}
104 changes: 103 additions & 1 deletion datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use arrow::buffer::ScalarBuffer;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields};
use arrow::record_batch::RecordBatch;
use datafusion::catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session,
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, Session,
};
use datafusion::common::{DataFusionError, Result, not_impl_err};
use datafusion::functions::math::abs;
Expand Down Expand Up @@ -96,6 +96,10 @@ impl TestContext {

let file_name = relative_path.file_name().unwrap().to_str().unwrap();
match file_name {
"cte_quoted_reference.slt" => {
info!("Registering strict catalog provider for CTE tests");
register_strict_orders_catalog(test_ctx.session_ctx());
}
"information_schema_table_types.slt" => {
info!("Registering local temporary table");
register_temp_table(test_ctx.session_ctx()).await;
Expand Down Expand Up @@ -171,6 +175,104 @@ impl TestContext {
}
}

// ==============================================================================
// Strict Catalog / Schema Provider (sqllogictest-only)
// ==============================================================================
//
// The goal of `cte_quoted_reference.slt` is to exercise end-to-end query planning
// while detecting *unexpected* catalog lookups.
//
// Specifically, if DataFusion incorrectly treats a CTE reference (e.g. `"barbaz"`)
// as a real table reference, the planner will attempt to resolve it through the
// schema provider. The types below deliberately `panic!` on any lookup other than
// the one table we expect (`orders`).
//
// This makes the "extra provider lookup" bug observable in an end-to-end test,
// rather than being silently ignored by default providers that return `Ok(None)`
// for unknown tables.

#[derive(Debug)]
struct StrictOrdersCatalog {
schema: Arc<dyn SchemaProvider>,
}

impl CatalogProvider for StrictOrdersCatalog {
fn as_any(&self) -> &dyn Any {
self
}

fn schema_names(&self) -> Vec<String> {
vec!["public".to_string()]
}

fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
(name == "public").then(|| Arc::clone(&self.schema))
}
}

#[derive(Debug)]
struct StrictOrdersSchema {
orders: Arc<dyn TableProvider>,
}

#[async_trait]
impl SchemaProvider for StrictOrdersSchema {
fn as_any(&self) -> &dyn Any {
self
}

fn table_names(&self) -> Vec<String> {
vec!["orders".to_string()]
}

async fn table(
&self,
name: &str,
) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
match name {
"orders" => Ok(Some(Arc::clone(&self.orders))),
other => panic!(
"unexpected table lookup: {other}. This maybe indicates a CTE reference was \
incorrectly treated as a catalog table reference."
),
}
}

fn table_exist(&self, name: &str) -> bool {
name == "orders"
}
}

fn register_strict_orders_catalog(ctx: &SessionContext) {
let schema = Arc::new(Schema::new(vec![Field::new(
"order_id",
DataType::Int32,
false,
)]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)
.expect("record batch should be valid");

let orders =
MemTable::try_new(schema, vec![vec![batch]]).expect("memtable should be valid");

let schema_provider: Arc<dyn SchemaProvider> = Arc::new(StrictOrdersSchema {
orders: Arc::new(orders),
});

// Override the default "datafusion" catalog for this test file so that any
// unexpected lookup is caught immediately.
ctx.register_catalog(
"datafusion",
Arc::new(StrictOrdersCatalog {
schema: schema_provider,
}),
);
}

#[cfg(feature = "avro")]
pub async fn register_avro_tables(ctx: &mut TestContext) {
use datafusion::prelude::AvroReadOptions;
Expand Down
Loading