Skip to content
35 changes: 32 additions & 3 deletions core/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,11 @@ protected ExprEval eval(double param)

class Round implements Function
{
//CHECKSTYLE.OFF: Regexp
private static final BigDecimal MAX_FINITE_VALUE = BigDecimal.valueOf(Double.MAX_VALUE);
private static final BigDecimal MIN_FINITE_VALUE = BigDecimal.valueOf(-1 * Double.MAX_VALUE);
//CHECKSTYLE.ON: Regexp

@Override
public String name()
{
Expand All @@ -705,15 +710,23 @@ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
ExprEval value1 = args.get(0).eval(bindings);
if (value1.type() != ExprType.LONG && value1.type() != ExprType.DOUBLE) {
throw new IAE("The first argument to the function[%s] should be integer or double type but get the %s type", name(), value1.type());
throw new IAE(
"The first argument to the function[%s] should be integer or double type but got the type: %s",
name(),
value1.type()
);
}

if (args.size() == 1) {
return eval(value1);
} else {
ExprEval value2 = args.get(1).eval(bindings);
if (value2.type() != ExprType.LONG) {
throw new IAE("The second argument to the function[%s] should be integer type but get the %s type", name(), value2.type());
throw new IAE(
"The second argument to the function[%s] should be integer type but got the type: %s",
name(),
value2.type()
);
}
return eval(value1, value2.asInt());
}
Expand All @@ -737,11 +750,27 @@ private ExprEval eval(ExprEval param, int scale)
if (param.type() == ExprType.LONG) {
return ExprEval.of(BigDecimal.valueOf(param.asLong()).setScale(scale, RoundingMode.HALF_UP).longValue());
} else if (param.type() == ExprType.DOUBLE) {
return ExprEval.of(BigDecimal.valueOf(param.asDouble()).setScale(scale, RoundingMode.HALF_UP).doubleValue());
BigDecimal decimal = safeGetFromDouble(param.asDouble());
return ExprEval.of(decimal.setScale(scale, RoundingMode.HALF_UP).doubleValue());
} else {
return ExprEval.of(null);
}
}

/**
* Converts non-finite doubles to BigDecimal values instead of throwing a NumberFormatException.
*/
private static BigDecimal safeGetFromDouble(double val)
{
if (Double.isNaN(val)) {
return BigDecimal.ZERO;
} else if (val == Double.POSITIVE_INFINITY) {
return MAX_FINITE_VALUE;
} else if (val == Double.NEGATIVE_INFINITY) {
return MIN_FINITE_VALUE;
}
return BigDecimal.valueOf(val);
}
}

class Signum extends UnivariateMathFunction
Expand Down
157 changes: 150 additions & 7 deletions core/src/test/java/org/apache/druid/math/expr/FunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@
package org.apache.druid.math.expr;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Locale;
import java.util.Set;

public class FunctionTest extends InitializedNullHandlingTest
{
Expand All @@ -35,13 +41,23 @@ public class FunctionTest extends InitializedNullHandlingTest
@Before
public void setup()
{
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put("x", "foo");
builder.put("y", 2);
builder.put("z", 3.1);
builder.put("a", new String[] {"foo", "bar", "baz", "foobar"});
builder.put("b", new Long[] {1L, 2L, 3L, 4L, 5L});
builder.put("c", new Double[] {3.1, 4.2, 5.3});
ImmutableMap.Builder<String, Object> builder = ImmutableMap.<String, Object>builder()
.put("x", "foo")
.put("y", 2)
.put("z", 3.1)
.put("d", 34.56D)
.put("maxLong", Long.MAX_VALUE)
.put("minLong", Long.MIN_VALUE)
.put("f", 12.34F)
.put("nan", Double.NaN)
.put("inf", Double.POSITIVE_INFINITY)
.put("-inf", Double.NEGATIVE_INFINITY)
.put("o", 0)
.put("od", 0D)
.put("of", 0F)
.put("a", new String[] {"foo", "bar", "baz", "foobar"})
.put("b", new Long[] {1L, 2L, 3L, 4L, 5L})
.put("c", new Double[] {3.1, 4.2, 5.3});
bindings = Parser.withMap(builder.build());
}

Expand Down Expand Up @@ -320,6 +336,133 @@ public void testArrayPrepend()
assertArrayExpr("array_prepend(1, <DOUBLE>[])", new Double[]{1.0});
}

