Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions processing/src/main/java/org/apache/druid/math/expr/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ default boolean areNumeric(List<Expr> args)
if (argType == null) {
continue;
}
numeric &= argType.isNumeric();
numeric = numeric && argType.isNumeric();
}
return numeric;
}
Expand Down Expand Up @@ -265,7 +265,7 @@ default boolean areSameTypes(List<Expr> args)
if (currentType == null) {
currentType = argType;
}
allSame &= Objects.equals(argType, currentType);
allSame = allSame && Objects.equals(argType, currentType);
}
return allSame;
}
Expand Down Expand Up @@ -302,7 +302,7 @@ default boolean areScalar(List<Expr> args)
if (argType == null) {
continue;
}
scalar &= argType.isPrimitive();
scalar = scalar && argType.isPrimitive();
}
return scalar;
}
Expand Down Expand Up @@ -330,7 +330,7 @@ default boolean canVectorize(List<Expr> args)
{
boolean canVectorize = true;
for (Expr arg : args) {
canVectorize &= arg.canVectorize(this);
canVectorize = canVectorize && arg.canVectorize(this);
}
return canVectorize;
}
Expand Down Expand Up @@ -498,7 +498,7 @@ public Set<String> getRequiredBindings()
/**
* Set of {@link IdentifierExpr#binding} which are used as scalar inputs to operators and functions.
*/
Set<String> getScalarBindings()
public Set<String> getScalarBindings()
{
return map(scalarVariables, IdentifierExpr::getBindingIfIdentifier);
}
Expand Down
63 changes: 48 additions & 15 deletions processing/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -1962,14 +1962,11 @@ public Set<Expr> getScalarInputs(List<Expr> args)
ExpressionType castTo = ExpressionType.fromString(
StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())
);
switch (castTo.getType()) {
case ARRAY:
return Collections.emptySet();
default:
return ImmutableSet.of(args.get(0));
if (!castTo.getType().isArray()) {
return ImmutableSet.of(args.get(0));
}
}
// unknown cast, can't safely assume either way
// either has array inputs or unknown inputs
return Collections.emptySet();
}

Expand All @@ -1980,16 +1977,11 @@ public Set<Expr> getArrayInputs(List<Expr> args)
ExpressionType castTo = ExpressionType.fromString(
StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())
);
switch (castTo.getType()) {
case LONG:
case DOUBLE:
case STRING:
return Collections.emptySet();
default:
return ImmutableSet.of(args.get(0));
if (castTo.getType().isArray()) {
return ImmutableSet.of(args.get(0));
}
}
// unknown cast, can't safely assume either way
// not an array, or unknown input types
return Collections.emptySet();
}

Expand Down Expand Up @@ -2087,6 +2079,13 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
{
return ExpressionTypeConversion.conditional(inspector, args.subList(1, 3));
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
// could potentially look for constants in the return positions and examine type...
return Collections.emptySet();
}
}

