diff --git a/extensions-contrib/spectator-histogram/src/main/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramPercentileSqlAggregator.java b/extensions-contrib/spectator-histogram/src/main/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramPercentileSqlAggregator.java index 987700ebd1a6..0e3931706c14 100644 --- a/extensions-contrib/spectator-histogram/src/main/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramPercentileSqlAggregator.java +++ b/extensions-contrib/spectator-histogram/src/main/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramPercentileSqlAggregator.java @@ -42,6 +42,7 @@ import org.apache.druid.sql.calcite.aggregation.Aggregations; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; @@ -123,7 +124,8 @@ private Aggregation handleSinglePercentile( final List existingAggregations ) { - final double percentile = ((Number) RexLiteral.value(percentileArg)).doubleValue(); + final Object value = RexLiteral.value(percentileArg); + final double percentile = DruidSqlParserUtils.getNumericLiteral(value, NAME, "percentile").doubleValue(); final String histogramName = StringUtils.format("%s:agg", name); diff --git a/extensions-contrib/spectator-histogram/src/test/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramSqlAggregatorTest.java b/extensions-contrib/spectator-histogram/src/test/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramSqlAggregatorTest.java index 1d2d68092bab..a0d368d2d77c 100644 --- a/extensions-contrib/spectator-histogram/src/test/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramSqlAggregatorTest.java +++ b/extensions-contrib/spectator-histogram/src/test/java/org/apache/druid/spectator/histogram/sql/SpectatorHistogramSqlAggregatorTest.java @@ -29,6 +29,7 @@ import org.apache.druid.data.input.impl.StringDimensionSchema; import org.apache.druid.data.input.impl.TimeAndDimsParseSpec; import org.apache.druid.data.input.impl.TimestampSpec; +import org.apache.druid.error.DruidException; import org.apache.druid.initialization.DruidModule; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.query.Druids; @@ -57,6 +58,7 @@ import org.apache.druid.sql.calcite.util.SqlTestFramework.StandardComponentSupplier; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; +import org.junit.Assert; import org.junit.jupiter.api.Test; import java.util.Collections; @@ -551,4 +553,21 @@ public void testSpectatorFunctionsOnNullHistogram() ImmutableList.of(new Object[]{null, null, null}) ); } + + @Test + public void testSpectatorPercentileWithStringLiteral() + { + // verify invalid queries return 400 (user error) + final String query = "SELECT SPECTATOR_PERCENTILE(histogram_metric, '99.99') FROM foo"; + + try { + testQuery(query, ImmutableList.of(), ImmutableList.of()); + Assert.fail("Expected DruidException but query succeeded"); + } + catch (DruidException e) { + Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona()); + Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory()); + Assert.assertTrue(e.getMessage().contains("must be a numeric literal")); + } + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java index 7d303d272740..442374501786 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctSqlAggregator.java @@ -110,4 +110,10 @@ protected Aggregation toAggregation( finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorFactory.getName()) : null ); } + + @Override + protected String getName() + { + return NAME; + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctUtf8SqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctUtf8SqlAggregator.java index 070fbd9f7337..0b82ad94c963 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctUtf8SqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchApproxCountDistinctUtf8SqlAggregator.java @@ -81,4 +81,10 @@ protected Aggregation toAggregation( finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorFactory.getName()) : null ); } + + @Override + protected String getName() + { + return NAME; + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java index 15221c0f6f81..b231e0504a02 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchBaseSqlAggregator.java @@ -40,6 +40,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerContext; @@ -92,7 +93,7 @@ public Aggregation toDruidAggregation( return null; } - logK = ((Number) RexLiteral.value(logKarg)).intValue(); + logK = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(logKarg), getName(), "logK").intValue(); } else { logK = HllSketchAggregatorFactory.DEFAULT_LG_K; } @@ -204,6 +205,8 @@ protected abstract Aggregation toAggregation( AggregatorFactory aggregatorFactory ); + protected abstract String getName(); + private boolean isValidComplexInputType(ColumnType columnType) { return HllSketchMergeAggregatorFactory.TYPE.equals(columnType) || diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java index 0e466bcaf0a0..3a3a35430537 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchObjectSqlAggregator.java @@ -72,4 +72,10 @@ protected Aggregation toAggregation( null ); } + + @Override + protected String getName() + { + return NAME; + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java index 08c7a1b123fd..a6604eb81f4f 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java @@ -40,6 +40,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; @@ -98,7 +99,7 @@ public Aggregation toDruidAggregation( return null; } - final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue(); + final float probability = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(probabilityArg), NAME, "probability").floatValue(); final int k; if (aggregateCall.getArgList().size() >= 3) { @@ -109,7 +110,7 @@ public Aggregation toDruidAggregation( return null; } - k = ((Number) RexLiteral.value(resolutionArg)).intValue(); + k = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(resolutionArg), NAME, "k").intValue(); } else { k = DoublesSketchAggregatorFactory.DEFAULT_K; } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java index 5ecd289c7287..e36ff03084d4 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchApproxCountDistinctSqlAggregator.java @@ -100,4 +100,10 @@ protected Aggregation toAggregation( ) : null ); } + + @Override + protected String getName() + { + return NAME; + } } diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java index b5c30cec7403..b62ffc9f23ce 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchBaseSqlAggregator.java @@ -38,6 +38,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerContext; @@ -90,7 +91,7 @@ public Aggregation toDruidAggregation( return null; } - sketchSize = ((Number) RexLiteral.value(sketchSizeArg)).intValue(); + sketchSize = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(sketchSizeArg), getName(), "size").intValue(); } else { sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE; } @@ -173,6 +174,8 @@ protected abstract Aggregation toAggregation( AggregatorFactory aggregatorFactory ); + protected abstract String getName(); + private boolean isValidComplexInputType(ColumnType columnType) { return SketchModule.THETA_SKETCH_TYPE.equals(columnType) || diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java index ac9cefd5f9ba..c4832b94056a 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchObjectSqlAggregator.java @@ -67,4 +67,10 @@ protected Aggregation toAggregation( null ); } + + @Override + protected String getName() + { + return NAME; + } } diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java index 0ee832ffe0ae..907324c97ce1 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import org.apache.druid.error.DruidException; import org.apache.druid.initialization.DruidModule; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.granularity.Granularities; @@ -71,6 +72,7 @@ import org.apache.druid.sql.calcite.util.TestDataBuilder; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.LinearShardSpec; +import org.junit.Assert; import org.junit.jupiter.api.Test; import java.util.Collections; @@ -1111,6 +1113,40 @@ public void testSuccessWithSmallMaxStreamLength() ); } + @Test + public void testApproxQuantileWithStringLiteral() + { + // verify invalid queries return 400 (user error) + final String query = "SELECT APPROX_QUANTILE_DS(m1, '0.99') FROM foo"; + + try { + testQuery(query, ImmutableList.of(), ImmutableList.of()); + Assert.fail("Expected DruidException but query succeeded"); + } + catch (DruidException e) { + Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona()); + Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory()); + Assert.assertTrue(e.getMessage().contains("Cannot apply 'APPROX_QUANTILE_DS'")); + } + } + + @Test + public void testApproxQuantileWithStringResolution() + { + // verify invalid queries return 400 (user error) + final String query = "SELECT APPROX_QUANTILE_DS(m1, 0.99, '128') FROM foo"; + + try { + testQuery(query, ImmutableList.of(), ImmutableList.of()); + Assert.fail("Expected DruidException but query succeeded"); + } + catch (DruidException e) { + Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona()); + Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory()); + Assert.assertTrue(e.getMessage().contains("Cannot apply 'APPROX_QUANTILE_DS'")); + } + } + private static PostAggregator makeFieldAccessPostAgg(String name) { return new FieldAccessPostAggregator(name, name); diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java index 209fe3500f40..821ed27bc640 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregator.java @@ -39,6 +39,7 @@ import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -88,7 +89,7 @@ public Aggregation toDruidAggregation( return null; } - final int maxNumEntries = ((Number) RexLiteral.value(maxNumEntriesOperand)).intValue(); + final int maxNumEntries = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxNumEntriesOperand), NAME, "maxNumEntries").intValue(); // Look for existing matching aggregatorFactory. for (final Aggregation existing : existingAggregations) { diff --git a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java index 41df080147b4..5dc39247753f 100644 --- a/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java +++ b/extensions-core/histogram/src/main/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregator.java @@ -43,6 +43,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker; import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; @@ -91,7 +92,7 @@ public Aggregation toDruidAggregation( return null; } - final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue(); + final float probability = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(probabilityArg), NAME, "probability").floatValue(); final int resolution; if (aggregateCall.getArgList().size() >= 3) { @@ -102,7 +103,7 @@ public Aggregation toDruidAggregation( return null; } - resolution = ((Number) RexLiteral.value(resolutionArg)).intValue(); + resolution = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(resolutionArg), NAME, "resolution").intValue(); } else { resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java index d20999d3afc4..ba910f3844b9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArrayConcatSqlAggregator.java @@ -43,6 +43,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -83,7 +84,7 @@ public Aggregation toDruidAggregation( // maxBytes must be a literal return null; } - maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); + maxSizeBytes = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxBytes), NAME, "maxBytes").intValue(); } final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0)); final ExprMacroTable macroTable = plannerContext.getPlannerToolbox().exprMacroTable(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index 1045a79870bb..3edd65c9d23b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -45,6 +45,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -86,7 +87,7 @@ public Aggregation toDruidAggregation( // maxBytes must be a literal return null; } - maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); + maxSizeBytes = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxBytes), NAME, "maxBytes").intValue(); } final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0)); if (arg == null) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java index 117738f72dda..55df7c76b2bc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java @@ -50,6 +50,7 @@ import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.InputAccessor; @@ -69,8 +70,9 @@ public class StringSqlAggregator implements SqlAggregator { private final SqlAggFunction function; + private static final String NAME = "STRING_AGG"; - public static final StringSqlAggregator STRING_AGG = new StringSqlAggregator(new StringAggFunction("STRING_AGG")); + public static final StringSqlAggregator STRING_AGG = new StringSqlAggregator(new StringAggFunction(NAME)); public static final StringSqlAggregator LISTAGG = new StringSqlAggregator(new StringAggFunction("LISTAGG")); public StringSqlAggregator(SqlAggFunction function) @@ -130,7 +132,7 @@ public Aggregation toDruidAggregation( // maxBytes must be a literal return null; } - maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); + maxSizeBytes = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxBytes), NAME, "maxBytes").intValue(); } final DruidExpression arg = arguments.get(0); @@ -214,7 +216,7 @@ public RelDataType inferReturnType(SqlOperatorBinding sqlOperatorBinding) throw SimpleSqlAggregator.badTypeException( columnName, - "STRING_AGG", + NAME, ((RowSignatures.ComplexSqlType) type).getColumnType() ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParserUtils.java b/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParserUtils.java index 38909ec91877..cf5cfdfd5886 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParserUtils.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParserUtils.java @@ -665,4 +665,51 @@ public static DruidException problemParsing(String message) { return InvalidSqlInput.exception(message); } + + /** + * Creates a DruidException for invalid SQL function parameter types. + * + * @param functionName the SQL function name (e.g., "SPECTATOR_PERCENTILE") + * @param parameterName the parameter name + * @param expectedType the expected type + * @param actualValue the value provided needed to determine type + * @return DruidException with INVALID_INPUT category and USER persona + */ + public static DruidException invalidParameterTypeException( + String functionName, + String parameterName, + String expectedType, + @Nullable Object actualValue + ) + { + final String actualType = actualValue == null ? "NULL" : actualValue.getClass().getSimpleName(); + return InvalidSqlInput.exception( + "%s parameter `%s` must be a %s literal, got %s", + functionName, + parameterName, + expectedType, + actualType + ); + } + + /** + * Validates and returns a numeric value from a RexLiteral, or throws invalidParameterTypeException if invalid. + * + * @param value the value extracted from RexLiteral.value() + * @param functionName the SQL function name + * @param parameterName the parameter name + * @return the value as a Number + * @throws DruidException if value is not a Number + */ + public static Number getNumericLiteral( + @Nullable Object value, + String functionName, + String parameterName + ) + { + if (!(value instanceof Number)) { + throw invalidParameterTypeException(functionName, parameterName, "numeric", value); + } + return (Number) value; + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 14ea8ddb4db3..120d2f6ab35a 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -14377,6 +14377,57 @@ public void testStringAggMaxBytes() ); } + @Test + public void testStringAggWithStringMaxBytes() + { + // verify invalid queries return 400 (user error) + final String query = "SELECT STRING_AGG(dim1, ',', 'abc') FROM foo"; + + try { + testQuery(query, ImmutableList.of(), ImmutableList.of()); + Assert.fail("Expected DruidException but query succeeded"); + } + catch (DruidException e) { + Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona()); + Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory()); + Assert.assertTrue(e.getMessage().contains("parameter `maxBytes` must be a numeric literal")); + } + } + + @Test + public void testArrayAggWithStringMaxBytes() + { + // verify invalid queries return 400 (user error) + final String query = "SELECT ARRAY_AGG(dim1, 'abc') FROM foo"; + + try { + testQuery(query, ImmutableList.of(), ImmutableList.of()); + Assert.fail("Expected DruidException but query succeeded"); + } + catch (DruidException e) { + Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona()); + Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory()); + Assert.assertTrue(e.getMessage().contains("parameter `maxBytes` must be a numeric literal")); + } + } + + @Test + public void testArrayConcatAggWithStringMaxBytes() + { + // verify invalid queries return 400 (user error) + final String query = "SELECT ARRAY_CONCAT_AGG(MV_TO_ARRAY(dim3), 'abc') FROM foo"; + + try { + testQuery(query, ImmutableList.of(), ImmutableList.of()); + Assert.fail("Expected DruidException but query succeeded"); + } + catch (DruidException e) { + Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona()); + Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory()); + Assert.assertTrue(e.getMessage().contains("parameter `maxBytes` must be a numeric literal")); + } + } + /** * see {@link TestDataBuilder#RAW_ROWS1_WITH_NUMERIC_DIMS} * for the input data source of this test