diff --git a/native/core/src/execution/expressions/array.rs b/native/core/src/execution/expressions/array.rs new file mode 100644 index 0000000000..7e8921c8d8 --- /dev/null +++ b/native/core/src/execution/expressions/array.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array expression builders + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::Expr; +use datafusion_comet_spark_expr::{ArrayExistsExpr, LambdaVariableExpr}; + +use crate::execution::operators::ExecutionError; +use crate::execution::planner::expression_registry::ExpressionBuilder; +use crate::execution::planner::PhysicalPlanner; +use crate::execution::serde::to_arrow_datatype; +use crate::extract_expr; + +pub struct ArrayExistsBuilder; + +impl ExpressionBuilder for ArrayExistsBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, ArrayExists); + let array_expr = + planner.create_expr(expr.array.as_ref().unwrap(), Arc::clone(&input_schema))?; + let lambda_body = planner.create_expr(expr.lambda_body.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(ArrayExistsExpr::new( + array_expr, + lambda_body, + expr.follow_three_valued_logic, + ))) + } +} + +pub struct LambdaVariableBuilder; + +impl ExpressionBuilder for LambdaVariableBuilder { + fn build( + &self, + spark_expr: &Expr, + _input_schema: SchemaRef, + _planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, LambdaVariable); + let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(LambdaVariableExpr::new(data_type))) + } +} diff --git a/native/core/src/execution/expressions/mod.rs b/native/core/src/execution/expressions/mod.rs index c2b144b7dd..6333bbd211 100644 --- a/native/core/src/execution/expressions/mod.rs +++ b/native/core/src/execution/expressions/mod.rs @@ -18,6 +18,7 @@ //! Native DataFusion expressions pub mod arithmetic; +pub mod array; pub mod bitwise; pub mod comparison; pub mod logical; diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index bf3904d9c1..50e9ae1f29 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -103,6 +103,8 @@ pub enum ExpressionType { Randn, SparkPartitionId, MonotonicallyIncreasingId, + ArrayExists, + LambdaVariable, // Time functions Hour, @@ -185,6 +187,9 @@ impl ExpressionRegistry { // Register temporal expressions self.register_temporal_expressions(); + // Register array expressions + self.register_array_expressions(); + // Register random expressions self.register_random_expressions(); @@ -312,6 +317,18 @@ impl ExpressionRegistry { ); } + /// Register array expression builders + fn register_array_expressions(&mut self) { + use crate::execution::expressions::array::*; + + self.builders + .insert(ExpressionType::ArrayExists, Box::new(ArrayExistsBuilder)); + self.builders.insert( + ExpressionType::LambdaVariable, + Box::new(LambdaVariableBuilder), + ); + } + /// Extract expression type from Spark protobuf expression fn get_expression_type(spark_expr: &Expr) -> Result { match spark_expr.expr_struct.as_ref() { @@ -376,6 +393,8 @@ impl ExpressionRegistry { Some(ExprStruct::MonotonicallyIncreasingId(_)) => { Ok(ExpressionType::MonotonicallyIncreasingId) } + Some(ExprStruct::ArrayExists(_)) => Ok(ExpressionType::ArrayExists), + Some(ExprStruct::LambdaVariable(_)) => Ok(ExpressionType::LambdaVariable), Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour), Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute), diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 5701577463..736cb5e9f1 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -88,6 +88,8 @@ message Expr { UnixTimestamp unix_timestamp = 65; FromJson from_json = 66; ToCsv to_csv = 67; + ArrayExists array_exists = 68; + LambdaVariable lambda_variable = 69; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -484,3 +486,17 @@ message ArrayJoin { message Rand { int64 seed = 1; } + +message ArrayExists { + Expr array = 1; + Expr lambda_body = 2; + bool follow_three_valued_logic = 3; +} + +// Currently only supports a single lambda variable per expression. The variable +// is resolved by column index (always the last column in the expanded batch +// constructed by ArrayExistsExpr). Extending to multi-argument lambdas +// (e.g. transform(array, (x, i) -> ...)) would require adding an identifier. +message LambdaVariable { + DataType datatype = 1; +} diff --git a/native/spark-expr/src/array_funcs/array_exists.rs b/native/spark-expr/src/array_funcs/array_exists.rs new file mode 100644 index 0000000000..dd23d13f6b --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_exists.rs @@ -0,0 +1,575 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BooleanArray, LargeListArray, ListArray, NullArray}; +use arrow::buffer::NullBuffer; +use arrow::compute::kernels::take::take; +use arrow::datatypes::{DataType, Field, Schema, UInt32Type}; +use arrow::record_batch::RecordBatch; +use datafusion::common::{DataFusionError, Result as DataFusionResult}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::any::Any; +use std::collections::HashSet; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +const LAMBDA_VAR_COLUMN: &str = "__comet_lambda_var"; + +/// Collect all column indices referenced by an expression tree. +fn collect_referenced_columns(expr: &Arc, indices: &mut HashSet) { + if let Some(col) = expr + .as_any() + .downcast_ref::() + { + indices.insert(col.index()); + } + for child in expr.children() { + collect_referenced_columns(child, indices); + } +} + +/// Decomposed list array: offsets as usize, values, and optional null buffer. +struct ListComponents { + offsets: Vec, + values: ArrayRef, + nulls: Option, +} + +impl ListComponents { + fn is_null(&self, row: usize) -> bool { + self.nulls.as_ref().is_some_and(|n| n.is_null(row)) + } +} + +fn decompose_list(array: &dyn Array) -> DataFusionResult { + if let Some(list) = array.as_any().downcast_ref::() { + Ok(ListComponents { + offsets: list.offsets().iter().map(|&o| o as usize).collect(), + values: Arc::clone(list.values()), + nulls: list.nulls().cloned(), + }) + } else if let Some(large) = array.as_any().downcast_ref::() { + Ok(ListComponents { + offsets: large.offsets().iter().map(|&o| o as usize).collect(), + values: Arc::clone(large.values()), + nulls: large.nulls().cloned(), + }) + } else { + Err(DataFusionError::Internal( + "ArrayExists expects a ListArray or LargeListArray input".to_string(), + )) + } +} + +/// Spark-compatible `array_exists(array, x -> predicate(x))`. +/// +/// Evaluates the lambda body vectorized over all elements in a single pass rather +/// than per-element to avoid repeated batch construction overhead. +#[derive(Debug, Eq)] +pub struct ArrayExistsExpr { + array_expr: Arc, + lambda_body: Arc, + follow_three_valued_logic: bool, +} + +impl Hash for ArrayExistsExpr { + fn hash(&self, state: &mut H) { + self.array_expr.hash(state); + self.lambda_body.hash(state); + self.follow_three_valued_logic.hash(state); + } +} + +impl PartialEq for ArrayExistsExpr { + fn eq(&self, other: &Self) -> bool { + self.array_expr.eq(&other.array_expr) + && self.lambda_body.eq(&other.lambda_body) + && self + .follow_three_valued_logic + .eq(&other.follow_three_valued_logic) + } +} + +impl ArrayExistsExpr { + pub fn new( + array_expr: Arc, + lambda_body: Arc, + follow_three_valued_logic: bool, + ) -> Self { + Self { + array_expr, + lambda_body, + follow_three_valued_logic, + } + } +} + +impl PhysicalExpr for ArrayExistsExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let num_rows = batch.num_rows(); + + let array_value = self.array_expr.evaluate(batch)?.into_array(num_rows)?; + let list = decompose_list(array_value.as_ref())?; + let total_elements = list.values.len(); + + if total_elements == 0 { + let mut result_builder = BooleanArray::builder(num_rows); + for row in 0..num_rows { + if list.is_null(row) { + result_builder.append_null(); + } else { + result_builder.append_value(false); + } + } + return Ok(ColumnarValue::Array(Arc::new(result_builder.finish()))); + } + + let mut repeat_indices = Vec::with_capacity(total_elements); + for row in 0..num_rows { + let start = list.offsets[row]; + let end = list.offsets[row + 1]; + for _ in start..end { + repeat_indices.push(row as u32); + } + } + + let repeat_indices_array = arrow::array::PrimitiveArray::::from(repeat_indices); + + // Only expand columns that are actually referenced by the lambda body. + // Unreferenced columns get a cheap NullArray placeholder to avoid costly take(). + let mut referenced_columns = HashSet::new(); + collect_referenced_columns(&self.lambda_body, &mut referenced_columns); + + let mut expanded_columns: Vec = Vec::with_capacity(batch.num_columns() + 1); + let mut expanded_fields: Vec> = Vec::with_capacity(batch.num_columns() + 1); + + for (i, col) in batch.columns().iter().enumerate() { + if referenced_columns.contains(&i) { + let expanded = take(col.as_ref(), &repeat_indices_array, None)?; + expanded_columns.push(expanded); + expanded_fields.push(Arc::new(batch.schema().field(i).clone())); + } else { + // Use a cheap NullArray placeholder for columns not referenced by the lambda + expanded_columns.push(Arc::new(NullArray::new(total_elements))); + expanded_fields.push(Arc::new(Field::new( + batch.schema().field(i).name(), + DataType::Null, + true, + ))); + } + } + + let element_field = Arc::new(Field::new( + LAMBDA_VAR_COLUMN, + list.values.data_type().clone(), + true, + )); + expanded_columns.push(Arc::clone(&list.values)); + expanded_fields.push(element_field); + + let expanded_schema = Arc::new(Schema::new(expanded_fields)); + let expanded_batch = RecordBatch::try_new(expanded_schema, expanded_columns)?; + + let body_result = self + .lambda_body + .evaluate(&expanded_batch)? + .into_array(total_elements)?; + + let body_booleans = body_result + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "ArrayExists lambda body must return BooleanArray".to_string(), + ) + })?; + + let mut result_builder = BooleanArray::builder(num_rows); + for row in 0..num_rows { + if list.is_null(row) { + result_builder.append_null(); + continue; + } + + let start = list.offsets[row]; + let end = list.offsets[row + 1]; + + if start == end { + result_builder.append_value(false); + continue; + } + + let mut found_true = false; + let mut found_null = false; + + for idx in start..end { + if body_booleans.is_null(idx) { + found_null = true; + } else if body_booleans.value(idx) { + found_true = true; + break; + } + } + + if found_true { + result_builder.append_value(true); + } else if found_null && self.follow_three_valued_logic { + result_builder.append_null(); + } else { + result_builder.append_value(false); + } + } + + Ok(ColumnarValue::Array(Arc::new(result_builder.finish()))) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.array_expr, &self.lambda_body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + match children.len() { + 2 => Ok(Arc::new(ArrayExistsExpr::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.follow_three_valued_logic, + ))), + _ => Err(DataFusionError::Internal( + "ArrayExistsExpr should have exactly two children".to_string(), + )), + } + } +} + +impl Display for ArrayExistsExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ArrayExists [array: {:?}, lambda_body: {:?}]", + self.array_expr, self.lambda_body + ) + } +} + +#[derive(Debug, Eq)] +pub struct LambdaVariableExpr { + data_type: DataType, +} + +impl Hash for LambdaVariableExpr { + fn hash(&self, state: &mut H) { + self.data_type.hash(state); + } +} + +impl PartialEq for LambdaVariableExpr { + fn eq(&self, other: &Self) -> bool { + self.data_type == other.data_type + } +} + +impl LambdaVariableExpr { + pub fn new(data_type: DataType) -> Self { + Self { data_type } + } +} + +impl PhysicalExpr for LambdaVariableExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + // The lambda variable is always the last column, appended by ArrayExistsExpr + let idx = batch.num_columns() - 1; + Ok(ColumnarValue::Array(Arc::clone(batch.column(idx)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + if children.is_empty() { + Ok(Arc::new(LambdaVariableExpr::new(self.data_type.clone()))) + } else { + Err(DataFusionError::Internal( + "LambdaVariableExpr should have no children".to_string(), + )) + } + } +} + +impl Display for LambdaVariableExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LambdaVariable({})", self.data_type) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::array::ListArray; + use arrow::datatypes::Int32Type; + use datafusion::physical_expr::expressions::{Column, Literal}; + use datafusion::{ + common::ScalarValue, logical_expr::Operator, physical_expr::expressions::BinaryExpr, + }; + + fn make_lambda_var_expr() -> Arc { + Arc::new(LambdaVariableExpr::new(DataType::Int32)) + } + + fn make_gt_predicate(threshold: i32) -> Arc { + Arc::new(BinaryExpr::new( + make_lambda_var_expr(), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(threshold)))), + )) + } + + #[test] + fn test_basic_exists() -> DataFusionResult<()> { + // exists(array(1, 2, 3), x -> x > 2) = true + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); + assert!(!bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_empty_array() -> DataFusionResult<()> { + // exists(array(), x -> x > 0) = false + let list = + ListArray::from_iter_primitive::(vec![ + Some(Vec::>::new()), + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(0); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(!bools.value(0)); + assert!(!bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_null_array() -> DataFusionResult<()> { + // exists(null, x -> x > 0) = null + let list = + ListArray::from_iter_primitive::(vec![None::>>]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(0); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_three_valued_logic() -> DataFusionResult<()> { + // exists(array(1, null, 3), x -> x > 5) = null (three-valued logic) + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(5); + + // With three-valued logic: result should be null + let expr = ArrayExistsExpr::new(Arc::clone(&array_expr), Arc::clone(&lambda_body), true); + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.is_null(0)); + + // Without three-valued logic: result should be false + let expr2 = ArrayExistsExpr::new(array_expr, lambda_body, false); + let result2 = expr2.evaluate(&batch)?.into_array(1)?; + let bools2 = result2.as_any().downcast_ref::().unwrap(); + assert!(!bools2.is_null(0)); + assert!(!bools2.value(0)); + Ok(()) + } + + #[test] + fn test_null_elements_with_match() -> DataFusionResult<()> { + // exists(array(1, null, 3), x -> x > 2) = true (because 3 > 2) + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(!bools.is_null(0)); + assert!(bools.value(0)); + Ok(()) + } + + #[test] + fn test_multiple_rows() -> DataFusionResult<()> { + // Row 0: [1, 2, 3] -> x > 2 -> true + // Row 1: [1, 2] -> x > 2 -> false + // Row 2: null -> null + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(1), Some(2)]), + None, + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(3)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); + assert!(!bools.value(1)); + assert!(bools.is_null(2)); + Ok(()) + } + + #[test] + fn test_multi_column_batch() -> DataFusionResult<()> { + // Verify batch expansion works correctly with additional columns + use arrow::array::Int32Array; + + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20)]), + Some(vec![Some(5)]), + ]); + let extra_col = Int32Array::from(vec![100, 200]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("arr", list.data_type().clone(), true), + Field::new("extra", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list), Arc::new(extra_col)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(15); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(2)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); // [10, 20] has 20 > 15 + assert!(!bools.value(1)); // [5] has no element > 15 + Ok(()) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 2bd1b9631b..4599944d99 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -16,12 +16,14 @@ // under the License. mod array_compact; +mod array_exists; mod array_insert; mod get_array_struct_fields; mod list_extract; mod size; pub use array_compact::SparkArrayCompact; +pub use array_exists::{ArrayExistsExpr, LambdaVariableExpr}; pub use array_insert::ArrayInsert; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 59fb0f9819..d188d90de0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -52,6 +52,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayContains] -> CometArrayContains, classOf[ArrayDistinct] -> CometArrayDistinct, classOf[ArrayExcept] -> CometArrayExcept, + classOf[ArrayExists] -> CometArrayExists, classOf[ArrayFilter] -> CometArrayFilter, classOf[ArrayInsert] -> CometArrayInsert, classOf[ArrayIntersect] -> CometArrayIntersect, @@ -66,6 +67,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ElementAt] -> CometElementAt, classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, + classOf[NamedLambdaVariable] -> CometNamedLambdaVariable, classOf[Size] -> CometSize) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index f107d5b309..0ff8490d2b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayExists, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, NamedLambdaVariable, Reverse, Size} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -569,6 +569,97 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { } } +object CometArrayExists extends CometExpressionSerde[ArrayExists] { + + /** Check if a lambda body contains nested lambda expressions (e.g., nested exists calls). */ + private def containsNestedLambda(expr: Expression): Boolean = { + expr match { + case _: LambdaFunction => true + case _ => expr.children.exists(containsNestedLambda) + } + } + + override def getSupportLevel(expr: ArrayExists): SupportLevel = { + if (!expr.followThreeValuedLogic) { + return Unsupported(Some("legacy non-three-valued logic mode is not supported")) + } + val elementType = expr.argument.dataType.asInstanceOf[ArrayType].elementType + elementType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | DateType | TimestampType | TimestampNTZType | StringType => + Compatible() + case _ => Unsupported(Some(s"element type not supported: $elementType")) + } + } + + override def convert( + expr: ArrayExists, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExprProto = exprToProto(expr.argument, inputs, binding) + if (arrayExprProto.isEmpty) { + withInfo(expr, expr.argument) + return None + } + + expr.function match { + case LambdaFunction(body, Seq(lambdaVar: NamedLambdaVariable), _) => + // Detect nested lambdas that we cannot support yet + if (containsNestedLambda(body)) { + withInfo(expr, "nested lambda expressions are not supported") + return None + } + + val bodyProto = exprToProto(body, inputs, binding) + if (bodyProto.isEmpty) { + withInfo(expr, body) + return None + } + + val arrayExistsBuilder = ExprOuterClass.ArrayExists + .newBuilder() + .setArray(arrayExprProto.get) + .setLambdaBody(bodyProto.get) + .setFollowThreeValuedLogic(expr.followThreeValuedLogic) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayExists(arrayExistsBuilder) + .build()) + + case other => + withInfo(expr, s"Unsupported lambda function form: $other") + None + } + } + +} + +object CometNamedLambdaVariable extends CometExpressionSerde[NamedLambdaVariable] { + + override def convert( + expr: NamedLambdaVariable, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val dataTypeProto = serializeDataType(expr.dataType) + if (dataTypeProto.isEmpty) { + withInfo(expr, s"Cannot serialize data type: ${expr.dataType}") + return None + } + + val lambdaVarBuilder = ExprOuterClass.LambdaVariable + .newBuilder() + .setDatatype(dataTypeProto.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setLambdaVariable(lambdaVarBuilder) + .build()) + } +} + object CometSize extends CometExpressionSerde[Size] { override def getSupportLevel(expr: Size): SupportLevel = { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql b/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql new file mode 100644 index 0000000000..5e04db6fee --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql @@ -0,0 +1,110 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_array_exists(arr_int array, arr_str array, arr_double array, arr_bool array, arr_long array, threshold int) USING parquet + +statement +INSERT INTO test_array_exists VALUES (array(1, 2, 3), array('a', 'bb', 'ccc'), array(1.5, 2.5, 3.5), array(false, false, true), array(100, 200, 300), 2), (array(1, 2), array('a', 'b'), array(0.5, 1.5), array(false, false), array(10, 20), 5), (array(), array(), array(), array(), array(), 0), (NULL, NULL, NULL, NULL, NULL, 1), (array(1, NULL, 3), array('a', NULL, 'ccc'), array(1.0, NULL, 3.0), array(true, NULL, false), array(100, NULL, 300), 2) + +-- basic: element satisfies predicate +query +SELECT exists(arr_int, x -> x > 2) FROM test_array_exists + +-- no match +query +SELECT exists(arr_int, x -> x > 100) FROM test_array_exists + +-- empty array returns false +query +SELECT exists(arr_int, x -> x > 0) FROM test_array_exists + +-- null array returns null +query +SELECT exists(arr_int, x -> x > 0) FROM test_array_exists WHERE arr_int IS NULL + +-- predicate referencing outer column +query +SELECT exists(arr_int, x -> x > threshold) FROM test_array_exists + +-- three-valued logic: null elements with no match -> null +query +SELECT exists(arr_int, x -> x > 5) FROM test_array_exists + +-- null elements but match exists -> true +query +SELECT exists(arr_int, x -> x > 2) FROM test_array_exists + +-- string type +query +SELECT exists(arr_str, x -> length(x) > 2) FROM test_array_exists + +-- double type +query +SELECT exists(arr_double, x -> x > 2.0) FROM test_array_exists + +-- boolean type +query +SELECT exists(arr_bool, x -> x) FROM test_array_exists + +-- long type +query +SELECT exists(arr_long, x -> x > 250) FROM test_array_exists + +-- literal arrays +query +SELECT exists(array(1, 2, 3), x -> x > 2) + +query +SELECT exists(array(1, 2, 3), x -> x > 5) + +-- empty literal array has NullType element type, which is unsupported +query spark_answer_only +SELECT exists(array(), x -> cast(x as int) > 0) + +query +SELECT exists(cast(NULL as array), x -> x > 0) + +-- null elements in literal array with three-valued logic +query +SELECT exists(array(1, NULL, 3), x -> x > 5) + +-- null elements in literal array with match +query +SELECT exists(array(1, NULL, 3), x -> x > 2) + +-- null elements with IS NULL predicate (non-null result despite null elements) +query +SELECT exists(array(0, null, 2, 3, null), x -> x IS NULL) + +-- timestamp type +query +SELECT exists(array(timestamp'2024-01-01 00:00:00', timestamp'2024-06-15 12:30:00'), x -> x > timestamp'2024-03-01 00:00:00') + +-- timestamp_ntz type +query +SELECT exists(array(timestamp_ntz'2024-01-01 00:00:00', timestamp_ntz'2024-06-15 12:30:00'), x -> x > timestamp_ntz'2024-03-01 00:00:00') + +-- complex predicate +query +SELECT exists(arr_int, x -> x > 1 AND x < 3) FROM test_array_exists + +-- predicate with modulo +query +SELECT exists(arr_int, x -> x % 2 = 0) FROM test_array_exists diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index bb519492db..ffb1d7befb 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -930,6 +930,158 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_exists - DataFrame API") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array, threshold int) using parquet") + sql(s"insert into $table values (array(1, 2, 3), 2)") + sql(s"insert into $table values (array(1, 2), 5)") + sql(s"insert into $table values (array(), 0)") + sql(s"insert into $table values (null, 1)") + sql(s"insert into $table values (array(1, null, 3), 2)") + + val df = spark.table(table) + + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > col("threshold")))) + checkSparkAnswerAndOperator( + df.select( + exists(col("arr"), x => x > 0).as("any_positive"), + exists(col("arr"), x => x > 100).as("any_large"))) + } + } + + test("array_exists - DataFrame API with decimal") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1.50, 2.75, 3.25))") + sql(s"insert into $table values (array(0.10, 0.20))") + + val df = spark.table(table) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2.0))) + } + } + + test("array_exists - DataFrame API with date") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(date'2024-01-01', date'2024-06-15'))") + sql(s"insert into $table values (array(date'2023-01-01'))") + + val df = spark.table(table) + checkSparkAnswerAndOperator( + df.select(exists(col("arr"), x => x > lit("2024-03-01").cast("date")))) + } + } + + test("array_exists - fallback for unsupported element type") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(X'01', X'02'))") + + val df = spark.table(table) + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr"), x => x.isNotNull)), + "element type not supported") + } + } + + test("array_exists - fallback with UDF in lambda") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1, 2, 3))") + sql(s"insert into $table values (array(4, 5, 6))") + sql(s"insert into $table values (null)") + + val isEven = udf((x: Int) => x % 2 == 0) + + val df = spark.table(table) + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr"), x => isEven(x))), + "scalaudf is not supported") + } + } + + test("array_exists - literal false and literal null lambdas") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1, 2, 3))") + sql(s"insert into $table values (array())") + sql(s"insert into $table values (null)") + + val df = spark.table(table) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), _ => lit(false)))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), _ => lit(true)))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), _ => lit(null).cast("boolean")))) + } + } + + test("array_exists - fallback for legacy non-three-valued logic") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1, null, 3))") + + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> "false") { + val df = spark.table(table) + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr"), x => x > 2)), + "legacy non-three-valued logic mode is not supported") + } + } + } + + test("array_exists - DataFrame API with timestamp") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql( + s"insert into $table values (array(timestamp'2024-01-01 00:00:00', timestamp'2024-06-15 12:30:00'))") + sql(s"insert into $table values (array(timestamp'2023-01-01 00:00:00'))") + + val df = spark.table(table) + checkSparkAnswerAndOperator( + df.select(exists(col("arr"), x => x > lit("2024-03-01 00:00:00").cast("timestamp")))) + } + } + + test("array_exists - CaseWhen/If in lambda") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1, 2, 3))") + sql(s"insert into $table values (array(-1, 0, 1))") + sql(s"insert into $table values (null)") + + val df = spark.table(table) + checkSparkAnswerAndOperator( + df.selectExpr("exists(arr, x -> CASE WHEN x > 0 THEN true ELSE false END)")) + checkSparkAnswerAndOperator(df.selectExpr("exists(arr, x -> IF(x > 0, true, false))")) + } + } + + test("array_exists - nested lambda falls back") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr1 array, arr2 array) using parquet") + sql(s"insert into $table values (array(1, 2, 3), array(4, 5, 6))") + sql(s"insert into $table values (array(10, 20), array(1, 2))") + sql(s"insert into $table values (array(1), array(1))") + + val df = spark.table(table) + // nested lambda: exists(arr1, x -> exists(arr2, y -> y > x)) + // nested lambdas are not yet supported, should fall back to Spark + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr1"), x => exists(col("arr2"), y => y > x))), + "nested lambda expressions are not supported") + } + } + // https://github.com/apache/datafusion-comet/issues/3375 test("(ansi) array access out of bounds - GetArrayItem") { withSQLConf(