-
Notifications
You must be signed in to change notification settings - Fork 95
fix: prevents exception on construction of FunctionConverter with duplicate functions
#564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6573680
1dc0ec9
f97fde7
03b0f24
afa6fdf
db13425
d21ef7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,8 @@ | |
|
|
||
| import com.google.common.collect.ArrayListMultimap; | ||
| import com.google.common.collect.ImmutableList; | ||
| import com.google.common.collect.ImmutableMap; | ||
| import com.google.common.collect.ImmutableListMultimap; | ||
| import com.google.common.collect.ListMultimap; | ||
| import com.google.common.collect.Multimap; | ||
| import com.google.common.collect.Multimaps; | ||
| import com.google.common.collect.Streams; | ||
|
|
@@ -45,6 +46,27 @@ | |
| import org.slf4j.Logger; | ||
| import org.slf4j.LoggerFactory; | ||
|
|
||
| /** | ||
| * Abstract base class for converting between Calcite {@link SqlOperator}s and Substrait function | ||
| * invocations. | ||
| * | ||
| * <p>This class handles bidirectional conversion: | ||
| * | ||
| * <ul> | ||
| * <li><b>Calcite → Substrait:</b> Subclasses implement {@code convert()} methods to convert | ||
| * Calcite calls to Substrait function invocations | ||
| * <li><b>Substrait → Calcite:</b> {@link #getSqlOperatorFromSubstraitFunc} converts Substrait | ||
| * function keys to Calcite {@link SqlOperator}s | ||
| * </ul> | ||
| * | ||
| * <p>When multiple functions with the same name and signature are passed into the constructor, a | ||
| * <b>last-wins precedence strategy</b> is used for resolution. The last function in the input list | ||
| * takes precedence during Calcite to Substrait conversion. | ||
| * | ||
| * @param <F> the function type (ScalarFunctionVariant, AggregateFunctionVariant, etc.) | ||
| * @param <T> the return type for Calcite→Substrait conversion | ||
| * @param <C> the call type being converted | ||
| */ | ||
| public abstract class FunctionConverter< | ||
| F extends SimpleExtension.Function, T, C extends FunctionConverter.GenericCall> { | ||
|
|
||
|
|
@@ -57,10 +79,32 @@ public abstract class FunctionConverter< | |
|
|
||
| protected final Multimap<String, SqlOperator> substraitFuncKeyToSqlOperatorMap; | ||
|
|
||
| /** | ||
| * Creates a FunctionConverter with the given functions. | ||
| * | ||
| * <p>If there are multiple functions provided with the same name and signature (e.g., from | ||
| * different extension URNs), the last one in the list will be given precedence during Calcite to | ||
| * Substrait conversion. | ||
| * | ||
| * @param functions the list of function variants to register | ||
| * @param typeFactory the Calcite type factory | ||
| */ | ||
| public FunctionConverter(List<F> functions, RelDataTypeFactory typeFactory) { | ||
| this(functions, Collections.EMPTY_LIST, typeFactory, TypeConverter.DEFAULT); | ||
| } | ||
|
|
||
| /** | ||
| * Creates a FunctionConverter with the given functions and additional signatures. | ||
| * | ||
| * <p>If there are multiple functions provided with the same name and signature (e.g., from | ||
| * different extension URNs), the last one in the list will be given precedence during Calcite to | ||
| * Substrait conversion. | ||
| * | ||
| * @param functions the list of function variants to register | ||
| * @param additionalSignatures additional Calcite operator signatures to map | ||
| * @param typeFactory the Calcite type factory | ||
| * @param typeConverter the type converter to use | ||
| */ | ||
| public FunctionConverter( | ||
| List<F> functions, | ||
| List<FunctionMappings.Sig> additionalSignatures, | ||
|
|
@@ -75,9 +119,9 @@ public FunctionConverter( | |
| this.typeFactory = typeFactory; | ||
| this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create(); | ||
|
|
||
| ArrayListMultimap<String, F> alm = ArrayListMultimap.<String, F>create(); | ||
| ArrayListMultimap<String, F> nameToFn = ArrayListMultimap.<String, F>create(); | ||
| for (F f : functions) { | ||
| alm.put(f.name().toLowerCase(Locale.ROOT), f); | ||
| nameToFn.put(f.name().toLowerCase(Locale.ROOT), f); | ||
| } | ||
|
|
||
| Multimap<String, FunctionMappings.Sig> calciteOperators = | ||
|
|
@@ -87,21 +131,21 @@ public FunctionConverter( | |
| FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create)); | ||
| IdentityHashMap<SqlOperator, FunctionFinder> matcherMap = | ||
| new IdentityHashMap<SqlOperator, FunctionFinder>(); | ||
| for (String key : alm.keySet()) { | ||
| for (String key : nameToFn.keySet()) { | ||
| Collection<Sig> sigs = calciteOperators.get(key); | ||
| if (sigs.isEmpty()) { | ||
| LOGGER.atDebug().log("No binding for function: {}", key); | ||
| } | ||
|
|
||
| for (Sig sig : sigs) { | ||
| List<F> implList = alm.get(key); | ||
| List<F> implList = nameToFn.get(key); | ||
| if (!implList.isEmpty()) { | ||
| matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| for (Entry<String, F> entry : alm.entries()) { | ||
| for (Entry<String, F> entry : nameToFn.entries()) { | ||
| String key = entry.getKey(); | ||
| F func = entry.getValue(); | ||
| for (FunctionMappings.Sig sig : calciteOperators.get(key)) { | ||
|
|
@@ -112,6 +156,17 @@ public FunctionConverter( | |
| this.signatures = matcherMap; | ||
| } | ||
|
|
||
| /** | ||
| * Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction). | ||
| * | ||
| * <p>Given a Substrait function key (e.g., "concat:str_str") and output type, this method finds | ||
| * the corresponding Calcite {@link SqlOperator}. When multiple operators match, the output type | ||
| * is used to disambiguate. | ||
| * | ||
| * @param key the Substrait function key (function name with type signature) | ||
| * @param outputType the expected output type | ||
| * @return the matching {@link SqlOperator}, or empty if no match found | ||
| */ | ||
benbellick marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type outputType) { | ||
| Map<SqlOperator, TypeBasedResolver> resolver = getTypeBasedResolver(); | ||
| Collection<SqlOperator> operators = substraitFuncKeyToSqlOperatorMap.get(key); | ||
|
|
@@ -155,7 +210,7 @@ protected class FunctionFinder { | |
| private final String substraitName; | ||
| private final SqlOperator operator; | ||
| private final List<F> functions; | ||
| private final Map<String, F> directMap; | ||
| private final ListMultimap<String, F> directMap; | ||
| private final Optional<SingularArgumentMatcher<F>> singularInputType; | ||
| private final Util.IntRange argRange; | ||
|
|
||
|
|
@@ -168,7 +223,7 @@ public FunctionFinder(String substraitName, SqlOperator operator, List<F> functi | |
| functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(), | ||
| functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt()); | ||
| this.singularInputType = getSingularInputType(functions); | ||
| ImmutableMap.Builder<String, F> directMap = ImmutableMap.builder(); | ||
| ImmutableListMultimap.Builder<String, F> directMap = ImmutableListMultimap.builder(); | ||
| for (F func : functions) { | ||
| String key = func.key(); | ||
| directMap.put(key, func); | ||
|
|
@@ -342,13 +397,29 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Converts a Calcite call to a Substrait function invocation (Calcite → Substrait direction). | ||
| * | ||
| * <p>This method tries to find a matching Substrait function for the given Calcite call using | ||
| * direct signature matching, type coercion, and least-restrictive type resolution. | ||
| * | ||
| * <p>If multiple registered function extensions have the same name and signature, the last one | ||
| * in the list passed into the constructor will be matched. | ||
| * | ||
| * @param call the Calcite call to match | ||
| * @param topLevelConverter function to convert RexNode operands to Substrait Expressions | ||
| * @return the matched Substrait function binding, or empty if no match found | ||
| */ | ||
| public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelConverter) { | ||
|
|
||
| /* | ||
| * Here the RexLiteral with an Enum value is mapped to String Literal. | ||
| * Not enough context here to construct a substrait EnumArg. | ||
| * Once a FunctionVariant is resolved we can map the String Literal | ||
| * to a EnumArg. | ||
| * | ||
| * Note that if there are multiple registered function extensions which can match a particular Call, | ||
| * the last one added to the extension collection will be matched. | ||
|
Comment on lines
+420
to
+422
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering whether this comment should be also in a more prominent place for any consumers of isthmus and where they would find that. Would maybe at the top of the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ended up adding quite a bit of comments. Let me know if anything added is incorrect or just too much :)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks great, I was suggesting to Victor a couple weeks in another PR that we maybe should try to add javadocs whenever we touch code to slowly build up javadoc documentation so this is a great contribution |
||
| */ | ||
| List<RexNode> operandsList = call.getOperands().collect(Collectors.toList()); | ||
| List<Expression> operands = | ||
|
|
@@ -369,7 +440,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo | |
| .findFirst(); | ||
|
|
||
| if (directMatchKey.isPresent()) { | ||
| F variant = directMap.get(directMatchKey.get()); | ||
| List<F> variants = directMap.get(directMatchKey.get()); | ||
| if (variants.isEmpty()) { | ||
|
|
||
| return Optional.empty(); | ||
| } | ||
|
|
||
| F variant = variants.get(variants.size() - 1); | ||
| variant.validateOutputType(operands, outputType); | ||
| List<FunctionArg> funcArgs = | ||
| IntStream.range(0, operandsList.size()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| package io.substrait.isthmus; | ||
|
|
||
| import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; | ||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
|
||
| import io.substrait.expression.Expression; | ||
| import io.substrait.extension.SimpleExtension; | ||
| import io.substrait.isthmus.expression.AggregateFunctionConverter; | ||
| import io.substrait.isthmus.expression.ScalarFunctionConverter; | ||
| import io.substrait.isthmus.expression.WindowFunctionConverter; | ||
| import java.io.IOException; | ||
| import java.io.UncheckedIOException; | ||
| import java.util.List; | ||
| import java.util.Optional; | ||
| import org.apache.calcite.rex.RexBuilder; | ||
| import org.apache.calcite.rex.RexCall; | ||
| import org.apache.calcite.rex.RexNode; | ||
| import org.apache.calcite.sql.fun.SqlStdOperatorTable; | ||
| import org.apache.calcite.sql.fun.SqlTrimFunction.Flag; | ||
| import org.junit.jupiter.api.Test; | ||
|
|
||
| /** Tests to reproduce #562 */ | ||
| class DuplicateFunctionUrnTest extends PlanTestBase { | ||
|
|
||
| static final SimpleExtension.ExtensionCollection collection1; | ||
| static final SimpleExtension.ExtensionCollection collection2; | ||
| static final SimpleExtension.ExtensionCollection collection; | ||
|
|
||
| static { | ||
| try { | ||
| String extensions1 = asString("extensions/functions_duplicate_urn1.yaml"); | ||
| String extensions2 = asString("extensions/functions_duplicate_urn2.yaml"); | ||
| collection1 = | ||
| SimpleExtension.load("urn:extension:io.substrait:functions_string", extensions1); | ||
| collection2 = SimpleExtension.load("urn:extension:com.domain:string", extensions2); | ||
| collection = collection1.merge(collection2); | ||
|
|
||
| // Verify that the merged collection contains duplicate concat functions with different URNs | ||
| // This is a precondition for the tests - if this fails, the tests don't make sense | ||
| List<SimpleExtension.ScalarFunctionVariant> concatFunctions = | ||
| collection.scalarFunctions().stream().filter(f -> f.name().equals("concat")).toList(); | ||
|
|
||
| if (concatFunctions.size() != 2) { | ||
| throw new IllegalStateException( | ||
| "Expected 2 concat functions in merged collection, but found: " | ||
| + concatFunctions.size()); | ||
| } | ||
|
|
||
| String urn1 = concatFunctions.get(0).getAnchor().urn(); | ||
| String urn2 = concatFunctions.get(1).getAnchor().urn(); | ||
| if (urn1.equals(urn2)) { | ||
| throw new IllegalStateException( | ||
| "Expected different URNs for the two concat functions, but both were: " + urn1); | ||
| } | ||
| } catch (IOException e) { | ||
| throw new UncheckedIOException(e); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| void testDuplicateFunctionWithDifferentUrns() { | ||
| assertDoesNotThrow( | ||
| () -> new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory)); | ||
| } | ||
|
|
||
| @Test | ||
| void testDuplicateAggregateFunctionWithDifferentUrns() { | ||
| assertDoesNotThrow( | ||
| () -> new AggregateFunctionConverter(collection.aggregateFunctions(), typeFactory)); | ||
| } | ||
|
|
||
| @Test | ||
| void testDuplicateWindowFunctionWithDifferentUrns() { | ||
| assertDoesNotThrow( | ||
| () -> new WindowFunctionConverter(collection.windowFunctions(), typeFactory)); | ||
| } | ||
|
|
||
| @Test | ||
| void testMergeOrderDeterminesFunctionPrecedence() { | ||
| // This test verifies that when multiple extension collections contain functions with | ||
| // the same name and signature but different URNs, the merge order determines precedence. | ||
| // The FunctionConverter uses a "last-wins" strategy: the last function added to the | ||
| // extension collection will be matched when converting from Calcite to Substrait. | ||
|
|
||
| SimpleExtension.ExtensionCollection reverseCollection = collection2.merge(collection1); | ||
| ScalarFunctionConverter converterA = | ||
| new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory); | ||
| ScalarFunctionConverter converterB = | ||
| new ScalarFunctionConverter(reverseCollection.scalarFunctions(), typeFactory); | ||
|
|
||
| RexBuilder rexBuilder = new RexBuilder(typeFactory); | ||
| RexCall concatCall = | ||
| (RexCall) | ||
| rexBuilder.makeCall( | ||
| SqlStdOperatorTable.CONCAT, | ||
| rexBuilder.makeLiteral("hello"), | ||
| rexBuilder.makeLiteral("world")); | ||
|
|
||
| // Create a simple topLevelConverter that converts literals to Substrait expressions | ||
| java.util.function.Function<RexNode, Expression> topLevelConverter = | ||
| rexNode -> { | ||
| org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; | ||
| return Expression.StrLiteral.builder() | ||
| .value(lit.getValueAs(String.class)) | ||
| .nullable(false) | ||
| .build(); | ||
| }; | ||
|
|
||
| Optional<Expression> exprA = converterA.convert(concatCall, topLevelConverter); | ||
| Optional<Expression> exprB = converterB.convert(concatCall, topLevelConverter); | ||
|
|
||
| Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get(); | ||
| Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get(); | ||
|
|
||
| assertEquals( | ||
| "extension:com.domain:string", | ||
| funcA.declaration().getAnchor().urn(), | ||
| "converterA should use last concat function (from collection2)"); | ||
|
|
||
| assertEquals( | ||
| "extension:io.substrait:functions_string", | ||
| funcB.declaration().getAnchor().urn(), | ||
| "converterB should use last concat function (from collection1)"); | ||
| } | ||
|
|
||
| @Test | ||
| void testLtrimMergeOrderWithDefaultExtensions() { | ||
| // This test verifies precedence between a custom ltrim (from collection2 with | ||
| // extension:com.domain:string) and the default extension catalog's ltrim | ||
| // (extension:io.substrait:functions_string). | ||
| // The FunctionConverter uses a "last-wins" strategy. | ||
|
|
||
| // Merge default extensions with collection2 - collection2's ltrim should be last | ||
| SimpleExtension.ExtensionCollection defaultWithCustom = extensions.merge(collection2); | ||
|
|
||
| // Merge collection2 with default extensions - default ltrim should be last | ||
| SimpleExtension.ExtensionCollection customWithDefault = collection2.merge(extensions); | ||
|
|
||
| ScalarFunctionConverter converterA = | ||
| new ScalarFunctionConverter(defaultWithCustom.scalarFunctions(), typeFactory); | ||
| ScalarFunctionConverter converterB = | ||
| new ScalarFunctionConverter(customWithDefault.scalarFunctions(), typeFactory); | ||
|
|
||
| // Create a TRIM(LEADING ' ' FROM 'test') call which uses TrimFunctionMapper to map to ltrim | ||
| RexBuilder rexBuilder = new RexBuilder(typeFactory); | ||
| RexCall trimCall = | ||
| (RexCall) | ||
| rexBuilder.makeCall( | ||
| SqlStdOperatorTable.TRIM, | ||
| rexBuilder.makeFlag(Flag.LEADING), | ||
| rexBuilder.makeLiteral(" "), | ||
| rexBuilder.makeLiteral("test")); | ||
|
|
||
| java.util.function.Function<RexNode, Expression> topLevelConverter = | ||
| rexNode -> { | ||
| org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; | ||
| Object value = lit.getValue(); | ||
| if (value == null) { | ||
| return Expression.StrLiteral.builder().value("").nullable(true).build(); | ||
| } | ||
| // Convert any literal value to string | ||
| return Expression.StrLiteral.builder().value(value.toString()).nullable(false).build(); | ||
| }; | ||
|
|
||
| Optional<Expression> exprA = converterA.convert(trimCall, topLevelConverter); | ||
| Optional<Expression> exprB = converterB.convert(trimCall, topLevelConverter); | ||
|
|
||
| Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get(); | ||
| // converterA should use collection2's custom ltrim (last) | ||
| assertEquals( | ||
| "extension:com.domain:string", | ||
| funcA.declaration().getAnchor().urn(), | ||
| "converterA should use last ltrim (custom from collection2)"); | ||
|
|
||
| Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get(); | ||
| // converterB should use default extensions' ltrim (last) | ||
| assertEquals( | ||
| "extension:io.substrait:functions_string", | ||
| funcB.declaration().getAnchor().urn(), | ||
| "converterB should use last ltrim (from default extensions)"); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| %YAML 1.2 | ||
| --- | ||
| urn: extension:io.substrait:functions_string | ||
|
|
||
| scalar_functions: | ||
| - name: "concat" | ||
| description: "concatenate strings" | ||
| impls: | ||
| - args: | ||
| - name: str1 | ||
| value: string | ||
| - name: str2 | ||
| value: string | ||
| variadic: | ||
| min: 0 | ||
| return: string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unrelated to the PR technically, but I just thought that this was a clearer name.