/**
Expand Down Expand Up @@ -2134,6 +2133,13 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
results.add(args.get(args.size() - 1));
return ExpressionTypeConversion.conditional(inspector, results);
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
// could potentially look for constants in the return positions and examine type...
return Collections.emptySet();
}
}

/**
Expand Down Expand Up @@ -2181,6 +2187,13 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
results.add(args.get(args.size() - 1));
return ExpressionTypeConversion.conditional(inspector, results);
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
// could potentially look for constants in the return positions and examine type...
return Collections.emptySet();
}
}

class NvlFunc implements Function
Expand Down Expand Up @@ -2222,6 +2235,13 @@ public <T> ExprVectorProcessor<T> asVectorProcessor(Expr.VectorInputBindingInspe
{
return VectorProcessors.nvl(inspector, args.get(0), args.get(1));
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
// output is same as input, doesn't matter the type
return Collections.emptySet();
}
}

class IsNullFunc implements Function
Expand Down Expand Up @@ -2263,6 +2283,13 @@ public <T> ExprVectorProcessor<T> asVectorProcessor(Expr.VectorInputBindingInspe
{
return VectorProcessors.isNull(inspector, args.get(0));
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
// null or not, doesnt matter if the inputs are arrays or scalars
return Collections.emptySet();
}
}

class IsNotNullFunc implements Function
Expand Down Expand Up @@ -2293,7 +2320,6 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
return ExpressionType.LONG;
}


@Override
public boolean canVectorize(Expr.InputBindingInspector inspector, List<Expr> args)
{
Expand All @@ -2305,6 +2331,13 @@ public <T> ExprVectorProcessor<T> asVectorProcessor(Expr.VectorInputBindingInspe
{
return VectorProcessors.isNotNull(inspector, args.get(0));
}

@Override
public Set<Expr> getScalarInputs(List<Expr> args)
{
// null or not, doesnt matter if the inputs are arrays or scalars
return Collections.emptySet();
}
}

class ConcatFunc implements Function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ public static ExpressionPlan plan(ColumnInspector inspector, Expr expression)
c -> !definitelyArray.contains(c)
&& definitelyMultiValued.contains(c)
&& !analysis.getArrayBindings().contains(c)
&& analysis.getScalarBindings().contains(c)
)
.collect(Collectors.toList());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,12 @@ public void testLiteralArraysExplicitDoubleParseException()
public void testFunctions()
{
validateParser("sqrt(x)", "(sqrt [x])", ImmutableList.of("x"));
validateParser("if(cond,then,else)", "(if [cond, then, else])", ImmutableList.of("cond", "else", "then"));
validateParser("if(cond,then,else)", "(if [cond, then, else])", ImmutableList.of("cond", "else", "then"), Collections.emptySet(), Collections.emptySet());
validateParser("case_simple(cond,then,else)", "(case_simple [cond, then, else])", ImmutableList.of("cond", "else", "then"), Collections.emptySet(), Collections.emptySet());
validateParser("case_searched(cond,then,else)", "(case_searched [cond, then, else])", ImmutableList.of("cond", "else", "then"), Collections.emptySet(), Collections.emptySet());
validateParser("nvl(x, fallback)", "(nvl [x, fallback])", ImmutableList.of("x", "fallback"), Collections.emptySet(), Collections.emptySet());
validateParser("nvl(x, 1)", "(nvl [x, 1])", ImmutableList.of("x"), ImmutableSet.of(), Collections.emptySet());
validateParser("nvl(x, [1,2,3])", "(nvl [x, [1, 2, 3]])", ImmutableList.of("x"), Collections.emptySet(), ImmutableSet.of());
validateParser("cast(x, 'STRING')", "(cast [x, STRING])", ImmutableList.of("x"));
validateParser("cast(x, 'LONG')", "(cast [x, LONG])", ImmutableList.of("x"));
validateParser("cast(x, 'DOUBLE')", "(cast [x, DOUBLE])", ImmutableList.of("x"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.apache.druid.sql.calcite.expression.builtin;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
Expand All @@ -29,7 +28,7 @@
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
Expand All @@ -39,45 +38,10 @@
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.joda.time.Period;

import java.util.Map;
import java.util.function.Function;

public class CastOperatorConversion implements SqlOperatorConversion
{
private static final Map<SqlTypeName, ExprType> EXPRESSION_TYPES;

static {
final ImmutableMap.Builder<SqlTypeName, ExprType> builder = ImmutableMap.builder();

for (SqlTypeName type : SqlTypeName.FRACTIONAL_TYPES) {
builder.put(type, ExprType.DOUBLE);
}

for (SqlTypeName type : SqlTypeName.INT_TYPES) {
builder.put(type, ExprType.LONG);
}

for (SqlTypeName type : SqlTypeName.STRING_TYPES) {
builder.put(type, ExprType.STRING);
}

// Booleans are treated as longs in Druid expressions, using two-value logic (positive = true, nonpositive = false).
builder.put(SqlTypeName.BOOLEAN, ExprType.LONG);

// Timestamps are treated as longs (millis since the epoch) in Druid expressions.
builder.put(SqlTypeName.TIMESTAMP, ExprType.LONG);
builder.put(SqlTypeName.DATE, ExprType.LONG);

for (SqlTypeName type : SqlTypeName.DAY_INTERVAL_TYPES) {
builder.put(type, ExprType.LONG);
}

for (SqlTypeName type : SqlTypeName.YEAR_INTERVAL_TYPES) {
builder.put(type, ExprType.LONG);
}

EXPRESSION_TYPES = builder.build();
}

@Override
public SqlOperator calciteOperator()
Expand All @@ -103,6 +67,7 @@ public DruidExpression toDruidExpression(
return null;
}


final SqlTypeName fromType = operand.getType().getSqlTypeName();
final SqlTypeName toType = rexNode.getType().getSqlTypeName();

Expand All @@ -118,28 +83,32 @@ public DruidExpression toDruidExpression(
} else {
// Handle other casts. If either type is ANY, use the other type instead. If both are ANY, this means nulls
// downstream, Druid will try its best
final ExprType fromExprType = SqlTypeName.ANY.equals(fromType)
? EXPRESSION_TYPES.get(toType)
: EXPRESSION_TYPES.get(fromType);
final ExprType toExprType = SqlTypeName.ANY.equals(toType)
? EXPRESSION_TYPES.get(fromType)
: EXPRESSION_TYPES.get(toType);

if (fromExprType == null || toExprType == null) {

final ColumnType fromDruidType = Calcites.getColumnTypeForRelDataType(operand.getType());
final ColumnType toDruidType = Calcites.getColumnTypeForRelDataType(rexNode.getType());

final ExpressionType fromExpressionType = SqlTypeName.ANY.equals(fromType)
? ExpressionType.fromColumnType(toDruidType)
: ExpressionType.fromColumnType(fromDruidType);
final ExpressionType toExpressionType = SqlTypeName.ANY.equals(toType)
? ExpressionType.fromColumnType(fromDruidType)
: ExpressionType.fromColumnType(toDruidType);

if (fromExpressionType == null || toExpressionType == null) {
// We have no runtime type for these SQL types.
return null;
}

final DruidExpression typeCastExpression;

if (fromExprType != toExprType) {
// Ignore casts for simple extractions (use Function.identity) since it is ok in many cases.
if (fromExpressionType.equals(toExpressionType)) {
// Ignore casts for simple extractions since it is ok in many cases.
typeCastExpression = operandExpression;
} else {
typeCastExpression = operandExpression.map(
Function.identity(),
expression -> StringUtils.format("CAST(%s, '%s')", expression, toExprType.toString())
expression -> StringUtils.format("CAST(%s, '%s')", expression, toExpressionType.asTypeString())
);
} else {
typeCastExpression = operandExpression;
}

if (toType == SqlTypeName.DATE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ public static boolean isLongType(SqlTypeName sqlTypeName)
return SqlTypeName.TIMESTAMP == sqlTypeName ||
SqlTypeName.DATE == sqlTypeName ||
SqlTypeName.BOOLEAN == sqlTypeName ||
SqlTypeName.INT_TYPES.contains(sqlTypeName);
SqlTypeName.INT_TYPES.contains(sqlTypeName) ||
SqlTypeName.DAY_INTERVAL_TYPES.contains(sqlTypeName) ||
SqlTypeName.YEAR_INTERVAL_TYPES.contains(sqlTypeName);
}

public static StringComparator getStringComparatorForRelDataType(RelDataType dataType)
Expand Down
Loading