Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Streams;
Expand Down Expand Up @@ -45,6 +46,27 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Abstract base class for converting between Calcite {@link SqlOperator}s and Substrait function
* invocations.
*
* <p>This class handles bidirectional conversion:
*
* <ul>
* <li><b>Calcite → Substrait:</b> Subclasses implement {@code convert()} methods to convert
* Calcite calls to Substrait function invocations
* <li><b>Substrait → Calcite:</b> {@link #getSqlOperatorFromSubstraitFunc} converts Substrait
* function keys to Calcite {@link SqlOperator}s
* </ul>
*
* <p>When multiple functions with the same name and signature are passed into the constructor, a
* <b>last-wins precedence strategy</b> is used for resolution. The last function in the input list
* takes precedence during Calcite to Substrait conversion.
*
* @param <F> the function type (ScalarFunctionVariant, AggregateFunctionVariant, etc.)
* @param <T> the return type for Calcite→Substrait conversion
* @param <C> the call type being converted
*/
public abstract class FunctionConverter<
F extends SimpleExtension.Function, T, C extends FunctionConverter.GenericCall> {

Expand All @@ -57,10 +79,32 @@ public abstract class FunctionConverter<

protected final Multimap<String, SqlOperator> substraitFuncKeyToSqlOperatorMap;

/**
* Creates a FunctionConverter with the given functions.
*
* <p>If there are multiple functions provided with the same name and signature (e.g., from
* different extension URNs), the last one in the list will be given precedence during Calcite to
* Substrait conversion.
*
* @param functions the list of function variants to register
* @param typeFactory the Calcite type factory
*/
public FunctionConverter(List<F> functions, RelDataTypeFactory typeFactory) {
this(functions, Collections.EMPTY_LIST, typeFactory, TypeConverter.DEFAULT);
}

/**
* Creates a FunctionConverter with the given functions and additional signatures.
*
* <p>If there are multiple functions provided with the same name and signature (e.g., from
* different extension URNs), the last one in the list will be given precedence during Calcite to
* Substrait conversion.
*
* @param functions the list of function variants to register
* @param additionalSignatures additional Calcite operator signatures to map
* @param typeFactory the Calcite type factory
* @param typeConverter the type converter to use
*/
public FunctionConverter(
List<F> functions,
List<FunctionMappings.Sig> additionalSignatures,
Expand All @@ -75,9 +119,9 @@ public FunctionConverter(
this.typeFactory = typeFactory;
this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create();

ArrayListMultimap<String, F> alm = ArrayListMultimap.<String, F>create();
ArrayListMultimap<String, F> nameToFn = ArrayListMultimap.<String, F>create();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to the PR technically, but I just thought that this was a clearer name.

for (F f : functions) {
alm.put(f.name().toLowerCase(Locale.ROOT), f);
nameToFn.put(f.name().toLowerCase(Locale.ROOT), f);
}

Multimap<String, FunctionMappings.Sig> calciteOperators =
Expand All @@ -87,21 +131,21 @@ public FunctionConverter(
FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create));
IdentityHashMap<SqlOperator, FunctionFinder> matcherMap =
new IdentityHashMap<SqlOperator, FunctionFinder>();
for (String key : alm.keySet()) {
for (String key : nameToFn.keySet()) {
Collection<Sig> sigs = calciteOperators.get(key);
if (sigs.isEmpty()) {
LOGGER.atDebug().log("No binding for function: {}", key);
}

for (Sig sig : sigs) {
List<F> implList = alm.get(key);
List<F> implList = nameToFn.get(key);
if (!implList.isEmpty()) {
matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList));
}
}
}

for (Entry<String, F> entry : alm.entries()) {
for (Entry<String, F> entry : nameToFn.entries()) {
String key = entry.getKey();
F func = entry.getValue();
for (FunctionMappings.Sig sig : calciteOperators.get(key)) {
Expand All @@ -112,6 +156,17 @@ public FunctionConverter(
this.signatures = matcherMap;
}

/**
* Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction).
*
* <p>Given a Substrait function key (e.g., "concat:str_str") and output type, this method finds
* the corresponding Calcite {@link SqlOperator}. When multiple operators match, the output type
* is used to disambiguate.
*
* @param key the Substrait function key (function name with type signature)
* @param outputType the expected output type
* @return the matching {@link SqlOperator}, or empty if no match found
*/
public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type outputType) {
Map<SqlOperator, TypeBasedResolver> resolver = getTypeBasedResolver();
Collection<SqlOperator> operators = substraitFuncKeyToSqlOperatorMap.get(key);
Expand Down Expand Up @@ -155,7 +210,7 @@ protected class FunctionFinder {
private final String substraitName;
private final SqlOperator operator;
private final List<F> functions;
private final Map<String, F> directMap;
private final ListMultimap<String, F> directMap;
private final Optional<SingularArgumentMatcher<F>> singularInputType;
private final Util.IntRange argRange;

Expand All @@ -168,7 +223,7 @@ public FunctionFinder(String substraitName, SqlOperator operator, List<F> functi
functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(),
functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt());
this.singularInputType = getSingularInputType(functions);
ImmutableMap.Builder<String, F> directMap = ImmutableMap.builder();
ImmutableListMultimap.Builder<String, F> directMap = ImmutableListMultimap.builder();
for (F func : functions) {
String key = func.key();
directMap.put(key, func);
Expand Down Expand Up @@ -342,13 +397,29 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
}
}

/**
* Converts a Calcite call to a Substrait function invocation (Calcite → Substrait direction).
*
* <p>This method tries to find a matching Substrait function for the given Calcite call using
* direct signature matching, type coercion, and least-restrictive type resolution.
*
* <p>If multiple registered function extensions have the same name and signature, the last one
* in the list passed into the constructor will be matched.
*
* @param call the Calcite call to match
* @param topLevelConverter function to convert RexNode operands to Substrait Expressions
* @return the matched Substrait function binding, or empty if no match found
*/
public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelConverter) {

/*
* Here the RexLiteral with an Enum value is mapped to String Literal.
* Not enough context here to construct a substrait EnumArg.
* Once a FunctionVariant is resolved we can map the String Literal
* to a EnumArg.
*
* Note that if there are multiple registered function extensions which can match a particular Call,
* the last one added to the extension collection will be matched.
Comment on lines +420 to +422
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether this comment should be also in a more prominent place for any consumers of isthmus and where they would find that. Would maybe at the top of the FunctionConverter class be more prominent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up adding quite a bit of comments. Let me know if anything added is incorrect or just too much :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, I was suggesting to Victor a couple weeks in another PR that we maybe should try to add javadocs whenever we touch code to slowly build up javadoc documentation so this is a great contribution

*/
List<RexNode> operandsList = call.getOperands().collect(Collectors.toList());
List<Expression> operands =
Expand All @@ -369,7 +440,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
.findFirst();

if (directMatchKey.isPresent()) {
F variant = directMap.get(directMatchKey.get());
List<F> variants = directMap.get(directMatchKey.get());
if (variants.isEmpty()) {

return Optional.empty();
}

F variant = variants.get(variants.size() - 1);
variant.validateOutputType(operands, outputType);
List<FunctionArg> funcArgs =
IntStream.range(0, operandsList.size())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.expression.Expression;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlTrimFunction.Flag;
import org.junit.jupiter.api.Test;

/** Tests to reproduce #562 */
class DuplicateFunctionUrnTest extends PlanTestBase {

static final SimpleExtension.ExtensionCollection collection1;
static final SimpleExtension.ExtensionCollection collection2;
static final SimpleExtension.ExtensionCollection collection;

static {
try {
String extensions1 = asString("extensions/functions_duplicate_urn1.yaml");
String extensions2 = asString("extensions/functions_duplicate_urn2.yaml");
collection1 =
SimpleExtension.load("urn:extension:io.substrait:functions_string", extensions1);
collection2 = SimpleExtension.load("urn:extension:com.domain:string", extensions2);
collection = collection1.merge(collection2);

// Verify that the merged collection contains duplicate concat functions with different URNs
// This is a precondition for the tests - if this fails, the tests don't make sense
List<SimpleExtension.ScalarFunctionVariant> concatFunctions =
collection.scalarFunctions().stream().filter(f -> f.name().equals("concat")).toList();

if (concatFunctions.size() != 2) {
throw new IllegalStateException(
"Expected 2 concat functions in merged collection, but found: "
+ concatFunctions.size());
}

String urn1 = concatFunctions.get(0).getAnchor().urn();
String urn2 = concatFunctions.get(1).getAnchor().urn();
if (urn1.equals(urn2)) {
throw new IllegalStateException(
"Expected different URNs for the two concat functions, but both were: " + urn1);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@Test
void testDuplicateFunctionWithDifferentUrns() {
assertDoesNotThrow(
() -> new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory));
}

@Test
void testDuplicateAggregateFunctionWithDifferentUrns() {
assertDoesNotThrow(
() -> new AggregateFunctionConverter(collection.aggregateFunctions(), typeFactory));
}

@Test
void testDuplicateWindowFunctionWithDifferentUrns() {
assertDoesNotThrow(
() -> new WindowFunctionConverter(collection.windowFunctions(), typeFactory));
}

@Test
void testMergeOrderDeterminesFunctionPrecedence() {
// This test verifies that when multiple extension collections contain functions with
// the same name and signature but different URNs, the merge order determines precedence.
// The FunctionConverter uses a "last-wins" strategy: the last function added to the
// extension collection will be matched when converting from Calcite to Substrait.

SimpleExtension.ExtensionCollection reverseCollection = collection2.merge(collection1);
ScalarFunctionConverter converterA =
new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory);
ScalarFunctionConverter converterB =
new ScalarFunctionConverter(reverseCollection.scalarFunctions(), typeFactory);

RexBuilder rexBuilder = new RexBuilder(typeFactory);
RexCall concatCall =
(RexCall)
rexBuilder.makeCall(
SqlStdOperatorTable.CONCAT,
rexBuilder.makeLiteral("hello"),
rexBuilder.makeLiteral("world"));

// Create a simple topLevelConverter that converts literals to Substrait expressions
java.util.function.Function<RexNode, Expression> topLevelConverter =
rexNode -> {
org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode;
return Expression.StrLiteral.builder()
.value(lit.getValueAs(String.class))
.nullable(false)
.build();
};

Optional<Expression> exprA = converterA.convert(concatCall, topLevelConverter);
Optional<Expression> exprB = converterB.convert(concatCall, topLevelConverter);

Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get();
Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get();

assertEquals(
"extension:com.domain:string",
funcA.declaration().getAnchor().urn(),
"converterA should use last concat function (from collection2)");

assertEquals(
"extension:io.substrait:functions_string",
funcB.declaration().getAnchor().urn(),
"converterB should use last concat function (from collection1)");
}

@Test
void testLtrimMergeOrderWithDefaultExtensions() {
// This test verifies precedence between a custom ltrim (from collection2 with
// extension:com.domain:string) and the default extension catalog's ltrim
// (extension:io.substrait:functions_string).
// The FunctionConverter uses a "last-wins" strategy.

// Merge default extensions with collection2 - collection2's ltrim should be last
SimpleExtension.ExtensionCollection defaultWithCustom = extensions.merge(collection2);

// Merge collection2 with default extensions - default ltrim should be last
SimpleExtension.ExtensionCollection customWithDefault = collection2.merge(extensions);

ScalarFunctionConverter converterA =
new ScalarFunctionConverter(defaultWithCustom.scalarFunctions(), typeFactory);
ScalarFunctionConverter converterB =
new ScalarFunctionConverter(customWithDefault.scalarFunctions(), typeFactory);

// Create a TRIM(LEADING ' ' FROM 'test') call which uses TrimFunctionMapper to map to ltrim
RexBuilder rexBuilder = new RexBuilder(typeFactory);
RexCall trimCall =
(RexCall)
rexBuilder.makeCall(
SqlStdOperatorTable.TRIM,
rexBuilder.makeFlag(Flag.LEADING),
rexBuilder.makeLiteral(" "),
rexBuilder.makeLiteral("test"));

java.util.function.Function<RexNode, Expression> topLevelConverter =
rexNode -> {
org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode;
Object value = lit.getValue();
if (value == null) {
return Expression.StrLiteral.builder().value("").nullable(true).build();
}
// Convert any literal value to string
return Expression.StrLiteral.builder().value(value.toString()).nullable(false).build();
};

Optional<Expression> exprA = converterA.convert(trimCall, topLevelConverter);
Optional<Expression> exprB = converterB.convert(trimCall, topLevelConverter);

Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get();
// converterA should use collection2's custom ltrim (last)
assertEquals(
"extension:com.domain:string",
funcA.declaration().getAnchor().urn(),
"converterA should use last ltrim (custom from collection2)");

Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get();
// converterB should use default extensions' ltrim (last)
assertEquals(
"extension:io.substrait:functions_string",
funcB.declaration().getAnchor().urn(),
"converterB should use last ltrim (from default extensions)");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
%YAML 1.2
---
urn: extension:io.substrait:functions_string

scalar_functions:
- name: "concat"
description: "concatenate strings"
impls:
- args:
- name: str1
value: string
- name: str2
value: string
variadic:
min: 0
return: string
Loading