diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 51bc1b6de285b..5f3490f535a46 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -41,6 +41,7 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx}; use datafusion_common::{ DataFusionError, Result, ScalarValue, arrow_datafusion_err, internal_err, + not_impl_err, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name}; @@ -133,8 +134,20 @@ impl AggregateUDFImpl for FirstValue { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + not_impl_err!("Not called because the return_field_from_args is implemented") + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + // Preserve metadata from the first argument field + Ok(Arc::new( + Field::new( + self.name(), + arg_fields[0].data_type().clone(), + true, // always nullable, there may be no rows + ) + .with_metadata(arg_fields[0].metadata().clone()), + )) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -1071,8 +1084,20 @@ impl AggregateUDFImpl for LastValue { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + not_impl_err!("Not called because the return_field_from_args is implemented") + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + // Preserve metadata from the first argument field + Ok(Arc::new( + Field::new( + self.name(), + arg_fields[0].data_type().clone(), + true, // always nullable, there may be no rows + ) + .with_metadata(arg_fields[0].metadata().clone()), + )) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index b499401e5589c..110108a554014 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -32,7 +32,7 @@ use arrow::record_batch::RecordBatch; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, }; -use datafusion::common::{not_impl_err, DataFusionError, Result}; +use datafusion::common::{exec_err, not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion::functions::math::abs; use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion::logical_expr::{ @@ -398,6 +398,58 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { .unwrap(); ctx.register_batch("table_with_metadata", batch).unwrap(); + + // Register the get_metadata UDF for testing metadata preservation + ctx.register_udf(ScalarUDF::from(GetMetadataUdf::new())); +} + +/// UDF to extract metadata from a field for testing purposes +/// Usage: get_metadata(expr, 'key') -> returns the metadata value or NULL +#[derive(Debug, PartialEq, Eq, Hash)] +struct GetMetadataUdf { + signature: Signature, +} + +impl GetMetadataUdf { + fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for GetMetadataUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "get_metadata" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Get the metadata key from the second argument (must be a string literal) + let key = match &args.args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(k))) => k.clone(), + _ => { + return exec_err!("get_metadata second argument must be a string literal") + } + }; + + // Get metadata from the first argument's field + let metadata_value = args.arg_fields[0].metadata().get(&key).cloned(); + + // Return as a scalar (same value for all rows) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(metadata_value))) + } } /// Create a UDF function named "example". See the `sample_udf.rs` example diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 8753d39cb7ef7..d7df46e4f90fc 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -235,7 +235,56 @@ order by 1 asc nulls last; 3 1 NULL 1 +# Regression test: first_value should preserve metadata +query IT +select first_value(id order by id asc nulls last), get_metadata(first_value(id order by id asc nulls last), 'metadata_key') +from table_with_metadata; +---- +1 the id field + +# Regression test: last_value should preserve metadata +query IT +select last_value(id order by id asc nulls first), get_metadata(last_value(id order by id asc nulls first), 'metadata_key') +from table_with_metadata; +---- +3 the id field +# Regression test: DISTINCT ON should preserve metadata (uses first_value internally) +query ITTT +select distinct on (id) id, get_metadata(id, 'metadata_key'), name, get_metadata(name, 'metadata_key') +from table_with_metadata order by id asc nulls last; +---- +1 the id field NULL the name field +3 the id field baz the name field +NULL the id field bar the name field + +# Regression test: DISTINCT should preserve metadata +query ITTT +with res AS ( + select distinct id, name from table_with_metadata +) +select id, get_metadata(id, 'metadata_key'), name, get_metadata(name, 'metadata_key') +from res +order by id asc nulls last; +---- +1 the id field NULL the name field +3 the id field baz the name field +NULL the id field bar the name field + +# Regression test: grouped columns should preserve metadata +query ITTT +with res AS ( + select name, count(*), id + from table_with_metadata + group by id, name +) +select id, get_metadata(id, 'metadata_key'), name, get_metadata(name, 'metadata_key') +from res +order by id asc nulls last, name asc nulls last +---- +1 the id field NULL the name field +3 the id field baz the name field +NULL the id field bar the name field statement ok drop table table_with_metadata;