@Test
public void testRoundWithNonNumericValuesShouldReturn0()
{
assertExpr("round(nan)", 0D);
assertExpr("round(nan, 5)", 0D);
//CHECKSTYLE.OFF: Regexp
assertExpr("round(inf)", Double.MAX_VALUE);
assertExpr("round(inf, 4)", Double.MAX_VALUE);
assertExpr("round(-inf)", -1 * Double.MAX_VALUE);
assertExpr("round(-inf, 3)", -1 * Double.MAX_VALUE);
assertExpr("round(-inf, -5)", -1 * Double.MAX_VALUE);
//CHECKSTYLE.ON: Regexp

// Calculations that result in non numeric numbers
assertExpr("round(0/od)", 0D);
assertExpr("round(od/od)", 0D);
//CHECKSTYLE.OFF: Regexp
assertExpr("round(1/od)", Double.MAX_VALUE);
assertExpr("round(-1/od)", -1 * Double.MAX_VALUE);
//CHECKSTYLE.ON: Regexp

assertExpr("round(0/of)", 0D);
assertExpr("round(of/of)", 0D);
//CHECKSTYLE.OFF: Regexp
assertExpr("round(1/of)", Double.MAX_VALUE);
assertExpr("round(-1/of)", -1 * Double.MAX_VALUE);
//CHECKSTYLE.ON: Regexp
}

@Test
public void testRoundWithLong()
{
assertExpr("round(y)", 2L);
assertExpr("round(y, 2)", 2L);
assertExpr("round(y, -1)", 0L);
}

@Test
public void testRoundWithDouble()
{
assertExpr("round(d)", 35D);
assertExpr("round(d, 2)", 34.56D);
assertExpr("round(d, y)", 34.56D);
assertExpr("round(d, 1)", 34.6D);
assertExpr("round(d, -1)", 30D);
}

@Test
public void testRoundWithFloat()
{
assertExpr("round(f)", 12D);
assertExpr("round(f, 2)", 12.34D);
assertExpr("round(f, y)", 12.34D);
assertExpr("round(f, 1)", 12.3D);
assertExpr("round(f, -1)", 10D);
}

@Test
public void testRoundWithExtremeNumbers()
{
assertExpr("round(maxLong)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(0, RoundingMode.HALF_UP).longValue());
assertExpr("round(minLong)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(0, RoundingMode.HALF_UP).longValue());
// overflow
assertExpr("round(maxLong + 1, 1)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(1, RoundingMode.HALF_UP).longValue());
// underflow
assertExpr("round(minLong - 1, -2)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(-2, RoundingMode.HALF_UP).longValue());

assertExpr("round(CAST(maxLong, 'DOUBLE') + 1, 1)", BigDecimal.valueOf(((double) Long.MAX_VALUE) + 1).setScale(1, RoundingMode.HALF_UP).doubleValue());
assertExpr("round(CAST(minLong, 'DOUBLE') - 1, -2)", BigDecimal.valueOf(((double) Long.MIN_VALUE) - 1).setScale(-2, RoundingMode.HALF_UP).doubleValue());
}

@Test
public void testRoundWithInvalidFirstArgument()
{
Set<Pair<String, String>> invalidArguments = ImmutableSet.of(
Pair.of("b", "LONG_ARRAY"),
Pair.of("x", "STRING"),
Pair.of("c", "DOUBLE_ARRAY"),
Pair.of("a", "STRING_ARRAY")

);
for (Pair<String, String> argAndType : invalidArguments) {
try {
assertExpr(String.format(Locale.ENGLISH, "round(%s)", argAndType.lhs), null);
Assert.fail("Did not throw IllegalArgumentException");
}
catch (IllegalArgumentException e) {
Assert.assertEquals(
String.format(
Locale.ENGLISH,
"The first argument to the function[round] should be integer or double type but got the type: %s",
argAndType.rhs
),
e.getMessage()
);
}
}
}

@Test
public void testRoundWithInvalidSecondArgument()
{
Set<Pair<String, String>> invalidArguments = ImmutableSet.of(
Pair.of("1.2", "DOUBLE"),
Pair.of("x", "STRING"),
Pair.of("a", "STRING_ARRAY"),
Pair.of("c", "DOUBLE_ARRAY")

);
for (Pair<String, String> argAndType : invalidArguments) {
try {
assertExpr(String.format(Locale.ENGLISH, "round(d, %s)", argAndType.lhs), null);
Assert.fail("Did not throw IllegalArgumentException");
}
catch (IllegalArgumentException e) {
Assert.assertEquals(
String.format(
Locale.ENGLISH,
"The second argument to the function[round] should be integer type but got the type: %s",
argAndType.rhs
),
e.getMessage()
);
}
}
}

@Test
public void testGreatest()
{
Expand Down
2 changes: 1 addition & 1 deletion docs/misc/math-expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ See javadoc of java.lang.Math for detailed explanation for each function.
|pow|pow(x, y) would return the value of the x raised to the power of y|
|remainder|remainder(x, y) would return the remainder operation on two arguments as prescribed by the IEEE 754 standard|
|rint|rint(x) would return value that is closest in value to x and is equal to a mathematical integer|
|round|round(x, y) would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points.|
|round|round(x, y) would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points. If x is `NaN`, x will return 0. If x is infinity, x will be converted to the nearest finite double. |
|scalb|scalb(d, sf) would return d * 2^sf rounded as if performed by a single correctly rounded floating-point multiply to a member of the double value set|
|signum|signum(x) would return the signum function of the argument x|
|sin|sin(x) would return the trigonometric sine of an angle x|
Expand Down
2 changes: 1 addition & 1 deletion docs/querying/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ to FLOAT. At runtime, Druid will widen 32-bit floats to 64-bit for most expressi
|`SQRT(expr)`|Square root.|
|`TRUNCATE(expr[, digits])`|Truncate expr to a specific number of decimal digits. If digits is negative, then this truncates that many places to the left of the decimal point. Digits defaults to zero if not specified.|
|`TRUNC(expr[, digits])`|Synonym for `TRUNCATE`.|
|`ROUND(expr[, digits])`|`ROUND(x, y)` would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points.|
|`ROUND(expr[, digits])`|`ROUND(x, y)` would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points. If `expr` evaluates to either `NaN`, `expr` will be converted to 0. If `expr` is infinity, `expr` will be converted to the nearest finite double. |
|`x + y`|Addition.|
|`x - y`|Subtraction.|
|`x * y`|Multiplication.|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.druid.query.groupby.having;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Longs;
import org.apache.druid.java.util.common.ISE;
Expand Down Expand Up @@ -89,10 +90,16 @@ static int compare(String aggregationName, Number value, Map<String, AggregatorF
}
}

private static int compareDoubleToLong(final double a, final long b)
@VisibleForTesting
static int compareDoubleToLong(final double a, final long b)
{
// Use BigDecimal when comparing integers vs floating points, a convenient way to handle all cases (like
// fractional values, values out of range of max long/max int) without worrying about them ourselves.
// The only edge case we need to handle is doubles that can not be converted to a BigDecimal, so fall back to using
// Double.compare
if (Double.isNaN(a) || Double.isInfinite(a)) {
return Double.compare(a, b);
}
return BigDecimal.valueOf(a).compareTo(BigDecimal.valueOf(b));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,64 @@ public void testLongRegex()
Assert.assertFalse(HavingSpecMetricComparator.LONG_PAT.matcher("").matches());
Assert.assertFalse(HavingSpecMetricComparator.LONG_PAT.matcher("xyz").matches());
}

@Test
public void testCompareDoubleToLongWithNanReturns1()
{
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.NaN, 1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.NaN, -1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.NaN, Long.MAX_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.NaN, Long.MIN_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.NaN, 0L));

Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.NaN, 1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.NaN, -1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.NaN, Long.MAX_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.NaN, Long.MIN_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.NaN, 0L));
}

