Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

/** The {@link Procedure}s including all the stored procedures. */
Expand All @@ -62,6 +63,10 @@ public static ProcedureBuilder newBuilder(String name) {
return builderSupplier != null ? builderSupplier.get() : null;
}

public static Set<String> names() {
return BUILDERS.keySet();
}

private static Map<String, Supplier<ProcedureBuilder>> initProcedureBuilders() {
ImmutableMap.Builder<String, Supplier<ProcedureBuilder>> procedureBuilders =
ImmutableMap.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.spark.sql.catalyst.parser.extensions

import org.apache.paimon.spark.SparkProcedures

import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException}
Expand All @@ -34,6 +36,8 @@ import org.apache.spark.sql.types.{DataType, StructType}

import java.util.Locale

import scala.collection.JavaConverters._

/* This file is based on source code from the Iceberg Project (http://iceberg.apache.org/), licensed by the Apache
* Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership. */
Expand Down Expand Up @@ -100,8 +104,15 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf
.replaceAll("--.*?\\n", " ")
.replaceAll("\\s+", " ")
.replaceAll("/\\*.*?\\*/", " ")
.replaceAll("`", "")
.trim()
normalized.startsWith("call") || isTagRefDdl(normalized)
isPaimonProcedure(normalized) || isTagRefDdl(normalized)
}

// All builtin paimon procedures are under the 'sys' namespace
private def isPaimonProcedure(normalized: String): Boolean = {
normalized.startsWith("call") &&
SparkProcedures.names().asScala.map("sys." + _).exists(normalized.contains)
}

private def isTagRefDdl(normalized: String): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,37 @@ public void stopSparkSession() {
}
}

@Test
public void testDelegateUnsupportedProcedure() {
assertThatThrownBy(() -> parser.parsePlan("CALL cat.d.t()"))
.isInstanceOf(ParseException.class)
.satisfies(
exception -> {
ParseException parseException = (ParseException) exception;
assertThat(parseException.getErrorClass())
.isEqualTo("PARSE_SYNTAX_ERROR");
assertThat(parseException.getMessageParameters().get("error"))
.isEqualTo("'CALL'");
});
}

@Test
public void testCallWithBackticks() throws ParseException {
PaimonCallStatement call =
(PaimonCallStatement) parser.parsePlan("CALL cat.`sys`.`rollback`()");
assertThat(JavaConverters.seqAsJavaList(call.name()))
.isEqualTo(Arrays.asList("cat", "sys", "rollback"));
assertThat(call.args().size()).isEqualTo(0);
}

@Test
public void testCallWithNamedArguments() throws ParseException {
PaimonCallStatement callStatement =
(PaimonCallStatement)
parser.parsePlan(
"CALL catalog.system.named_args_func(arg1 => 1, arg2 => 'test', arg3 => true)");
"CALL catalog.sys.rollback(arg1 => 1, arg2 => 'test', arg3 => true)");
assertThat(JavaConverters.seqAsJavaList(callStatement.name()))
.isEqualTo(Arrays.asList("catalog", "system", "named_args_func"));
.isEqualTo(Arrays.asList("catalog", "sys", "rollback"));
assertThat(callStatement.args().size()).isEqualTo(3);
assertArgument(callStatement, 0, "arg1", 1, DataTypes.IntegerType);
assertArgument(callStatement, 1, "arg2", "test", DataTypes.StringType);
Expand All @@ -98,11 +121,11 @@ public void testCallWithPositionalArguments() throws ParseException {
PaimonCallStatement callStatement =
(PaimonCallStatement)
parser.parsePlan(
"CALL catalog.system.positional_args_func(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4"
"CALL catalog.sys.rollback(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4"
+ ".0e1,500e-1BD, "
+ "TIMESTAMP '2017-02-03T10:37:30.00Z')");
assertThat(JavaConverters.seqAsJavaList(callStatement.name()))
.isEqualTo(Arrays.asList("catalog", "system", "positional_args_func"));
.isEqualTo(Arrays.asList("catalog", "sys", "rollback"));
assertThat(callStatement.args().size()).isEqualTo(8);
assertArgument(callStatement, 0, 1, DataTypes.IntegerType);
assertArgument(
Expand All @@ -127,19 +150,19 @@ public void testCallWithPositionalArguments() throws ParseException {
public void testCallWithMixedArguments() throws ParseException {
PaimonCallStatement callStatement =
(PaimonCallStatement)
parser.parsePlan("CALL catalog.system.mixed_function(arg1 => 1, 'test')");
parser.parsePlan("CALL catalog.sys.rollback(arg1 => 1, 'test')");
assertThat(JavaConverters.seqAsJavaList(callStatement.name()))
.isEqualTo(Arrays.asList("catalog", "system", "mixed_function"));
.isEqualTo(Arrays.asList("catalog", "sys", "rollback"));
assertThat(callStatement.args().size()).isEqualTo(2);
assertArgument(callStatement, 0, "arg1", 1, DataTypes.IntegerType);
assertArgument(callStatement, 1, "test", DataTypes.StringType);
}

@Test
public void testCallWithParseException() {
assertThatThrownBy(() -> parser.parsePlan("CALL catalog.system func abc"))
assertThatThrownBy(() -> parser.parsePlan("CALL catalog.sys.rollback abc"))
.isInstanceOf(PaimonParseException.class)
.hasMessageContaining("missing '(' at 'func'");
.hasMessageContaining("missing '(' at 'abc'");
}

private void assertArgument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
package org.apache.paimon.spark.procedure

import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.paimon.spark.analysis.NoSuchProcedureException

import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.parser.extensions.PaimonParseException
import org.assertj.core.api.Assertions.assertThatThrownBy

Expand All @@ -32,7 +32,7 @@ abstract class ProcedureTestBase extends PaimonSparkTestBase {
|""".stripMargin)

assertThatThrownBy(() => spark.sql("CALL sys.unknown_procedure(table => 'test.T')"))
.isInstanceOf(classOf[NoSuchProcedureException])
.isInstanceOf(classOf[ParseException])
}

test(s"test parse exception") {
Expand Down
Loading