diff --git a/datafusion/spark/src/function/url/mod.rs b/datafusion/spark/src/function/url/mod.rs index 7c959572a8263..82bf8a9e09616 100644 --- a/datafusion/spark/src/function/url/mod.rs +++ b/datafusion/spark/src/function/url/mod.rs @@ -20,15 +20,26 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; pub mod parse_url; +pub mod try_parse_url; make_udf_function!(parse_url::ParseUrl, parse_url); +make_udf_function!(try_parse_url::TryParseUrl, try_parse_url); pub mod expr_fn { use datafusion_functions::export_functions; - export_functions!((parse_url, "Extracts a part from a URL.", args)); + export_functions!(( + parse_url, + "Extracts a part from a URL, throwing an error if an invalid URL is provided.", + args + )); + export_functions!(( + try_parse_url, + "Same as parse_url but returns NULL if an invalid URL is provided.", + args + )); } pub fn functions() -> Vec> { - vec![parse_url()] + vec![parse_url(), try_parse_url()] } diff --git a/datafusion/spark/src/function/url/parse_url.rs b/datafusion/spark/src/function/url/parse_url.rs index f9c33060cc5ef..d93c260b4f340 100644 --- a/datafusion/spark/src/function/url/parse_url.rs +++ b/datafusion/spark/src/function/url/parse_url.rs @@ -26,13 +26,13 @@ use arrow::datatypes::DataType; use datafusion_common::cast::{ as_large_string_array, as_string_array, as_string_view_array, }; -use datafusion_common::{exec_datafusion_err, exec_err, plan_err, Result}; +use datafusion_common::{exec_datafusion_err, exec_err, Result}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_functions::utils::make_scalar_function; -use url::Url; +use url::{ParseError, Url}; #[derive(Debug, PartialEq, Eq, Hash)] pub struct ParseUrl { @@ -49,20 +49,7 @@ impl ParseUrl { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![ - TypeSignature::Uniform( - 1, - vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8], - ), - TypeSignature::Uniform( - 2, - vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8], - ), - TypeSignature::Uniform( - 3, - vec![DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8], - ), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), } @@ -95,11 +82,22 @@ impl ParseUrl { /// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed /// fn parse(value: &str, part: &str, key: Option<&str>) -> Result> { - Url::parse(value) - .map_err(|e| exec_datafusion_err!("{e:?}")) + let url: std::result::Result = Url::parse(value); + if let Err(ParseError::RelativeUrlWithoutBase) = url { + return if !value.contains("://") { + Ok(None) + } else { + Err(exec_datafusion_err!("The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02")) + }; + }; + url.map_err(|e| exec_datafusion_err!("{e:?}")) .map(|url| match part { "HOST" => url.host_str().map(String::from), - "PATH" => Some(url.path().to_string()), + "PATH" => { + let path: String = url.path().to_string(); + let path: String = if path == "/" { "".to_string() } else { path }; + Some(path) + } "QUERY" => match key { None => url.query().map(String::from), Some(key) => url @@ -146,35 +144,7 @@ impl ScalarUDFImpl for ParseUrl { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() < 2 || arg_types.len() > 3 { - return plan_err!( - "{} expects 2 or 3 arguments, but got {}", - self.name(), - arg_types.len() - ); - } - match arg_types.len() { - 2 | 3 => { - if arg_types - .iter() - .any(|arg| matches!(arg, DataType::LargeUtf8)) - { - Ok(DataType::LargeUtf8) - } else if arg_types - .iter() - .any(|arg| matches!(arg, DataType::Utf8View)) - { - Ok(DataType::Utf8View) - } else { - Ok(DataType::Utf8) - } - } - _ => plan_err!( - "`{}` expects 2 or 3 arguments, got {}", - &self.name(), - arg_types.len() - ), - } + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -200,6 +170,13 @@ impl ScalarUDFImpl for ParseUrl { /// - The output array type (StringArray or LargeStringArray) is determined by input types /// fn spark_parse_url(args: &[ArrayRef]) -> Result { + spark_handled_parse_url(args, |x| x) +} + +pub fn spark_handled_parse_url( + args: &[ArrayRef], + handler_err: impl Fn(Result>) -> Result>, +) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!( "{} expects 2 or 3 arguments, but got {}", @@ -212,6 +189,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result { let part = &args[1]; let result = if args.len() == 3 { + // In this case, the 'key' argument is passed let key = &args[2]; match (url.data_type(), part.data_type(), key.data_type()) { @@ -220,6 +198,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result { as_string_array(url)?, as_string_array(part)?, as_string_array(key)?, + handler_err, ) } (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { @@ -227,6 +206,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result { as_string_view_array(url)?, as_string_view_array(part)?, as_string_view_array(key)?, + handler_err, ) } (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { @@ -234,6 +214,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result { as_large_string_array(url)?, as_large_string_array(part)?, as_large_string_array(key)?, + handler_err, ) } _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), @@ -253,6 +234,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result { as_string_array(url)?, as_string_array(part)?, &key, + handler_err, ) } (DataType::Utf8View, DataType::Utf8View) => { @@ -260,6 +242,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result { as_string_view_array(url)?, as_string_view_array(part)?, &key, + handler_err, ) } (DataType::LargeUtf8, DataType::LargeUtf8) => { @@ -267,6 +250,7 @@ fn spark_parse_url(args: &[ArrayRef]) -> Result { as_large_string_array(url)?, as_large_string_array(part)?, &key, + handler_err, ) } _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), @@ -279,6 +263,7 @@ fn process_parse_url<'a, A, B, C, T>( url_array: &'a A, part_array: &'a B, key_array: &'a C, + handle: impl Fn(Result>) -> Result>, ) -> Result where &'a A: StringArrayType<'a>, @@ -292,7 +277,7 @@ where .zip(key_array.iter()) .map(|((url, part), key)| { if let (Some(url), Some(part), key) = (url, part, key) { - ParseUrl::parse(url, part, key) + handle(ParseUrl::parse(url, part, key)) } else { Ok(None) } @@ -300,3 +285,148 @@ where .collect::>() .map(|array| Arc::new(array) as ArrayRef) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, Int32Array, StringArray}; + use datafusion_common::Result; + use std::array::from_ref; + use std::sync::Arc; + + fn sa(vals: &[Option<&str>]) -> ArrayRef { + Arc::new(StringArray::from(vals.to_vec())) as ArrayRef + } + + #[test] + fn test_parse_host() -> Result<()> { + let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?; + assert_eq!(got, Some("example.com".to_string())); + Ok(()) + } + + #[test] + fn test_parse_query_no_key_vs_with_key() -> Result<()> { + let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?; + assert_eq!(got_all, Some("a=1&b=2".to_string())); + + let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?; + assert_eq!(got_a, Some("1".to_string())); + + let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?; + assert_eq!(got_c, None); + Ok(()) + } + + #[test] + fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> { + let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag"; + assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string())); + assert_eq!( + ParseUrl::parse(url, "PROTOCOL", None)?, + Some("ftp".to_string()) + ); + assert_eq!( + ParseUrl::parse(url, "USERINFO", None)?, + Some("user:pwd".to_string()) + ); + assert_eq!( + ParseUrl::parse(url, "FILE", None)?, + Some("/files?x=1".to_string()) + ); + assert_eq!( + ParseUrl::parse(url, "AUTHORITY", None)?, + Some("user:pwd@ftp.example.com".to_string()) + ); + Ok(()) + } + + #[test] + fn test_parse_path_root_is_empty_string() -> Result<()> { + let got = ParseUrl::parse("https://example.com/", "PATH", None)?; + assert_eq!(got, Some("".to_string())); + Ok(()) + } + + #[test] + fn test_parse_malformed_url_returns_error() -> Result<()> { + let got = ParseUrl::parse("notaurl", "HOST", None)?; + assert_eq!(got, None); + Ok(()) + } + + #[test] + fn test_spark_utf8_two_args() -> Result<()> { + let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]); + let parts = sa(&[Some("HOST"), Some("PATH")]); + + let out = spark_handled_parse_url(&[urls, parts], |x| x)?; + let out_sa = out.as_any().downcast_ref::().unwrap(); + + assert_eq!(out_sa.len(), 2); + assert_eq!(out_sa.value(0), "example.com"); + assert_eq!(out_sa.value(1), ""); + Ok(()) + } + + #[test] + fn test_spark_utf8_three_args_query_key() -> Result<()> { + let urls = sa(&[ + Some("https://example.com/a?x=1&y=2"), + Some("https://ex.com/?a=1"), + ]); + let parts = sa(&[Some("QUERY"), Some("QUERY")]); + let keys = sa(&[Some("y"), Some("b")]); + + let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?; + let out_sa = out.as_any().downcast_ref::().unwrap(); + + assert_eq!(out_sa.len(), 2); + assert_eq!(out_sa.value(0), "2"); + assert!(out_sa.is_null(1)); + Ok(()) + } + + #[test] + fn test_spark_userinfo_and_nulls() -> Result<()> { + let urls = sa(&[ + Some("ftp://user:pwd@ftp.example.com:21/files"), + Some("https://example.com"), + None, + ]); + let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]); + + let out = spark_handled_parse_url(&[urls, parts], |x| x)?; + let out_sa = out.as_any().downcast_ref::().unwrap(); + + assert_eq!(out_sa.len(), 3); + assert_eq!(out_sa.value(0), "user:pwd"); + assert!(out_sa.is_null(1)); + assert!(out_sa.is_null(2)); + Ok(()) + } + + #[test] + fn test_invalid_arg_count() { + let urls = sa(&[Some("https://example.com")]); + let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err(); + assert!(format!("{err}").contains("expects 2 or 3 arguments")); + + let parts = sa(&[Some("HOST")]); + let keys = sa(&[Some("x")]); + let err = + spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x) + .unwrap_err(); + assert!(format!("{err}").contains("expects 2 or 3 arguments")); + } + + #[test] + fn test_non_string_types_error() { + let urls = sa(&[Some("https://example.com")]); + let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef; + + let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("expects STRING arguments")); + } +} diff --git a/datafusion/spark/src/function/url/try_parse_url.rs b/datafusion/spark/src/function/url/try_parse_url.rs new file mode 100644 index 0000000000000..c04850f3a6bf0 --- /dev/null +++ b/datafusion/spark/src/function/url/try_parse_url.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use crate::function::url::parse_url::{spark_handled_parse_url, ParseUrl}; +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// TRY_PARSE_URL function for tolerant URL component extraction (never errors; returns NULL on invalid or missing parts). +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryParseUrl { + signature: Signature, +} + +impl Default for TryParseUrl { + fn default() -> Self { + Self::new() + } +} + +impl TryParseUrl { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::String(2), TypeSignature::String(3)], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TryParseUrl { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "try_parse_url" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let parse_url: ParseUrl = ParseUrl::new(); + parse_url.return_type(arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_try_parse_url, vec![])(&args) + } +} + +fn spark_try_parse_url(args: &[ArrayRef]) -> Result { + spark_handled_parse_url(args, |x| match x { + Err(_) => Ok(None), + result => result, + }) +} diff --git a/datafusion/sqllogictest/test_files/spark/url/parse_url.slt b/datafusion/sqllogictest/test_files/spark/url/parse_url.slt index cca07ceb6ff43..f2dc55f75598a 100644 --- a/datafusion/sqllogictest/test_files/spark/url/parse_url.slt +++ b/datafusion/sqllogictest/test_files/spark/url/parse_url.slt @@ -70,8 +70,108 @@ SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'U ---- userinfo -statement error parse_url expects 2 or 3 arguments, but got 1 -SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string); +query T +SELECT parse_url('https://example.com/a?x=1', 'QUERY', 'x'); +---- +1 + +query T +SELECT parse_url('https://example.com/a?x=1', 'query', 'x'); +---- +NULL + +query T +SELECT parse_url('www.example.com/path?x=1', 'HOST'); +---- +NULL + +query T +SELECT parse_url('www.example.com/path?x=1', 'host'); +---- +NULL + +query T +SELECT parse_url('https://example.com/?a=1', 'QUERY', 'b'); +---- +NULL + +query T +SELECT parse_url('https://example.com/?a=1', 'query', 'b'); +---- +NULL + +query T +SELECT parse_url('https://example.com/path#frag', 'REF'); +---- +frag + +query T +SELECT parse_url('https://example.com/path#frag', 'ref'); +---- +NULL + +query T +SELECT parse_url('ftp://user:pwd@ftp.example.com:21/files', 'USERINFO'); +---- +user:pwd + +query T +SELECT parse_url('ftp://user:pwd@ftp.example.com:21/files', 'userinfo'); +---- +NULL + +query T +SELECT parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'HOST'); +---- +[2001:db8::2] + +query T +SELECT parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'host'); +---- +NULL + +query T +SELECT parse_url('notaurl', 'HOST'); +---- +NULL + +query T +SELECT parse_url('notaurl', 'host'); +---- +NULL + +query T +SELECT parse_url('https://example.com', 'PATH'); +---- +(empty) + +query T +SELECT parse_url('https://example.com', 'path'); +---- +NULL + +query T +SELECT parse_url('https://example.com/a/b?x=1&y=2#frag', 'PROTOCOL'); +---- +https + +query T +SELECT parse_url('https://example.com/a/b?x=1&y=2#frag', 'protocol'); +---- +NULL + +query T +SELECT parse_url('https://ex.com/?Tag=ok', 'QUERY', 'tag'); +---- +NULL + +query T +SELECT parse_url('https://ex.com/?Tag=ok', 'query', 'tag'); +---- +NULL statement error 'parse_url' does not support zero arguments SELECT parse_url(); + +query error DataFusion error: Execution error: The url is invalid: inva lid://spark\.apache\.org/path\?query=1\. Use `try_parse_url` to tolerate invalid URL and return NULL instead\. SQLSTATE: 22P02 +SELECT parse_url('inva lid://spark.apache.org/path?query=1', 'QUERY'); diff --git a/datafusion/sqllogictest/test_files/spark/url/try_parse_url.slt b/datafusion/sqllogictest/test_files/spark/url/try_parse_url.slt new file mode 100644 index 0000000000000..403747c63c77c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/try_parse_url.slt @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/blob/b6095cc7fccaf016b47f009ba93b2357dc781a7d/python/pysail/tests/spark/function/test_try_parse_url.txt +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +query T +SELECT try_parse_url('https://example.com/a?x=1', 'QUERY', 'x'); +---- +1 + +query T +SELECT try_parse_url('https://example.com/a?x=1', 'query', 'x'); +---- +NULL + +query T +SELECT try_parse_url('www.example.com/path?x=1', 'HOST'); +---- +NULL + +query T +SELECT try_parse_url('www.example.com/path?x=1', 'host'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/?a=1', 'QUERY', 'b'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/?a=1', 'query', 'b'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/path#frag', 'REF'); +---- +frag + +query T +SELECT try_parse_url('https://example.com/path#frag', 'ref'); +---- +NULL + +query T +SELECT try_parse_url('ftp://user:pwd@ftp.example.com:21/files', 'USERINFO'); +---- +user:pwd + +query T +SELECT try_parse_url('ftp://user:pwd@ftp.example.com:21/files', 'userinfo'); +---- +NULL + +query T +SELECT try_parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'HOST'); +---- +[2001:db8::2] + +query T +SELECT try_parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'host'); +---- +NULL + +query T +SELECT try_parse_url('notaurl', 'HOST'); +---- +NULL + +query T +SELECT try_parse_url('notaurl', 'host'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com', 'PATH'); +---- +(empty) + +query T +SELECT try_parse_url('https://example.com', 'path'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/a/b?x=1&y=2#frag', 'PROTOCOL'); +---- +https + +query T +SELECT try_parse_url('https://example.com/a/b?x=1&y=2#frag', 'protocol'); +---- +NULL + +query T +SELECT try_parse_url('https://ex.com/?Tag=ok', 'QUERY', 'tag'); +---- +NULL + +query T +SELECT try_parse_url('https://ex.com/?Tag=ok', 'query', 'tag'); +---- +NULL + +query T +SELECT try_parse_url('inva lid://spark.apache.org/path?query=1', 'QUERY'); +---- +NULL