diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 851230a36..d738ef157 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -2,7 +2,8 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import com.google.common.collect.Streams; @@ -45,6 +46,27 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * Abstract base class for converting between Calcite {@link SqlOperator}s and Substrait function + * invocations. + * + *

This class handles bidirectional conversion: + * + *

+ * + *

When multiple functions with the same name and signature are passed into the constructor, a + * last-wins precedence strategy is used for resolution. The last function in the input list + * takes precedence during Calcite to Substrait conversion. + * + * @param the function type (ScalarFunctionVariant, AggregateFunctionVariant, etc.) + * @param the return type for Calcite→Substrait conversion + * @param the call type being converted + */ public abstract class FunctionConverter< F extends SimpleExtension.Function, T, C extends FunctionConverter.GenericCall> { @@ -57,10 +79,32 @@ public abstract class FunctionConverter< protected final Multimap substraitFuncKeyToSqlOperatorMap; + /** + * Creates a FunctionConverter with the given functions. + * + *

If there are multiple functions provided with the same name and signature (e.g., from + * different extension URNs), the last one in the list will be given precedence during Calcite to + * Substrait conversion. + * + * @param functions the list of function variants to register + * @param typeFactory the Calcite type factory + */ public FunctionConverter(List functions, RelDataTypeFactory typeFactory) { this(functions, Collections.EMPTY_LIST, typeFactory, TypeConverter.DEFAULT); } + /** + * Creates a FunctionConverter with the given functions and additional signatures. + * + *

If there are multiple functions provided with the same name and signature (e.g., from + * different extension URNs), the last one in the list will be given precedence during Calcite to + * Substrait conversion. + * + * @param functions the list of function variants to register + * @param additionalSignatures additional Calcite operator signatures to map + * @param typeFactory the Calcite type factory + * @param typeConverter the type converter to use + */ public FunctionConverter( List functions, List additionalSignatures, @@ -75,9 +119,9 @@ public FunctionConverter( this.typeFactory = typeFactory; this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create(); - ArrayListMultimap alm = ArrayListMultimap.create(); + ArrayListMultimap nameToFn = ArrayListMultimap.create(); for (F f : functions) { - alm.put(f.name().toLowerCase(Locale.ROOT), f); + nameToFn.put(f.name().toLowerCase(Locale.ROOT), f); } Multimap calciteOperators = @@ -87,21 +131,21 @@ public FunctionConverter( FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create)); IdentityHashMap matcherMap = new IdentityHashMap(); - for (String key : alm.keySet()) { + for (String key : nameToFn.keySet()) { Collection sigs = calciteOperators.get(key); if (sigs.isEmpty()) { LOGGER.atDebug().log("No binding for function: {}", key); } for (Sig sig : sigs) { - List implList = alm.get(key); + List implList = nameToFn.get(key); if (!implList.isEmpty()) { matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList)); } } } - for (Entry entry : alm.entries()) { + for (Entry entry : nameToFn.entries()) { String key = entry.getKey(); F func = entry.getValue(); for (FunctionMappings.Sig sig : calciteOperators.get(key)) { @@ -112,6 +156,17 @@ public FunctionConverter( this.signatures = matcherMap; } + /** + * Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction). + * + *

Given a Substrait function key (e.g., "concat:str_str") and output type, this method finds + * the corresponding Calcite {@link SqlOperator}. When multiple operators match, the output type + * is used to disambiguate. + * + * @param key the Substrait function key (function name with type signature) + * @param outputType the expected output type + * @return the matching {@link SqlOperator}, or empty if no match found + */ public Optional getSqlOperatorFromSubstraitFunc(String key, Type outputType) { Map resolver = getTypeBasedResolver(); Collection operators = substraitFuncKeyToSqlOperatorMap.get(key); @@ -155,7 +210,7 @@ protected class FunctionFinder { private final String substraitName; private final SqlOperator operator; private final List functions; - private final Map directMap; + private final ListMultimap directMap; private final Optional> singularInputType; private final Util.IntRange argRange; @@ -168,7 +223,7 @@ public FunctionFinder(String substraitName, SqlOperator operator, List functi functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(), functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt()); this.singularInputType = getSingularInputType(functions); - ImmutableMap.Builder directMap = ImmutableMap.builder(); + ImmutableListMultimap.Builder directMap = ImmutableListMultimap.builder(); for (F func : functions) { String key = func.key(); directMap.put(key, func); @@ -342,6 +397,19 @@ private Stream matchKeys(List rexOperands, List opTypes } } + /** + * Converts a Calcite call to a Substrait function invocation (Calcite → Substrait direction). + * + *

This method tries to find a matching Substrait function for the given Calcite call using + * direct signature matching, type coercion, and least-restrictive type resolution. + * + *

If multiple registered function extensions have the same name and signature, the last one + * in the list passed into the constructor will be matched. + * + * @param call the Calcite call to match + * @param topLevelConverter function to convert RexNode operands to Substrait Expressions + * @return the matched Substrait function binding, or empty if no match found + */ public Optional attemptMatch(C call, Function topLevelConverter) { /* @@ -349,6 +417,9 @@ public Optional attemptMatch(C call, Function topLevelCo * Not enough context here to construct a substrait EnumArg. * Once a FunctionVariant is resolved we can map the String Literal * to a EnumArg. + * + * Note that if there are multiple registered function extensions which can match a particular Call, + * the last one added to the extension collection will be matched. */ List operandsList = call.getOperands().collect(Collectors.toList()); List operands = @@ -369,7 +440,13 @@ public Optional attemptMatch(C call, Function topLevelCo .findFirst(); if (directMatchKey.isPresent()) { - F variant = directMap.get(directMatchKey.get()); + List variants = directMap.get(directMatchKey.get()); + if (variants.isEmpty()) { + + return Optional.empty(); + } + + F variant = variants.get(variants.size() - 1); variant.validateOutputType(operands, outputType); List funcArgs = IntStream.range(0, operandsList.size()) diff --git a/isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java b/isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java new file mode 100644 index 000000000..1d8cca4cc --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java @@ -0,0 +1,182 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.expression.Expression; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Optional; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.fun.SqlTrimFunction.Flag; +import org.junit.jupiter.api.Test; + +/** Tests to reproduce #562 */ +class DuplicateFunctionUrnTest extends PlanTestBase { + + static final SimpleExtension.ExtensionCollection collection1; + static final SimpleExtension.ExtensionCollection collection2; + static final SimpleExtension.ExtensionCollection collection; + + static { + try { + String extensions1 = asString("extensions/functions_duplicate_urn1.yaml"); + String extensions2 = asString("extensions/functions_duplicate_urn2.yaml"); + collection1 = + SimpleExtension.load("urn:extension:io.substrait:functions_string", extensions1); + collection2 = SimpleExtension.load("urn:extension:com.domain:string", extensions2); + collection = collection1.merge(collection2); + + // Verify that the merged collection contains duplicate concat functions with different URNs + // This is a precondition for the tests - if this fails, the tests don't make sense + List concatFunctions = + collection.scalarFunctions().stream().filter(f -> f.name().equals("concat")).toList(); + + if (concatFunctions.size() != 2) { + throw new IllegalStateException( + "Expected 2 concat functions in merged collection, but found: " + + concatFunctions.size()); + } + + String urn1 = concatFunctions.get(0).getAnchor().urn(); + String urn2 = concatFunctions.get(1).getAnchor().urn(); + if (urn1.equals(urn2)) { + throw new IllegalStateException( + "Expected different URNs for the two concat functions, but both were: " + urn1); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Test + void testDuplicateFunctionWithDifferentUrns() { + assertDoesNotThrow( + () -> new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory)); + } + + @Test + void testDuplicateAggregateFunctionWithDifferentUrns() { + assertDoesNotThrow( + () -> new AggregateFunctionConverter(collection.aggregateFunctions(), typeFactory)); + } + + @Test + void testDuplicateWindowFunctionWithDifferentUrns() { + assertDoesNotThrow( + () -> new WindowFunctionConverter(collection.windowFunctions(), typeFactory)); + } + + @Test + void testMergeOrderDeterminesFunctionPrecedence() { + // This test verifies that when multiple extension collections contain functions with + // the same name and signature but different URNs, the merge order determines precedence. + // The FunctionConverter uses a "last-wins" strategy: the last function added to the + // extension collection will be matched when converting from Calcite to Substrait. + + SimpleExtension.ExtensionCollection reverseCollection = collection2.merge(collection1); + ScalarFunctionConverter converterA = + new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory); + ScalarFunctionConverter converterB = + new ScalarFunctionConverter(reverseCollection.scalarFunctions(), typeFactory); + + RexBuilder rexBuilder = new RexBuilder(typeFactory); + RexCall concatCall = + (RexCall) + rexBuilder.makeCall( + SqlStdOperatorTable.CONCAT, + rexBuilder.makeLiteral("hello"), + rexBuilder.makeLiteral("world")); + + // Create a simple topLevelConverter that converts literals to Substrait expressions + java.util.function.Function topLevelConverter = + rexNode -> { + org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; + return Expression.StrLiteral.builder() + .value(lit.getValueAs(String.class)) + .nullable(false) + .build(); + }; + + Optional exprA = converterA.convert(concatCall, topLevelConverter); + Optional exprB = converterB.convert(concatCall, topLevelConverter); + + Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get(); + Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get(); + + assertEquals( + "extension:com.domain:string", + funcA.declaration().getAnchor().urn(), + "converterA should use last concat function (from collection2)"); + + assertEquals( + "extension:io.substrait:functions_string", + funcB.declaration().getAnchor().urn(), + "converterB should use last concat function (from collection1)"); + } + + @Test + void testLtrimMergeOrderWithDefaultExtensions() { + // This test verifies precedence between a custom ltrim (from collection2 with + // extension:com.domain:string) and the default extension catalog's ltrim + // (extension:io.substrait:functions_string). + // The FunctionConverter uses a "last-wins" strategy. + + // Merge default extensions with collection2 - collection2's ltrim should be last + SimpleExtension.ExtensionCollection defaultWithCustom = extensions.merge(collection2); + + // Merge collection2 with default extensions - default ltrim should be last + SimpleExtension.ExtensionCollection customWithDefault = collection2.merge(extensions); + + ScalarFunctionConverter converterA = + new ScalarFunctionConverter(defaultWithCustom.scalarFunctions(), typeFactory); + ScalarFunctionConverter converterB = + new ScalarFunctionConverter(customWithDefault.scalarFunctions(), typeFactory); + + // Create a TRIM(LEADING ' ' FROM 'test') call which uses TrimFunctionMapper to map to ltrim + RexBuilder rexBuilder = new RexBuilder(typeFactory); + RexCall trimCall = + (RexCall) + rexBuilder.makeCall( + SqlStdOperatorTable.TRIM, + rexBuilder.makeFlag(Flag.LEADING), + rexBuilder.makeLiteral(" "), + rexBuilder.makeLiteral("test")); + + java.util.function.Function topLevelConverter = + rexNode -> { + org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; + Object value = lit.getValue(); + if (value == null) { + return Expression.StrLiteral.builder().value("").nullable(true).build(); + } + // Convert any literal value to string + return Expression.StrLiteral.builder().value(value.toString()).nullable(false).build(); + }; + + Optional exprA = converterA.convert(trimCall, topLevelConverter); + Optional exprB = converterB.convert(trimCall, topLevelConverter); + + Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get(); + // converterA should use collection2's custom ltrim (last) + assertEquals( + "extension:com.domain:string", + funcA.declaration().getAnchor().urn(), + "converterA should use last ltrim (custom from collection2)"); + + Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get(); + // converterB should use default extensions' ltrim (last) + assertEquals( + "extension:io.substrait:functions_string", + funcB.declaration().getAnchor().urn(), + "converterB should use last ltrim (from default extensions)"); + } +} diff --git a/isthmus/src/test/resources/extensions/functions_duplicate_urn1.yaml b/isthmus/src/test/resources/extensions/functions_duplicate_urn1.yaml new file mode 100644 index 000000000..c7b9d98f8 --- /dev/null +++ b/isthmus/src/test/resources/extensions/functions_duplicate_urn1.yaml @@ -0,0 +1,16 @@ +%YAML 1.2 +--- +urn: extension:io.substrait:functions_string + +scalar_functions: + - name: "concat" + description: "concatenate strings" + impls: + - args: + - name: str1 + value: string + - name: str2 + value: string + variadic: + min: 0 + return: string diff --git a/isthmus/src/test/resources/extensions/functions_duplicate_urn2.yaml b/isthmus/src/test/resources/extensions/functions_duplicate_urn2.yaml new file mode 100644 index 000000000..85d9d052d --- /dev/null +++ b/isthmus/src/test/resources/extensions/functions_duplicate_urn2.yaml @@ -0,0 +1,25 @@ +%YAML 1.2 +--- +urn: extension:com.domain:string + +scalar_functions: + - name: "ltrim" + description: "left trim from custom domain" + impls: + - args: + - name: str + value: string + - name: chars + value: string + return: string + - name: "concat" + description: "concatenate strings from custom domain" + impls: + - args: + - name: str1 + value: string + - name: str2 + value: string + variadic: + min: 0 + return: string