From 9b9eddbc404a08fe7cf75e9d115a4f15a08a2823 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 11 Dec 2025 11:14:59 +0530 Subject: [PATCH] fix: derive custom nullability for spark map_from_arrays --- .../spark/src/function/map/map_from_arrays.rs | 88 +++++++++++++++++-- 1 file changed, 79 insertions(+), 9 deletions(-) diff --git a/datafusion/spark/src/function/map/map_from_arrays.rs b/datafusion/spark/src/function/map/map_from_arrays.rs index 987548e353e44..dc155616dd77b 100644 --- a/datafusion/spark/src/function/map/map_from_arrays.rs +++ b/datafusion/spark/src/function/map/map_from_arrays.rs @@ -23,11 +23,14 @@ use crate::function::map::utils::{ }; use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::kernels::cast; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::take_function_args; -use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; /// Spark-compatible `map_from_arrays` expression /// @@ -63,12 +66,23 @@ impl ScalarUDFImpl for MapFromArrays { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [key_type, value_type] = take_function_args("map_from_arrays", arg_types)?; - Ok(map_type_from_key_value_types( - get_element_type(key_type)?, - get_element_type(value_type)?, - )) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [keys_field, values_field] = args.arg_fields else { + return internal_err!("map_from_arrays expects exactly 2 arguments"); + }; + + let map_type = map_type_from_key_value_types( + get_element_type(keys_field.data_type())?, + get_element_type(values_field.data_type())?, + ); + // Spark marks map_from_arrays as null intolerant, so the output is + // nullable if either input is nullable. + let nullable = keys_field.is_nullable() || values_field.is_nullable(); + Ok(Arc::new(Field::new(self.name(), map_type, nullable))) } fn invoke_with_args( @@ -103,3 +117,59 @@ fn map_from_arrays_inner(args: &[ArrayRef]) -> Result { values.nulls(), ) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Field; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn test_map_from_arrays_nullability_and_type() { + let func = MapFromArrays::new(); + + let keys_field: FieldRef = Arc::new(Field::new( + "keys", + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + false, + )); + let values_field: FieldRef = Arc::new(Field::new( + "values", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + )); + + let out = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&keys_field), Arc::clone(&values_field)], + scalar_arguments: &[None, None], + }) + .expect("return_field_from_args should succeed"); + + let expected_type = + map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8); + assert_eq!(out.data_type(), &expected_type); + assert!( + !out.is_nullable(), + "map_from_arrays should be non-nullable when both inputs are non-nullable" + ); + + let nullable_keys: FieldRef = Arc::new(Field::new( + "keys", + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + true, + )); + + let out_nullable = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_keys, values_field], + scalar_arguments: &[None, None], + }) + .expect("return_field_from_args should succeed"); + + assert!( + out_nullable.is_nullable(), + "map_from_arrays should be nullable when any input is nullable" + ); + } +}