diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index f67f4b5632..c2b372690d 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -182,6 +182,14 @@ The following Spark expressions are currently available. Any known compatibility | VariancePop | | | VarianceSamp | | +## Complex Types + +| Expression | Notes | +| ----------------- | ----- | +| CreateNamedStruct | | +| GetElementAt | | +| StructsToJson | | + ## Other | Expression | Notes | @@ -191,4 +199,3 @@ The following Spark expressions are currently available. Any known compatibility | ScalarSubquery | | | Coalesce | | | NormalizeNaNAndZero | | -| CreateNamedStruct | | diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index aee0471a5d..45de2bca97 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -95,7 +95,7 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::{ Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, MinuteExpr, RLike, - SecondExpr, TimestampTruncExpr, + SecondExpr, TimestampTruncExpr, ToJson, }; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ @@ -655,6 +655,10 @@ impl PhysicalPlanner { self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; Ok(Arc::new(GetStructField::new(child, expr.ordinal as usize))) } + ExprStruct::ToJson(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(ToJson::new(child, &expr.timezone))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index fa8f79ace2..5fd96ac0e2 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -79,6 +79,7 @@ message Expr { BloomFilterMightContain bloom_filter_might_contain = 52; CreateNamedStruct create_named_struct = 53; GetStructField get_struct_field = 54; + ToJson to_json = 55; } } @@ -343,6 +344,15 @@ message StringSpace { Expr child = 1; } +message ToJson { + Expr child = 1; + string timezone = 2; + string date_format = 3; + string timestamp_format = 4; + string timestamp_ntz_format = 5; + bool ignore_null_fields = 6; +} + message Hour { Expr child = 1; string timezone = 2; diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index ed8cdc2fe1..12dbfdcdc7 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -582,7 +582,7 @@ pub fn spark_cast( arg: ColumnarValue, data_type: &DataType, eval_mode: EvalMode, - timezone: String, + timezone: &str, allow_incompat: bool, ) -> DataFusionResult { match arg { @@ -1414,7 +1414,7 @@ impl PhysicalExpr for Cast { arg, &self.data_type, self.eval_mode, - self.timezone.clone(), + &self.timezone, self.allow_incompat, ) } diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 9fb16f94d9..6233a29eb7 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -30,6 +30,7 @@ pub mod spark_hash; mod structs; mod temporal; pub mod timezone; +mod to_json; pub mod utils; mod xxhash64; @@ -39,6 +40,7 @@ pub use if_expr::IfExpr; pub use regexp::RLike; pub use structs::{CreateNamedStruct, GetStructField}; pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; +pub use to_json::ToJson; /// Spark supports three evaluation modes when evaluating expressions, which affect /// the behavior when processing input values that are invalid or would result in an diff --git a/native/spark-expr/src/to_json.rs b/native/spark-expr/src/to_json.rs new file mode 100644 index 0000000000..2b9a2c5407 --- /dev/null +++ b/native/spark-expr/src/to_json.rs @@ -0,0 +1,352 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO upstream this to DataFusion as long as we have a way to specify all +// of the Spark-specific compatibility features that we need (including +// being able to specify Spark-compatible cast from all types to string) + +use crate::{spark_cast, EvalMode}; +use arrow_array::builder::StringBuilder; +use arrow_array::{Array, ArrayRef, RecordBatch, StringArray, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// to_json function +#[derive(Debug, Hash)] +pub struct ToJson { + /// The input to convert to JSON + expr: Arc, + /// Timezone to use when converting timestamps to JSON + timezone: String, +} + +impl ToJson { + pub fn new(expr: Arc, timezone: &str) -> Self { + Self { + expr, + timezone: timezone.to_owned(), + } + } +} + +impl Display for ToJson { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "to_json({}, timezone={})", self.expr, self.timezone) + } +} + +impl PartialEq for ToJson { + fn eq(&self, other: &dyn Any) -> bool { + if let Some(other) = other.downcast_ref::() { + self.expr.eq(&other.expr) && self.timezone.eq(&other.timezone) + } else { + false + } + } +} + +impl PhysicalExpr for ToJson { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> Result { + Ok(DataType::Utf8) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.expr.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let input = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + Ok(ColumnarValue::Array(array_to_json_string( + &input, + &self.timezone, + )?)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + &self.timezone, + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.expr.hash(&mut s); + self.timezone.hash(&mut s); + self.hash(&mut s); + } +} + +/// Convert an array into a JSON value string representation +fn array_to_json_string(arr: &Arc, timezone: &str) -> Result { + if let Some(struct_array) = arr.as_any().downcast_ref::() { + struct_to_json(struct_array, timezone) + } else { + spark_cast( + ColumnarValue::Array(Arc::clone(arr)), + &DataType::Utf8, + EvalMode::Legacy, + timezone, + false, + )? + .into_array(arr.len()) + } +} + +fn escape_string(input: &str) -> String { + let mut escaped_string = String::with_capacity(input.len()); + let mut is_escaped = false; + for c in input.chars() { + match c { + '\"' | '\\' if !is_escaped => { + escaped_string.push('\\'); + escaped_string.push(c); + is_escaped = false; + } + '\t' => { + escaped_string.push('\\'); + escaped_string.push('t'); + is_escaped = false; + } + '\r' => { + escaped_string.push('\\'); + escaped_string.push('r'); + is_escaped = false; + } + '\n' => { + escaped_string.push('\\'); + escaped_string.push('n'); + is_escaped = false; + } + '\x0C' => { + escaped_string.push('\\'); + escaped_string.push('f'); + is_escaped = false; + } + '\x08' => { + escaped_string.push('\\'); + escaped_string.push('b'); + is_escaped = false; + } + '\\' => { + escaped_string.push('\\'); + is_escaped = true; + } + _ => { + escaped_string.push(c); + is_escaped = false; + } + } + } + escaped_string +} + +fn struct_to_json(array: &StructArray, timezone: &str) -> Result { + // get field names and escape any quotes + let field_names: Vec = array + .fields() + .iter() + .map(|f| escape_string(f.name().as_str())) + .collect(); + // determine which fields need to have their values quoted + let is_string: Vec = array + .fields() + .iter() + .map(|f| match f.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => true, + DataType::Dictionary(_, dt) => { + matches!(dt.as_ref(), DataType::Utf8 | DataType::LargeUtf8) + } + _ => false, + }) + .collect(); + // create JSON string representation of each column + let string_arrays: Vec = array + .columns() + .iter() + .map(|arr| array_to_json_string(arr, timezone)) + .collect::>>()?; + let string_arrays: Vec<&StringArray> = string_arrays + .iter() + .map(|arr| { + arr.as_any() + .downcast_ref::() + .expect("string array") + }) + .collect(); + // build the JSON string containing entries in the format `"field_name":field_value` + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut json = String::with_capacity(array.len() * 16); + for row_index in 0..array.len() { + if array.is_null(row_index) { + builder.append_null(); + } else { + json.clear(); + let mut any_fields_written = false; + json.push('{'); + for col_index in 0..string_arrays.len() { + if !string_arrays[col_index].is_null(row_index) { + if any_fields_written { + json.push(','); + } + // quoted field name + json.push('"'); + json.push_str(&field_names[col_index]); + json.push_str("\":"); + // value + let string_value = string_arrays[col_index].value(row_index); + if is_string[col_index] { + json.push('"'); + json.push_str(&escape_string(string_value)); + json.push('"'); + } else { + json.push_str(string_value); + } + any_fields_written = true; + } + } + json.push('}'); + builder.append_value(&json); + } + } + Ok(Arc::new(builder.finish())) +} + +#[cfg(test)] +mod test { + use crate::to_json::struct_to_json; + use arrow_array::types::Int32Type; + use arrow_array::{Array, PrimitiveArray, StringArray}; + use arrow_array::{ArrayRef, BooleanArray, Int32Array, StructArray}; + use arrow_schema::{DataType, Field}; + use datafusion_common::Result; + use std::sync::Arc; + + #[test] + fn test_primitives() -> Result<()> { + let bools: ArrayRef = create_bools(); + let ints: ArrayRef = create_ints(); + let strings: ArrayRef = create_strings(); + let struct_array = StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Boolean, true)), bools), + (Arc::new(Field::new("b", DataType::Int32, true)), ints), + (Arc::new(Field::new("c", DataType::Utf8, true)), strings), + ]); + let json = struct_to_json(&struct_array, "UTC")?; + let json = json + .as_any() + .downcast_ref::() + .expect("string array"); + assert_eq!(4, json.len()); + assert_eq!(r#"{"b":123}"#, json.value(0)); + assert_eq!(r#"{"a":true,"c":"foo"}"#, json.value(1)); + assert_eq!(r#"{"a":false,"b":2147483647,"c":"bar"}"#, json.value(2)); + assert_eq!(r#"{"a":false,"b":-2147483648,"c":""}"#, json.value(3)); + Ok(()) + } + + #[test] + fn test_nested_struct() -> Result<()> { + let bools: ArrayRef = create_bools(); + let ints: ArrayRef = create_ints(); + + // create first struct array + let struct_fields = vec![ + Arc::new(Field::new("a", DataType::Boolean, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]; + let struct_values = vec![bools, ints]; + let struct_array = StructArray::from( + struct_fields + .clone() + .into_iter() + .zip(struct_values) + .collect::>(), + ); + + // create second struct array containing the first struct array + let struct_fields2 = vec![Arc::new(Field::new( + "a", + DataType::Struct(struct_fields.into()), + true, + ))]; + let struct_values2: Vec = vec![Arc::new(struct_array.clone())]; + let struct_array2 = StructArray::from( + struct_fields2 + .into_iter() + .zip(struct_values2) + .collect::>(), + ); + + let json = struct_to_json(&struct_array2, "UTC")?; + let json = json + .as_any() + .downcast_ref::() + .expect("string array"); + assert_eq!(4, json.len()); + assert_eq!(r#"{"a":{"b":123}}"#, json.value(0)); + assert_eq!(r#"{"a":{"a":true}}"#, json.value(1)); + assert_eq!(r#"{"a":{"a":false,"b":2147483647}}"#, json.value(2)); + assert_eq!(r#"{"a":{"a":false,"b":-2147483648}}"#, json.value(3)); + Ok(()) + } + + fn create_ints() -> Arc> { + Arc::new(Int32Array::from(vec![ + Some(123), + None, + Some(i32::MAX), + Some(i32::MIN), + ])) + } + + fn create_bools() -> Arc { + Arc::new(BooleanArray::from(vec![ + None, + Some(true), + Some(false), + Some(false), + ])) + } + + fn create_strings() -> Arc { + Arc::new(StringArray::from(vec![ + None, + Some("foo"), + Some("bar"), + Some(""), + ])) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 45f9d19924..717ea8911d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1258,6 +1258,66 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } + case StructsToJson(options, child, timezoneId) => + if (options.nonEmpty) { + withInfo(expr, "StructsToJson with options is not supported") + None + } else { + + def isSupportedType(dt: DataType): Boolean = { + dt match { + case StructType(fields) => + fields.forall(f => isSupportedType(f.dataType)) + case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | + DataTypes.DoubleType | DataTypes.StringType => + true + case DataTypes.DateType | DataTypes.TimestampType => + // TODO implement these types with tests for formatting options and timezone + false + case _: MapType | _: ArrayType => + // Spark supports map and array in StructsToJson but this is not yet + // implemented in Comet + false + case _ => false + } + } + + val isSupported = child.dataType match { + case s: StructType => + s.fields.forall(f => isSupportedType(f.dataType)) + case _: MapType | _: ArrayType => + // Spark supports map and array in StructsToJson but this is not yet + // implemented in Comet + false + case _ => + false + } + + if (isSupported) { + exprToProto(child, input, binding) match { + case Some(p) => + val toJson = ExprOuterClass.ToJson + .newBuilder() + .setChild(p) + .setTimezone(timezoneId.getOrElse("UTC")) + .setIgnoreNullFields(true) + .build() + Some( + ExprOuterClass.Expr + .newBuilder() + .setToJson(toJson) + .build()) + case _ => + withInfo(expr, child) + None + } + } else { + withInfo(expr, "Unsupported data type", child) + None + } + } + case Like(left, right, escapeChar) => if (escapeChar == '\\') { val leftExpr = exprToProtoInternal(left, inputs) diff --git a/spark/src/test/resources/tpcds-micro-benchmarks/to_json.sql b/spark/src/test/resources/tpcds-micro-benchmarks/to_json.sql new file mode 100644 index 0000000000..f78af92d64 --- /dev/null +++ b/spark/src/test/resources/tpcds-micro-benchmarks/to_json.sql @@ -0,0 +1,20 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- This is not part of TPC-DS but runs on TPC-DS data + +SELECT to_json(named_struct('id', i_item_sk, 'desc', i_item_desc, 'color', i_color)) FROM item; \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index c50823b582..c86cfa84ec 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1953,6 +1953,79 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("to_json") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + (0 until 100).map(i => { + val str = if (i % 2 == 0) { + "even" + } else { + "odd" + } + (i.toByte, i.toShort, i, i.toLong, i * 1.2f, -i * 1.2d, str, i.toString) + }), + "tbl", + withDictionary = dictionaryEnabled) { + + val fields = Range(1, 8).map(n => s"'col$n', _$n").mkString(", ") + + checkSparkAnswerAndOperator(s"SELECT to_json(named_struct($fields)) FROM tbl") + checkSparkAnswerAndOperator( + s"SELECT to_json(named_struct('nested', named_struct($fields))) FROM tbl") + } + } + } + + test("to_json escaping of field names and string values") { + val gen = new DataGenerator(new Random(42)) + val chars = "\\'\"abc\t\r\n\f\b" + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + (0 until 100).map(i => { + val str1 = gen.generateString(chars, 8) + val str2 = gen.generateString(chars, 8) + (i.toString, str1, str2) + }), + "tbl", + withDictionary = dictionaryEnabled) { + + val fields = Range(1, 3) + .map(n => { + val columnName = s"""column "$n"""" + s"'$columnName', _$n" + }) + .mkString(", ") + + checkSparkAnswerAndOperator( + """SELECT 'column "1"' x, """ + + s"to_json(named_struct($fields)) FROM tbl ORDER BY x") + } + } + } + + test("to_json unicode") { + Seq(true, false).foreach { dictionaryEnabled => + withParquetTable( + (0 until 100).map(i => { + (i.toString, "\uD83E\uDD11", "\u018F") + }), + "tbl", + withDictionary = dictionaryEnabled) { + + val fields = Range(1, 3) + .map(n => { + val columnName = s"""column "$n"""" + s"'$columnName', _$n" + }) + .mkString(", ") + + checkSparkAnswerAndOperator( + """SELECT 'column "1"' x, """ + + s"to_json(named_struct($fields)) FROM tbl ORDER BY x") + } + } + } + test("struct and named_struct with dictionary") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable( diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala index 40a84a1252..0839790ae7 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala @@ -74,7 +74,8 @@ object CometTPCDSMicroBenchmark extends CometTPCQueryBenchmarkBase { "join_inner", "join_left_outer", "join_semi", - "rlike") + "rlike", + "to_json") override def runQueries( queryLocation: String,