@Test
public void testCompareDoubleToLongWithInfinityReturns1()
{
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.POSITIVE_INFINITY, 1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.POSITIVE_INFINITY, -1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.POSITIVE_INFINITY, Long.MAX_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.POSITIVE_INFINITY, Long.MIN_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Double.POSITIVE_INFINITY, 0L));

Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.POSITIVE_INFINITY, 1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.POSITIVE_INFINITY, -1));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.POSITIVE_INFINITY, Long.MAX_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.POSITIVE_INFINITY, Long.MIN_VALUE));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(Float.POSITIVE_INFINITY, 0L));
}

@Test
public void testCompareDoubleToLongWithInfinityReturnsNegative1()
{
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Double.NEGATIVE_INFINITY, 1));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Double.NEGATIVE_INFINITY, -1));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Double.NEGATIVE_INFINITY, Long.MAX_VALUE));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Double.NEGATIVE_INFINITY, Long.MIN_VALUE));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Double.NEGATIVE_INFINITY, 0L));

Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Float.NEGATIVE_INFINITY, 1));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Float.NEGATIVE_INFINITY, -1));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Float.NEGATIVE_INFINITY, Long.MAX_VALUE));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Float.NEGATIVE_INFINITY, Long.MIN_VALUE));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(Float.NEGATIVE_INFINITY, 0L));
}

@Test
public void testCompareDoubleToLongWithNumbers()
{
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong(1 + 1e-6, 1));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong(1 - 1e-6, 1));
Assert.assertEquals(0, HavingSpecMetricComparator.compareDoubleToLong(10D, 10));
Assert.assertEquals(0, HavingSpecMetricComparator.compareDoubleToLong(0D, 0));
Assert.assertEquals(0, HavingSpecMetricComparator.compareDoubleToLong(-0D, 0));
Assert.assertEquals(1, HavingSpecMetricComparator.compareDoubleToLong((double) Long.MAX_VALUE + 1, Long.MAX_VALUE));
Assert.assertEquals(-1, HavingSpecMetricComparator.compareDoubleToLong((double) Long.MIN_VALUE - 1, Long.MIN_VALUE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,13 @@ public void reduce(
} else {
if (exprResult.type() == ExprType.LONG) {
bigDecimal = BigDecimal.valueOf(exprResult.asLong());

} else {
// if exprResult evaluates to Nan or infinity, this will throw a NumberFormatException.
// If you find yourself in such a position, consider casting the literal to a BIGINT so that
// the query can execute.
bigDecimal = BigDecimal.valueOf(exprResult.asDouble());
}

literal = rexBuilder.makeLiteral(bigDecimal, constExp.getType(), true);
}
} else if (sqlTypeName == SqlTypeName.ARRAY) {
Expand Down
Loading