diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index af52555793b7..a20863929875 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -694,6 +694,11 @@ protected ExprEval eval(double param) class Round implements Function { + //CHECKSTYLE.OFF: Regexp + private static final BigDecimal MAX_FINITE_VALUE = BigDecimal.valueOf(Double.MAX_VALUE); + private static final BigDecimal MIN_FINITE_VALUE = BigDecimal.valueOf(-1 * Double.MAX_VALUE); + //CHECKSTYLE.ON: Regexp + @Override public String name() { @@ -705,7 +710,11 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) { ExprEval value1 = args.get(0).eval(bindings); if (value1.type() != ExprType.LONG && value1.type() != ExprType.DOUBLE) { - throw new IAE("The first argument to the function[%s] should be integer or double type but get the %s type", name(), value1.type()); + throw new IAE( + "The first argument to the function[%s] should be integer or double type but got the type: %s", + name(), + value1.type() + ); } if (args.size() == 1) { @@ -713,7 +722,11 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) } else { ExprEval value2 = args.get(1).eval(bindings); if (value2.type() != ExprType.LONG) { - throw new IAE("The second argument to the function[%s] should be integer type but get the %s type", name(), value2.type()); + throw new IAE( + "The second argument to the function[%s] should be integer type but got the type: %s", + name(), + value2.type() + ); } return eval(value1, value2.asInt()); } @@ -737,11 +750,27 @@ private ExprEval eval(ExprEval param, int scale) if (param.type() == ExprType.LONG) { return ExprEval.of(BigDecimal.valueOf(param.asLong()).setScale(scale, RoundingMode.HALF_UP).longValue()); } else if (param.type() == ExprType.DOUBLE) { - return ExprEval.of(BigDecimal.valueOf(param.asDouble()).setScale(scale, RoundingMode.HALF_UP).doubleValue()); + BigDecimal decimal = safeGetFromDouble(param.asDouble()); + return ExprEval.of(decimal.setScale(scale, RoundingMode.HALF_UP).doubleValue()); } else { return ExprEval.of(null); } } + + /** + * Converts non-finite doubles to BigDecimal values instead of throwing a NumberFormatException. + */ + private static BigDecimal safeGetFromDouble(double val) + { + if (Double.isNaN(val)) { + return BigDecimal.ZERO; + } else if (val == Double.POSITIVE_INFINITY) { + return MAX_FINITE_VALUE; + } else if (val == Double.NEGATIVE_INFINITY) { + return MIN_FINITE_VALUE; + } + return BigDecimal.valueOf(val); + } } class Signum extends UnivariateMathFunction diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index fc83b4f7dcdf..bd755ba7e0ec 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -20,13 +20,19 @@ package org.apache.druid.math.expr; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.Locale; +import java.util.Set; public class FunctionTest extends InitializedNullHandlingTest { @@ -35,13 +41,23 @@ public class FunctionTest extends InitializedNullHandlingTest @Before public void setup() { - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put("x", "foo"); - builder.put("y", 2); - builder.put("z", 3.1); - builder.put("a", new String[] {"foo", "bar", "baz", "foobar"}); - builder.put("b", new Long[] {1L, 2L, 3L, 4L, 5L}); - builder.put("c", new Double[] {3.1, 4.2, 5.3}); + ImmutableMap.Builder builder = ImmutableMap.builder() + .put("x", "foo") + .put("y", 2) + .put("z", 3.1) + .put("d", 34.56D) + .put("maxLong", Long.MAX_VALUE) + .put("minLong", Long.MIN_VALUE) + .put("f", 12.34F) + .put("nan", Double.NaN) + .put("inf", Double.POSITIVE_INFINITY) + .put("-inf", Double.NEGATIVE_INFINITY) + .put("o", 0) + .put("od", 0D) + .put("of", 0F) + .put("a", new String[] {"foo", "bar", "baz", "foobar"}) + .put("b", new Long[] {1L, 2L, 3L, 4L, 5L}) + .put("c", new Double[] {3.1, 4.2, 5.3}); bindings = Parser.withMap(builder.build()); } @@ -320,6 +336,133 @@ public void testArrayPrepend() assertArrayExpr("array_prepend(1, [])", new Double[]{1.0}); } + @Test + public void testRoundWithNonNumericValuesShouldReturn0() + { + assertExpr("round(nan)", 0D); + assertExpr("round(nan, 5)", 0D); + //CHECKSTYLE.OFF: Regexp + assertExpr("round(inf)", Double.MAX_VALUE); + assertExpr("round(inf, 4)", Double.MAX_VALUE); + assertExpr("round(-inf)", -1 * Double.MAX_VALUE); + assertExpr("round(-inf, 3)", -1 * Double.MAX_VALUE); + assertExpr("round(-inf, -5)", -1 * Double.MAX_VALUE); + //CHECKSTYLE.ON: Regexp + + // Calculations that result in non numeric numbers + assertExpr("round(0/od)", 0D); + assertExpr("round(od/od)", 0D); + //CHECKSTYLE.OFF: Regexp + assertExpr("round(1/od)", Double.MAX_VALUE); + assertExpr("round(-1/od)", -1 * Double.MAX_VALUE); + //CHECKSTYLE.ON: Regexp + + assertExpr("round(0/of)", 0D); + assertExpr("round(of/of)", 0D); + //CHECKSTYLE.OFF: Regexp + assertExpr("round(1/of)", Double.MAX_VALUE); + assertExpr("round(-1/of)", -1 * Double.MAX_VALUE); + //CHECKSTYLE.ON: Regexp + } + + @Test + public void testRoundWithLong() + { + assertExpr("round(y)", 2L); + assertExpr("round(y, 2)", 2L); + assertExpr("round(y, -1)", 0L); + } + + @Test + public void testRoundWithDouble() + { + assertExpr("round(d)", 35D); + assertExpr("round(d, 2)", 34.56D); + assertExpr("round(d, y)", 34.56D); + assertExpr("round(d, 1)", 34.6D); + assertExpr("round(d, -1)", 30D); + } + + @Test + public void testRoundWithFloat() + { + assertExpr("round(f)", 12D); + assertExpr("round(f, 2)", 12.34D); + assertExpr("round(f, y)", 12.34D); + assertExpr("round(f, 1)", 12.3D); + assertExpr("round(f, -1)", 10D); + } + + @Test + public void testRoundWithExtremeNumbers() + { + assertExpr("round(maxLong)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(0, RoundingMode.HALF_UP).longValue()); + assertExpr("round(minLong)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(0, RoundingMode.HALF_UP).longValue()); + // overflow + assertExpr("round(maxLong + 1, 1)", BigDecimal.valueOf(Long.MIN_VALUE).setScale(1, RoundingMode.HALF_UP).longValue()); + // underflow + assertExpr("round(minLong - 1, -2)", BigDecimal.valueOf(Long.MAX_VALUE).setScale(-2, RoundingMode.HALF_UP).longValue()); + + assertExpr("round(CAST(maxLong, 'DOUBLE') + 1, 1)", BigDecimal.valueOf(((double) Long.MAX_VALUE) + 1).setScale(1, RoundingMode.HALF_UP).doubleValue()); + assertExpr("round(CAST(minLong, 'DOUBLE') - 1, -2)", BigDecimal.valueOf(((double) Long.MIN_VALUE) - 1).setScale(-2, RoundingMode.HALF_UP).doubleValue()); + } + + @Test + public void testRoundWithInvalidFirstArgument() + { + Set> invalidArguments = ImmutableSet.of( + Pair.of("b", "LONG_ARRAY"), + Pair.of("x", "STRING"), + Pair.of("c", "DOUBLE_ARRAY"), + Pair.of("a", "STRING_ARRAY") + + ); + for (Pair argAndType : invalidArguments) { + try { + assertExpr(String.format(Locale.ENGLISH, "round(%s)", argAndType.lhs), null); + Assert.fail("Did not throw IllegalArgumentException"); + } + catch (IllegalArgumentException e) { + Assert.assertEquals( + String.format( + Locale.ENGLISH, + "The first argument to the function[round] should be integer or double type but got the type: %s", + argAndType.rhs + ), + e.getMessage() + ); + } + } + } + + @Test + public void testRoundWithInvalidSecondArgument() + { + Set> invalidArguments = ImmutableSet.of( + Pair.of("1.2", "DOUBLE"), + Pair.of("x", "STRING"), + Pair.of("a", "STRING_ARRAY"), + Pair.of("c", "DOUBLE_ARRAY") + + ); + for (Pair argAndType : invalidArguments) { + try { + assertExpr(String.format(Locale.ENGLISH, "round(d, %s)", argAndType.lhs), null); + Assert.fail("Did not throw IllegalArgumentException"); + } + catch (IllegalArgumentException e) { + Assert.assertEquals( + String.format( + Locale.ENGLISH, + "The second argument to the function[round] should be integer type but got the type: %s", + argAndType.rhs + ), + e.getMessage() + ); + } + } + } + @Test public void testGreatest() { diff --git a/docs/misc/math-expr.md b/docs/misc/math-expr.md index 709715ac9c1e..dc356479ad58 100644 --- a/docs/misc/math-expr.md +++ b/docs/misc/math-expr.md @@ -141,7 +141,7 @@ See javadoc of java.lang.Math for detailed explanation for each function. |pow|pow(x, y) would return the value of the x raised to the power of y| |remainder|remainder(x, y) would return the remainder operation on two arguments as prescribed by the IEEE 754 standard| |rint|rint(x) would return value that is closest in value to x and is equal to a mathematical integer| -|round|round(x, y) would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points.| +|round|round(x, y) would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points. If x is `NaN`, x will return 0. If x is infinity, x will be converted to the nearest finite double. | |scalb|scalb(d, sf) would return d * 2^sf rounded as if performed by a single correctly rounded floating-point multiply to a member of the double value set| |signum|signum(x) would return the signum function of the argument x| |sin|sin(x) would return the trigonometric sine of an angle x| diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 1f6e7e3ab716..acbc3ae0e726 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -287,7 +287,7 @@ to FLOAT. At runtime, Druid will widen 32-bit floats to 64-bit for most expressi |`SQRT(expr)`|Square root.| |`TRUNCATE(expr[, digits])`|Truncate expr to a specific number of decimal digits. If digits is negative, then this truncates that many places to the left of the decimal point. Digits defaults to zero if not specified.| |`TRUNC(expr[, digits])`|Synonym for `TRUNCATE`.| -|`ROUND(expr[, digits])`|`ROUND(x, y)` would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points.| +|`ROUND(expr[, digits])`|`ROUND(x, y)` would return the value of the x rounded to the y decimal places. While x can be an integer or floating-point number, y must be an integer. The type of the return value is specified by that of x. y defaults to 0 if omitted. When y is negative, x is rounded on the left side of the y decimal points. If `expr` evaluates to either `NaN`, `expr` will be converted to 0. If `expr` is infinity, `expr` will be converted to the nearest finite double. | |`x + y`|Addition.| |`x - y`|Subtraction.| |`x * y`|Multiplication.| diff --git a/processing/src/main/java/org/apache/druid/query/groupby/having/HavingSpecMetricComparator.java b/processing/src/main/java/org/apache/druid/query/groupby/having/HavingSpecMetricComparator.java index 88f50efb8488..e3f786e1f8aa 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/having/HavingSpecMetricComparator.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/having/HavingSpecMetricComparator.java @@ -19,6 +19,7 @@ package org.apache.druid.query.groupby.having; +import com.google.common.annotations.VisibleForTesting; import com.google.common.primitives.Doubles; import com.google.common.primitives.Longs; import org.apache.druid.java.util.common.ISE; @@ -89,10 +90,16 @@ static int compare(String aggregationName, Number value, Map