Skip to content
67 changes: 67 additions & 0 deletions native/core/src/execution/expressions/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Array expression builders

use std::sync::Arc;

use arrow::datatypes::SchemaRef;
use datafusion::physical_expr::PhysicalExpr;
use datafusion_comet_proto::spark_expression::Expr;
use datafusion_comet_spark_expr::{ArrayExistsExpr, LambdaVariableExpr};

use crate::execution::operators::ExecutionError;
use crate::execution::planner::expression_registry::ExpressionBuilder;
use crate::execution::planner::PhysicalPlanner;
use crate::execution::serde::to_arrow_datatype;
use crate::extract_expr;

pub struct ArrayExistsBuilder;

impl ExpressionBuilder for ArrayExistsBuilder {
fn build(
&self,
spark_expr: &Expr,
input_schema: SchemaRef,
planner: &PhysicalPlanner,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let expr = extract_expr!(spark_expr, ArrayExists);
let array_expr =
planner.create_expr(expr.array.as_ref().unwrap(), Arc::clone(&input_schema))?;
let lambda_body = planner.create_expr(expr.lambda_body.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(ArrayExistsExpr::new(
array_expr,
lambda_body,
expr.follow_three_valued_logic,
)))
}
}

pub struct LambdaVariableBuilder;

impl ExpressionBuilder for LambdaVariableBuilder {
fn build(
&self,
spark_expr: &Expr,
_input_schema: SchemaRef,
_planner: &PhysicalPlanner,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let expr = extract_expr!(spark_expr, LambdaVariable);
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(LambdaVariableExpr::new(data_type)))
}
}
1 change: 1 addition & 0 deletions native/core/src/execution/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Native DataFusion expressions
pub mod arithmetic;
pub mod array;
pub mod bitwise;
pub mod comparison;
pub mod logical;
Expand Down
19 changes: 19 additions & 0 deletions native/core/src/execution/planner/expression_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ pub enum ExpressionType {
Randn,
SparkPartitionId,
MonotonicallyIncreasingId,
ArrayExists,
LambdaVariable,

// Time functions
Hour,
Expand Down Expand Up @@ -185,6 +187,9 @@ impl ExpressionRegistry {
// Register temporal expressions
self.register_temporal_expressions();

// Register array expressions
self.register_array_expressions();

// Register random expressions
self.register_random_expressions();

Expand Down Expand Up @@ -312,6 +317,18 @@ impl ExpressionRegistry {
);
}

/// Register array expression builders
fn register_array_expressions(&mut self) {
use crate::execution::expressions::array::*;

self.builders
.insert(ExpressionType::ArrayExists, Box::new(ArrayExistsBuilder));
self.builders.insert(
ExpressionType::LambdaVariable,
Box::new(LambdaVariableBuilder),
);
}

/// Extract expression type from Spark protobuf expression
fn get_expression_type(spark_expr: &Expr) -> Result<ExpressionType, ExecutionError> {
match spark_expr.expr_struct.as_ref() {
Expand Down Expand Up @@ -376,6 +393,8 @@ impl ExpressionRegistry {
Some(ExprStruct::MonotonicallyIncreasingId(_)) => {
Ok(ExpressionType::MonotonicallyIncreasingId)
}
Some(ExprStruct::ArrayExists(_)) => Ok(ExpressionType::ArrayExists),
Some(ExprStruct::LambdaVariable(_)) => Ok(ExpressionType::LambdaVariable),

Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour),
Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute),
Expand Down
16 changes: 16 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ message Expr {
UnixTimestamp unix_timestamp = 65;
FromJson from_json = 66;
ToCsv to_csv = 67;
ArrayExists array_exists = 68;
LambdaVariable lambda_variable = 69;
}

// Optional QueryContext for error reporting (contains SQL text and position)
Expand Down Expand Up @@ -484,3 +486,17 @@ message ArrayJoin {
message Rand {
int64 seed = 1;
}

message ArrayExists {
Expr array = 1;
Expr lambda_body = 2;
bool follow_three_valued_logic = 3;
}

// Currently only supports a single lambda variable per expression. The variable
// is resolved by column index (always the last column in the expanded batch
// constructed by ArrayExistsExpr). Extending to multi-argument lambdas
// (e.g. transform(array, (x, i) -> ...)) would require adding an identifier.
message LambdaVariable {
DataType datatype = 1;
}
Loading
Loading