diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index 341f76c1d147c..f0f509b4c2464 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -5,7 +5,7 @@ | | Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, SUM(shipping.volume) AS revenue | | | Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] | | | SubqueryAlias: shipping | -| | Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume | +| | Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, date_part(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume | | | Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") | | | Projection: lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate, customer.c_nationkey, n1.n_name | | | Inner Join: supplier.s_nationkey = n1.n_nationkey | @@ -33,7 +33,7 @@ | | CoalesceBatchesExec: target_batch_size=8192 | | | RepartitionExec: partitioning=Hash([Column { name: "supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 }, Column { name: "l_year", index: 2 }], 2), input_partitions=2 | | | AggregateExec: mode=Partial, gby=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year], aggr=[SUM(shipping.volume)] | -| | ProjectionExec: expr=[n_name@4 as supp_nation, n_name@6 as cust_nation, datepart(YEAR, l_shipdate@2) as l_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume] | +| | ProjectionExec: expr=[n_name@4 as supp_nation, n_name@6 as cust_nation, date_part(YEAR, l_shipdate@2) as l_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume] | | | CoalesceBatchesExec: target_batch_size=8192 | | | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "c_nationkey", index: 3 }, Column { name: "n_nationkey", index: 0 })], filter=n_name@0 = FRANCE AND n_name@1 = GERMANY OR n_name@0 = GERMANY AND n_name@1 = FRANCE | | | CoalesceBatchesExec: target_batch_size=8192 | @@ -85,4 +85,4 @@ | | FilterExec: n_name@1 = GERMANY OR n_name@1 = FRANCE | | | MemoryExec: partitions=0, partition_sizes=[] | | | | -+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ \ No newline at end of file ++---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/benchmarks/expected-plans/q8.txt b/benchmarks/expected-plans/q8.txt index b356b9ded3bc0..06f95ebb82850 100644 --- a/benchmarks/expected-plans/q8.txt +++ b/benchmarks/expected-plans/q8.txt @@ -5,7 +5,7 @@ | | Projection: all_nations.o_year, CAST(CAST(SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END) AS Decimal128(12, 2)) / CAST(SUM(all_nations.volume) AS Decimal128(12, 2)) AS Decimal128(15, 2)) AS mkt_share | | | Aggregate: groupBy=[[all_nations.o_year]], aggr=[[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)]] | | | SubqueryAlias: all_nations | -| | Projection: datepart(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume, n2.n_name AS nation | +| | Projection: date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume, n2.n_name AS nation | | | Inner Join: n1.n_regionkey = region.r_regionkey | | | Projection: lineitem.l_extendedprice, lineitem.l_discount, orders.o_orderdate, n1.n_regionkey, n2.n_name | | | Inner Join: supplier.s_nationkey = n2.n_nationkey | @@ -41,7 +41,7 @@ | | CoalesceBatchesExec: target_batch_size=8192 | | | RepartitionExec: partitioning=Hash([Column { name: "o_year", index: 0 }], 2), input_partitions=2 | | | AggregateExec: mode=Partial, gby=[o_year@0 as o_year], aggr=[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)] | -| | ProjectionExec: expr=[datepart(YEAR, o_orderdate@2) as o_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume, n_name@4 as nation] | +| | ProjectionExec: expr=[date_part(YEAR, o_orderdate@2) as o_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume, n_name@4 as nation] | | | CoalesceBatchesExec: target_batch_size=8192 | | | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "n_regionkey", index: 3 }, Column { name: "r_regionkey", index: 0 })] | | | CoalesceBatchesExec: target_batch_size=8192 | @@ -110,4 +110,4 @@ | | FilterExec: r_name@1 = AMERICA | | | MemoryExec: partitions=0, partition_sizes=[] | | | | -+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ \ No newline at end of file ++---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/benchmarks/expected-plans/q9.txt b/benchmarks/expected-plans/q9.txt index 24a5e7f16e134..9356113bb0e80 100644 --- a/benchmarks/expected-plans/q9.txt +++ b/benchmarks/expected-plans/q9.txt @@ -5,7 +5,7 @@ | | Projection: profit.nation, profit.o_year, SUM(profit.amount) AS sum_profit | | | Aggregate: groupBy=[[profit.nation, profit.o_year]], aggr=[[SUM(profit.amount)]] | | | SubqueryAlias: profit | -| | Projection: nation.n_name AS nation, datepart(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) - partsupp.ps_supplycost * lineitem.l_quantity AS amount | +| | Projection: nation.n_name AS nation, date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) - partsupp.ps_supplycost * lineitem.l_quantity AS amount | | | Inner Join: supplier.s_nationkey = nation.n_nationkey | | | Projection: lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey, partsupp.ps_supplycost, orders.o_orderdate | | | Inner Join: lineitem.l_orderkey = orders.o_orderkey | @@ -30,7 +30,7 @@ | | CoalesceBatchesExec: target_batch_size=8192 | | | RepartitionExec: partitioning=Hash([Column { name: "nation", index: 0 }, Column { name: "o_year", index: 1 }], 2), input_partitions=2 | | | AggregateExec: mode=Partial, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)] | -| | ProjectionExec: expr=[n_name@7 as nation, datepart(YEAR, o_orderdate@5) as o_year, l_extendedprice@1 * (Some(1),20,0 - l_discount@2) - ps_supplycost@4 * l_quantity@0 as amount] | +| | ProjectionExec: expr=[n_name@7 as nation, date_part(YEAR, o_orderdate@5) as o_year, l_extendedprice@1 * (Some(1),20,0 - l_discount@2) - ps_supplycost@4 * l_quantity@0 as amount] | | | CoalesceBatchesExec: target_batch_size=8192 | | | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 3 }, Column { name: "n_nationkey", index: 0 })] | | | CoalesceBatchesExec: target_batch_size=8192 | @@ -76,4 +76,4 @@ | | RepartitionExec: partitioning=Hash([Column { name: "n_nationkey", index: 0 }], 2), input_partitions=0 | | | MemoryExec: partitions=0, partition_sizes=[] | | | | -+---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ \ No newline at end of file ++---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9f59d672e5cb5..57513528be64c 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -610,9 +610,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f1e31e207a6b8fb791a38ea3105e6cb541f55e4d029902d3039a4ad07cc4105" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64-simd" @@ -1057,6 +1057,7 @@ dependencies = [ "ahash", "arrow", "datafusion-common", + "lazy_static", "sqlparser", ] @@ -1670,9 +1671,9 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "io-lifetimes" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi 0.3.1", "libc", @@ -2388,9 +2389,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.2" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1a59b5d8e97dee33696bf13c5ba8ab85341c002922fba050069326b9c498974" +checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390" dependencies = [ "aho-corasick", "memchr", @@ -3121,9 +3122,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" [[package]] name = "unicode-normalization" diff --git a/datafusion/core/tests/dataframe_functions.rs b/datafusion/core/tests/dataframe_functions.rs index 2f4e4d9d8c981..23af19983d3c4 100644 --- a/datafusion/core/tests/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe_functions.rs @@ -99,14 +99,14 @@ async fn test_fn_bit_length() -> Result<()> { let expr = bit_length(col("a")); let expected = vec![ - "+-------------------+", - "| bitlength(test.a) |", - "+-------------------+", - "| 48 |", - "| 48 |", - "| 48 |", - "| 72 |", - "+-------------------+", + "+--------------------+", + "| bit_length(test.a) |", + "+--------------------+", + "| 48 |", + "| 48 |", + "| 48 |", + "| 72 |", + "+--------------------+", ]; assert_fn_batches!(expr, expected); @@ -196,14 +196,14 @@ async fn test_fn_character_length() -> Result<()> { let expr = character_length(col("a")); let expected = vec![ - "+-------------------------+", - "| characterlength(test.a) |", - "+-------------------------+", - "| 6 |", - "| 6 |", - "| 6 |", - "| 9 |", - "+-------------------------+", + "+--------------------------+", + "| character_length(test.a) |", + "+--------------------------+", + "| 6 |", + "| 6 |", + "| 6 |", + "| 9 |", + "+--------------------------+", ]; assert_fn_batches!(expr, expected); @@ -395,14 +395,14 @@ async fn test_fn_regexp_match() -> Result<()> { let expr = regexp_match(vec![col("a"), lit("[a-z]")]); let expected = vec![ - "+-----------------------------------+", - "| regexpmatch(test.a,Utf8(\"[a-z]\")) |", - "+-----------------------------------+", - "| [a] |", - "| [a] |", - "| [d] |", - "| [b] |", - "+-----------------------------------+", + "+------------------------------------+", + "| regexp_match(test.a,Utf8(\"[a-z]\")) |", + "+------------------------------------+", + "| [a] |", + "| [a] |", + "| [d] |", + "| [b] |", + "+------------------------------------+", ]; assert_fn_batches!(expr, expected); @@ -416,14 +416,14 @@ async fn test_fn_regexp_replace() -> Result<()> { let expr = regexp_replace(vec![col("a"), lit("[a-z]"), lit("x"), lit("g")]); let expected = vec![ - "+---------------------------------------------------------+", - "| regexpreplace(test.a,Utf8(\"[a-z]\"),Utf8(\"x\"),Utf8(\"g\")) |", - "+---------------------------------------------------------+", - "| xxxDEF |", - "| xxx123 |", - "| CBAxxx |", - "| 123AxxDxx |", - "+---------------------------------------------------------+", + "+----------------------------------------------------------+", + "| regexp_replace(test.a,Utf8(\"[a-z]\"),Utf8(\"x\"),Utf8(\"g\")) |", + "+----------------------------------------------------------+", + "| xxxDEF |", + "| xxx123 |", + "| CBAxxx |", + "| 123AxxDxx |", + "+----------------------------------------------------------+", ]; assert_fn_batches!(expr, expected); @@ -581,14 +581,14 @@ async fn test_fn_split_part() -> Result<()> { let expr = split_part(col("a"), lit("b"), lit(1)); let expected = vec![ - "+--------------------------------------+", - "| splitpart(test.a,Utf8(\"b\"),Int32(1)) |", - "+--------------------------------------+", - "| a |", - "| a |", - "| CBAdef |", - "| 123A |", - "+--------------------------------------+", + "+---------------------------------------+", + "| split_part(test.a,Utf8(\"b\"),Int32(1)) |", + "+---------------------------------------+", + "| a |", + "| a |", + "| CBAdef |", + "| 123A |", + "+---------------------------------------+", ]; assert_fn_batches!(expr, expected); @@ -600,14 +600,14 @@ async fn test_fn_starts_with() -> Result<()> { let expr = starts_with(col("a"), lit("abc")); let expected = vec![ - "+--------------------------------+", - "| startswith(test.a,Utf8(\"abc\")) |", - "+--------------------------------+", - "| true |", - "| true |", - "| false |", - "| false |", - "+--------------------------------+", + "+---------------------------------+", + "| starts_with(test.a,Utf8(\"abc\")) |", + "+---------------------------------+", + "| true |", + "| true |", + "| false |", + "| false |", + "+---------------------------------+", ]; assert_fn_batches!(expr, expected); @@ -679,14 +679,14 @@ async fn test_fn_to_hex() -> Result<()> { let expr = to_hex(col("b")); let expected = vec![ - "+---------------+", - "| tohex(test.b) |", - "+---------------+", - "| 1 |", - "| a |", - "| a |", - "| 64 |", - "+---------------+", + "+----------------+", + "| to_hex(test.b) |", + "+----------------+", + "| 1 |", + "| a |", + "| a |", + "| 64 |", + "+----------------+", ]; assert_fn_batches!(expr, expected); diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index a1284cf9bbad8..68a3a7008e84f 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -711,44 +711,44 @@ async fn test_arrow_typeof() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----------------------------------------------------------------------+", - "| arrowtypeof(datetrunc(Utf8(\"minute\"),totimestampseconds(Int64(61)))) |", - "+----------------------------------------------------------------------+", - "| Timestamp(Second, None) |", - "+----------------------------------------------------------------------+", + "+--------------------------------------------------------------------------+", + "| arrow_typeof(date_trunc(Utf8(\"minute\"),to_timestamp_seconds(Int64(61)))) |", + "+--------------------------------------------------------------------------+", + "| Timestamp(Second, None) |", + "+--------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); let sql = "select arrow_typeof(date_trunc('second', to_timestamp_millis(61)));"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+---------------------------------------------------------------------+", - "| arrowtypeof(datetrunc(Utf8(\"second\"),totimestampmillis(Int64(61)))) |", - "+---------------------------------------------------------------------+", - "| Timestamp(Millisecond, None) |", - "+---------------------------------------------------------------------+", + "+-------------------------------------------------------------------------+", + "| arrow_typeof(date_trunc(Utf8(\"second\"),to_timestamp_millis(Int64(61)))) |", + "+-------------------------------------------------------------------------+", + "| Timestamp(Millisecond, None) |", + "+-------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); let sql = "select arrow_typeof(date_trunc('millisecond', to_timestamp_micros(61)));"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------------------------------------------+", - "| arrowtypeof(datetrunc(Utf8(\"millisecond\"),totimestampmicros(Int64(61)))) |", - "+--------------------------------------------------------------------------+", - "| Timestamp(Microsecond, None) |", - "+--------------------------------------------------------------------------+", + "+------------------------------------------------------------------------------+", + "| arrow_typeof(date_trunc(Utf8(\"millisecond\"),to_timestamp_micros(Int64(61)))) |", + "+------------------------------------------------------------------------------+", + "| Timestamp(Microsecond, None) |", + "+------------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); let sql = "select arrow_typeof(date_trunc('microsecond', to_timestamp(61)));"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------------------------------------+", - "| arrowtypeof(datetrunc(Utf8(\"microsecond\"),totimestamp(Int64(61)))) |", - "+--------------------------------------------------------------------+", - "| Timestamp(Nanosecond, None) |", - "+--------------------------------------------------------------------+", + "+-----------------------------------------------------------------------+", + "| arrow_typeof(date_trunc(Utf8(\"microsecond\"),to_timestamp(Int64(61)))) |", + "+-----------------------------------------------------------------------+", + "| Timestamp(Nanosecond, None) |", + "+-----------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -767,7 +767,7 @@ async fn cast_timestamp_to_timestamptz() -> Result<()> { let expected = vec![ "+-----------------------------+---------------------------------------+", - "| table_a.ts | arrowtypeof(table_a.ts) |", + "| table_a.ts | arrow_typeof(table_a.ts) |", "+-----------------------------+---------------------------------------+", "| 2020-09-08T13:42:29.190855Z | Timestamp(Nanosecond, Some(\"+00:00\")) |", "| 2020-09-08T12:42:29.190855Z | Timestamp(Nanosecond, Some(\"+00:00\")) |", @@ -1018,7 +1018,7 @@ async fn test_ts_dt_binary_ops() -> Result<()> { } assert_eq!( res, - Some("Projection: now() = currentdate()\n EmptyRelation".to_string()) + Some("Projection: now() = current_date()\n EmptyRelation".to_string()) ); Ok(()) diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt b/datafusion/core/tests/sqllogictests/test_files/scalar.slt index 13b8a74860cc8..2b61fe0600ac1 100644 --- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt +++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt @@ -1223,7 +1223,7 @@ statement error Error during planning: No function matches the given name and ar SELECT pi(3.14); # error message for wrong function signature (Any: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'arrowtypeof\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarrowtypeof\(Any\) +statement error Error during planning: No function matches the given name and argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\) SELECT arrow_typeof(1, 1); # error message for wrong function signature (OneOf: fixed number of args of arbitrary types) diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index b2d336a89d56d..160d76deb13c7 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,6 +38,7 @@ path = "src/lib.rs" ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } arrow = { workspace = true } datafusion-common = { path = "../common", version = "25.0.0" } +lazy_static = { version = "^1.4.0" } sqlparser = "0.34" [dev-dependencies] diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 6633e4c0a9178..26b88a71a706f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -19,6 +19,7 @@ use crate::Volatility; use datafusion_common::{DataFusionError, Result}; +use lazy_static::lazy_static; use std::fmt; use std::str::FromStr; @@ -202,6 +203,120 @@ pub enum BuiltinScalarFunction { ArrowTypeof, } +lazy_static! { + /// Mapping between SQL function names to `BuiltinScalarFunction` types. + /// Note that multiple SQL function names can represent the same `BuiltinScalarFunction`. These are treated as aliases. + /// In case of such aliases, the first SQL function name in the vector is used when displaying the function. + static ref NAME_TO_FUNCTION: Vec<(&'static str, BuiltinScalarFunction)> = vec![ + // math functions + ("abs", BuiltinScalarFunction::Abs), + ("acos", BuiltinScalarFunction::Acos), + ("acosh", BuiltinScalarFunction::Acosh), + ("asin", BuiltinScalarFunction::Asin), + ("asinh", BuiltinScalarFunction::Asinh), + ("atan", BuiltinScalarFunction::Atan), + ("atanh", BuiltinScalarFunction::Atanh), + ("atan2", BuiltinScalarFunction::Atan2), + ("cbrt", BuiltinScalarFunction::Cbrt), + ("ceil", BuiltinScalarFunction::Ceil), + ("cos", BuiltinScalarFunction::Cos), + ("cosh", BuiltinScalarFunction::Cosh), + ("degrees", BuiltinScalarFunction::Degrees), + ("exp", BuiltinScalarFunction::Exp), + ("factorial", BuiltinScalarFunction::Factorial), + ("floor", BuiltinScalarFunction::Floor), + ("gcd", BuiltinScalarFunction::Gcd), + ("lcm", BuiltinScalarFunction::Lcm), + ("ln", BuiltinScalarFunction::Ln), + ("log", BuiltinScalarFunction::Log), + ("log10", BuiltinScalarFunction::Log10), + ("log2", BuiltinScalarFunction::Log2), + ("pi", BuiltinScalarFunction::Pi), + ("power", BuiltinScalarFunction::Power), + ("pow", BuiltinScalarFunction::Power), + ("radians", BuiltinScalarFunction::Radians), + ("random", BuiltinScalarFunction::Random), + ("round", BuiltinScalarFunction::Round), + ("signum", BuiltinScalarFunction::Signum), + ("sin", BuiltinScalarFunction::Sin), + ("sinh", BuiltinScalarFunction::Sinh), + ("sqrt", BuiltinScalarFunction::Sqrt), + ("tan", BuiltinScalarFunction::Tan), + ("tanh", BuiltinScalarFunction::Tanh), + ("trunc", BuiltinScalarFunction::Trunc), + + // conditional functions + ("coalesce", BuiltinScalarFunction::Coalesce), + ("nullif", BuiltinScalarFunction::NullIf), + + // string functions + ("ascii", BuiltinScalarFunction::Ascii), + ("bit_length", BuiltinScalarFunction::BitLength), + ("btrim", BuiltinScalarFunction::Btrim), + ("character_length", BuiltinScalarFunction::CharacterLength), + ("char_length", BuiltinScalarFunction::CharacterLength), + ("concat", BuiltinScalarFunction::Concat), + ("concat_ws", BuiltinScalarFunction::ConcatWithSeparator), + ("chr", BuiltinScalarFunction::Chr), + ("initcap", BuiltinScalarFunction::InitCap), + ("left", BuiltinScalarFunction::Left), + ("length", BuiltinScalarFunction::CharacterLength), + ("lower", BuiltinScalarFunction::Lower), + ("lpad", BuiltinScalarFunction::Lpad), + ("ltrim", BuiltinScalarFunction::Ltrim), + ("octet_length", BuiltinScalarFunction::OctetLength), + ("repeat", BuiltinScalarFunction::Repeat), + ("replace", BuiltinScalarFunction::Replace), + ("reverse", BuiltinScalarFunction::Reverse), + ("right", BuiltinScalarFunction::Right), + ("rpad", BuiltinScalarFunction::Rpad), + ("rtrim", BuiltinScalarFunction::Rtrim), + ("split_part", BuiltinScalarFunction::SplitPart), + ("starts_with", BuiltinScalarFunction::StartsWith), + ("strpos", BuiltinScalarFunction::Strpos), + ("substr", BuiltinScalarFunction::Substr), + ("to_hex", BuiltinScalarFunction::ToHex), + ("translate", BuiltinScalarFunction::Translate), + ("trim", BuiltinScalarFunction::Trim), + ("upper", BuiltinScalarFunction::Upper), + ("uuid", BuiltinScalarFunction::Uuid), + + // regex functions + ("regexp_match", BuiltinScalarFunction::RegexpMatch), + ("regexp_replace", BuiltinScalarFunction::RegexpReplace), + + // time/date functions + ("now", BuiltinScalarFunction::Now), + ("current_date", BuiltinScalarFunction::CurrentDate), + ("current_time", BuiltinScalarFunction::CurrentTime), + ("date_bin", BuiltinScalarFunction::DateBin), + ("date_trunc", BuiltinScalarFunction::DateTrunc), + ("datetrunc", BuiltinScalarFunction::DateTrunc), + ("date_part", BuiltinScalarFunction::DatePart), + ("datepart", BuiltinScalarFunction::DatePart), + ("to_timestamp", BuiltinScalarFunction::ToTimestamp), + ("to_timestamp_millis", BuiltinScalarFunction::ToTimestampMillis), + ("to_timestamp_micros", BuiltinScalarFunction::ToTimestampMicros), + ("to_timestamp_seconds", BuiltinScalarFunction::ToTimestampSeconds), + ("from_unixtime", BuiltinScalarFunction::FromUnixtime), + + // hashing functions + ("digest", BuiltinScalarFunction::Digest), + ("md5", BuiltinScalarFunction::MD5), + ("sha224", BuiltinScalarFunction::SHA224), + ("sha256", BuiltinScalarFunction::SHA256), + ("sha384", BuiltinScalarFunction::SHA384), + ("sha512", BuiltinScalarFunction::SHA512), + + // other functions + ("struct", BuiltinScalarFunction::Struct), + ("arrow_typeof", BuiltinScalarFunction::ArrowTypeof), + + // array functions + ("make_array", BuiltinScalarFunction::MakeArray), + ]; +} + impl BuiltinScalarFunction { /// an allowlist of functions to take zero arguments, so that they will get special treatment /// while executing. @@ -316,7 +431,13 @@ impl BuiltinScalarFunction { impl fmt::Display for BuiltinScalarFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // lowercase of the debug. + for (func_name, func) in NAME_TO_FUNCTION.iter() { + if func == self { + return write!(f, "{}", func_name); + } + } + + // Should not be reached write!(f, "{}", format!("{self:?}").to_lowercase()) } } @@ -324,116 +445,32 @@ impl fmt::Display for BuiltinScalarFunction { impl FromStr for BuiltinScalarFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { - Ok(match name { - // math functions - "abs" => BuiltinScalarFunction::Abs, - "acos" => BuiltinScalarFunction::Acos, - "acosh" => BuiltinScalarFunction::Acosh, - "asin" => BuiltinScalarFunction::Asin, - "asinh" => BuiltinScalarFunction::Asinh, - "atan" => BuiltinScalarFunction::Atan, - "atanh" => BuiltinScalarFunction::Atanh, - "atan2" => BuiltinScalarFunction::Atan2, - "cbrt" => BuiltinScalarFunction::Cbrt, - "ceil" => BuiltinScalarFunction::Ceil, - "cos" => BuiltinScalarFunction::Cos, - "cosh" => BuiltinScalarFunction::Cosh, - "degrees" => BuiltinScalarFunction::Degrees, - "exp" => BuiltinScalarFunction::Exp, - "factorial" => BuiltinScalarFunction::Factorial, - "floor" => BuiltinScalarFunction::Floor, - "gcd" => BuiltinScalarFunction::Gcd, - "lcm" => BuiltinScalarFunction::Lcm, - "ln" => BuiltinScalarFunction::Ln, - "log" => BuiltinScalarFunction::Log, - "log10" => BuiltinScalarFunction::Log10, - "log2" => BuiltinScalarFunction::Log2, - "pi" => BuiltinScalarFunction::Pi, - "power" | "pow" => BuiltinScalarFunction::Power, - "radians" => BuiltinScalarFunction::Radians, - "random" => BuiltinScalarFunction::Random, - "round" => BuiltinScalarFunction::Round, - "signum" => BuiltinScalarFunction::Signum, - "sin" => BuiltinScalarFunction::Sin, - "sinh" => BuiltinScalarFunction::Sinh, - "sqrt" => BuiltinScalarFunction::Sqrt, - "tan" => BuiltinScalarFunction::Tan, - "tanh" => BuiltinScalarFunction::Tanh, - "trunc" => BuiltinScalarFunction::Trunc, - - // conditional functions - "coalesce" => BuiltinScalarFunction::Coalesce, - "nullif" => BuiltinScalarFunction::NullIf, - - // string functions - "ascii" => BuiltinScalarFunction::Ascii, - "bit_length" => BuiltinScalarFunction::BitLength, - "btrim" => BuiltinScalarFunction::Btrim, - "char_length" => BuiltinScalarFunction::CharacterLength, - "character_length" => BuiltinScalarFunction::CharacterLength, - "concat" => BuiltinScalarFunction::Concat, - "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, - "chr" => BuiltinScalarFunction::Chr, - "initcap" => BuiltinScalarFunction::InitCap, - "left" => BuiltinScalarFunction::Left, - "length" => BuiltinScalarFunction::CharacterLength, - "lower" => BuiltinScalarFunction::Lower, - "lpad" => BuiltinScalarFunction::Lpad, - "ltrim" => BuiltinScalarFunction::Ltrim, - "octet_length" => BuiltinScalarFunction::OctetLength, - "repeat" => BuiltinScalarFunction::Repeat, - "replace" => BuiltinScalarFunction::Replace, - "reverse" => BuiltinScalarFunction::Reverse, - "right" => BuiltinScalarFunction::Right, - "rpad" => BuiltinScalarFunction::Rpad, - "rtrim" => BuiltinScalarFunction::Rtrim, - "split_part" => BuiltinScalarFunction::SplitPart, - "starts_with" => BuiltinScalarFunction::StartsWith, - "strpos" => BuiltinScalarFunction::Strpos, - "substr" => BuiltinScalarFunction::Substr, - "to_hex" => BuiltinScalarFunction::ToHex, - "translate" => BuiltinScalarFunction::Translate, - "trim" => BuiltinScalarFunction::Trim, - "upper" => BuiltinScalarFunction::Upper, - "uuid" => BuiltinScalarFunction::Uuid, - - // regex functions - "regexp_match" => BuiltinScalarFunction::RegexpMatch, - "regexp_replace" => BuiltinScalarFunction::RegexpReplace, - - // time/date functions - "now" => BuiltinScalarFunction::Now, - "current_date" => BuiltinScalarFunction::CurrentDate, - "current_time" => BuiltinScalarFunction::CurrentTime, - "date_bin" => BuiltinScalarFunction::DateBin, - "date_trunc" | "datetrunc" => BuiltinScalarFunction::DateTrunc, - "date_part" | "datepart" => BuiltinScalarFunction::DatePart, - "to_timestamp" => BuiltinScalarFunction::ToTimestamp, - "to_timestamp_millis" => BuiltinScalarFunction::ToTimestampMillis, - "to_timestamp_micros" => BuiltinScalarFunction::ToTimestampMicros, - "to_timestamp_seconds" => BuiltinScalarFunction::ToTimestampSeconds, - "from_unixtime" => BuiltinScalarFunction::FromUnixtime, - - // hashing functions - "digest" => BuiltinScalarFunction::Digest, - "md5" => BuiltinScalarFunction::MD5, - "sha224" => BuiltinScalarFunction::SHA224, - "sha256" => BuiltinScalarFunction::SHA256, - "sha384" => BuiltinScalarFunction::SHA384, - "sha512" => BuiltinScalarFunction::SHA512, + for (func_name, func) in NAME_TO_FUNCTION.iter() { + if name == *func_name { + return Ok(func.clone()); + } + } - // other functions - "struct" => BuiltinScalarFunction::Struct, - "arrow_typeof" => BuiltinScalarFunction::ArrowTypeof, + Err(DataFusionError::Plan(format!( + "There is no built-in function named {name}" + ))) + } +} - // array functions - "make_array" => BuiltinScalarFunction::MakeArray, +#[cfg(test)] +mod tests { + use super::*; - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in function named {name}" - ))) - } - }) + #[test] + // Test for BuiltinScalarFunction's Display and from_str() implementations. + // For each variant in BuiltinScalarFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. + fn test_display_and_from_str() { + for (_, func_original) in NAME_TO_FUNCTION.iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltinScalarFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, *func_original); + } } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index fbb61fb1a31ac..1761d237c3d36 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1212,7 +1212,7 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = - "Projection: concatwithseparator(Utf8(\"-\"), a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; + "Projection: concat_ws(Utf8(\"-\"), a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 83c7923f07d2a..42850178136e5 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -419,7 +419,7 @@ mod tests { .project(proj)? .build()?; - let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ + let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ \n TableScan: test" .to_string(); let actual = get_optimized_plan_formatted(&plan, &Utc::now()); @@ -559,7 +559,7 @@ mod tests { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = r#"Projection: Date32("18636") AS totimestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408") + let expected = r#"Projection: Date32("18636") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408") TableScan: test"#; let actual = get_optimized_plan_formatted(&plan, &time); diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 761d6539b23c0..a4220797453ed 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -225,7 +225,7 @@ fn concat_ws_literals() -> Result<()> { FROM test"; let plan = test_sql(sql)?; let expected = - "Projection: concatwithseparator(Utf8(\"-\"), Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ + "Projection: concat_ws(Utf8(\"-\"), Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 452761454afc0..a0d4d3d542511 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2479,7 +2479,7 @@ fn select_groupby_orderby() { // expect that this is not an ambiguous reference let expected = "Sort: birth_date ASC NULLS LAST\ - \n Projection: AVG(person.age) AS value, datetrunc(Utf8(\"month\"), person.birth_date) AS birth_date\ + \n Projection: AVG(person.age) AS value, date_trunc(Utf8(\"month\"), person.birth_date) AS birth_date\ \n Aggregate: groupBy=[[person.birth_date]], aggr=[[AVG(person.age)]]\ \n TableScan: person"; quick_test(sql, expected);