From 6f5350d7fe5b4e873a5392e124b3640302321804 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Tue, 8 Jun 2021 17:30:10 -0700 Subject: [PATCH 01/17] support distinct count aggregation Signed-off-by: chloe-zh --- .../sql/analysis/ExpressionAnalyzer.java | 1 + .../org/opensearch/sql/ast/dsl/AstDSL.java | 4 ++++ .../sql/ast/expression/AggregateFunction.java | 14 +++++++++++ .../org/opensearch/sql/expression/DSL.java | 20 ++++++++++++++++ .../aggregation/AggregationState.java | 3 +++ .../expression/aggregation/Aggregator.java | 18 +++++++++++++- .../expression/aggregation/AvgAggregator.java | 6 +++++ .../aggregation/CountAggregator.java | 16 ++++++++++++- .../expression/aggregation/MaxAggregator.java | 6 +++++ .../expression/aggregation/MinAggregator.java | 6 +++++ .../expression/aggregation/SumAggregator.java | 6 +++++ .../sql/analysis/ExpressionAnalyzerTest.java | 8 +++++++ .../aggregation/AggregationTest.java | 7 ++++++ .../aggregation/AvgAggregatorTest.java | 7 ++++++ .../aggregation/CountAggregatorTest.java | 7 ++++++ .../aggregation/MaxAggregatorTest.java | 7 ++++++ .../aggregation/MinAggregatorTest.java | 7 ++++++ .../aggregation/SumAggregatorTest.java | 7 ++++++ .../correctness/queries/aggregation.txt | 3 ++- .../dsl/MetricAggregationBuilder.java | 24 +++++++++++++++++++ .../dsl/MetricAggregationBuilderTest.java | 17 +++++++++++++ sql/src/main/antlr/OpenSearchSQLParser.g4 | 7 ++++-- .../sql/sql/parser/AstExpressionBuilder.java | 9 +++++++ .../sql/parser/AstAggregationBuilderTest.java | 14 +++++++++++ 24 files changed, 219 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 0f207c03741..3cc1dc95278 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -160,6 +160,7 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext Expression arg = node.getField().accept(this, context); Aggregator aggregator = (Aggregator) repository.compile( builtinFunctionName.get().getName(), Collections.singletonList(arg)); + aggregator.distinct(node.getDistinct()); if (node.getCondition() != null) { aggregator.condition(analyze(node.getCondition(), context)); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 7400ae20e6f..be8f7095db5 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -214,6 +214,10 @@ public static UnresolvedExpression filteredAggregate( return new AggregateFunction(func, field, condition); } + public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) { + return new AggregateFunction(func, field, true); + } + public static Function function(String funcName, UnresolvedExpression... funcArgs) { return new Function(funcName, Arrays.asList(funcArgs)); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index 8753e35ed9f..d11fdca3ac0 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -46,6 +46,7 @@ public class AggregateFunction extends UnresolvedExpression { private final UnresolvedExpression field; private final List argList; private UnresolvedExpression condition; + private Boolean distinct = false; /** * Constructor. @@ -72,6 +73,19 @@ public AggregateFunction(String funcName, UnresolvedExpression field, this.condition = condition; } + /** + * Constructor. + * @param funcName function name. + * @param field {@link UnresolvedExpression}. + * @param distinct field is distinct. + */ + public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) { + this.funcName = funcName; + this.field = field; + this.argList = Collections.emptyList(); + this.distinct = distinct; + } + @Override public List getChild() { return Collections.singletonList(field); diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 31050afc871..93f86ca1f82 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -492,14 +492,26 @@ public Aggregator avg(Expression... expressions) { return aggregate(BuiltinFunctionName.AVG, expressions); } + public Aggregator distinctAvg(Expression... expressions) { + return avg(expressions).distinct(true); + } + public Aggregator sum(Expression... expressions) { return aggregate(BuiltinFunctionName.SUM, expressions); } + public Aggregator distinctSum(Expression... expressions) { + return sum(expressions).distinct(true); + } + public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } + public Aggregator distinctCount(Expression... expressions) { + return count(expressions).distinct(true); + } + public RankingWindowFunction rowNumber() { return (RankingWindowFunction) repository.compile( BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList()); @@ -519,10 +531,18 @@ public Aggregator min(Expression... expressions) { return aggregate(BuiltinFunctionName.MIN, expressions); } + public Aggregator distinctMin(Expression... expressions) { + return min(expressions).distinct(true); + } + public Aggregator max(Expression... expressions) { return aggregate(BuiltinFunctionName.MAX, expressions); } + public Aggregator distinctMax(Expression... expressions) { + return max(expressions).distinct(true); + } + private FunctionExpression function(BuiltinFunctionName functionName, Expression... expressions) { return (FunctionExpression) repository.compile( functionName.getName(), Arrays.asList(expressions)); diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java index b1c29cb4a7a..ed3eca77513 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java @@ -26,6 +26,7 @@ package org.opensearch.sql.expression.aggregation; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.storage.bindingtuple.BindingTuple; @@ -37,4 +38,6 @@ public interface AggregationState { * Get {@link ExprValue} result. */ ExprValue result(); + + Set distinctSet(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index 80944172ea1..1e9af97e8b9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -64,6 +64,12 @@ public abstract class Aggregator @Getter @Accessors(fluent = true) protected Expression condition; + @Setter + @Getter + @Accessors(fluent = true) + protected Boolean distinct = false; + + /** * Create an {@link AggregationState} which will be used for aggregation. @@ -89,7 +95,8 @@ public abstract class Aggregator */ public S iterate(BindingTuple tuple, S state) { ExprValue value = getArguments().get(0).valueOf(tuple); - if (value.isNull() || value.isMissing() || !conditionValue(tuple)) { + if (value.isNull() || value.isMissing() || !conditionValue(tuple) + || (distinct && duplicated(value, state))) { return state; } return iterate(value, state); @@ -121,4 +128,13 @@ public boolean conditionValue(BindingTuple tuple) { return ExprValueUtils.getBooleanValue(condition.valueOf(tuple)); } + private Boolean duplicated(ExprValue value, S state) { + for (ExprValue exprValue : state.distinctSet()) { + if (exprValue.compareTo(value) == 0) { + return true; + } + } + return false; + } + } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java index 0ec0a02a3c1..ca1c32c24f5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Locale; +import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -80,5 +81,10 @@ protected static class AvgState implements AggregationState { public ExprValue result() { return count == 0 ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count); } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 3195bf39413..36f78c50765 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -28,8 +28,11 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.HashSet; import java.util.List; import java.util.Locale; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -50,7 +53,7 @@ public CountAggregator.CountState create() { @Override protected CountState iterate(ExprValue value, CountState state) { - state.count++; + state.count(value); return state; } @@ -64,14 +67,25 @@ public String toString() { */ protected static class CountState implements AggregationState { private int count; + private final Set set = new HashSet<>(); CountState() { this.count = 0; } + public void count(ExprValue value) { + set.add(value); + count++; + } + @Override public ExprValue result() { return ExprValueUtils.integerValue(count); } + + @Override + public Set distinctSet() { + return set; + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java index 11ad63093db..9a1d31caad1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java @@ -30,6 +30,7 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; @@ -74,5 +75,10 @@ public void max(ExprValue value) { public ExprValue result() { return maxResult; } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java index 46f69129ed8..a40315c8c0a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java @@ -30,6 +30,7 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; +import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; @@ -79,5 +80,10 @@ public void min(ExprValue value) { public ExprValue result() { return minResult; } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java index e658d21471e..afdf61c5ad0 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Locale; +import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -116,5 +117,10 @@ public void add(ExprValue value) { public ExprValue result() { return isEmptyCollection ? ExprNullValue.of() : sumResult; } + + @Override + public Set distinctSet() { + return Set.of(); + } } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index aa8d2b12dee..628842b4f09 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -292,6 +292,14 @@ public void aggregation_filter() { ); } + @Test + public void distinct_aggregation() { + assertAnalyzeEqual( + dsl.distinctCount(DSL.ref("integer_value", INTEGER)), + AstDSL.distinctAggregate("count", qualifiedName("integer_value")) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java index cc2825858a2..634a3a71920 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java @@ -116,6 +116,13 @@ public class AggregationTest extends ExpressionTestBase { "timestamp_value", "2040-01-01 07:00:00"))); + protected static List tuples_with_duplicates = + Arrays.asList( + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3))); + protected static List tuples_with_null_and_missing = Arrays.asList( ExprValueUtils.tupleValue( diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java index 494d3cfab2e..33ea4c91233 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java @@ -61,6 +61,13 @@ public void filtered_avg() { assertEquals(3.0, result.value()); } + @Test + public void distinct_avg() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctAvg(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator avg"); + } + @Test public void avg_with_missing() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java index 0fdadfc692c..26a53539ace 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java @@ -129,6 +129,13 @@ public void filtered_count() { assertEquals(3, result.value()); } + @Test + public void distinct_count() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)), + tuples_with_duplicates); + assertEquals(3, result.value()); + } + @Test public void count_with_missing() { ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)), diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java index 5aa9d3a7473..20cde543141 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java @@ -116,6 +116,13 @@ public void filtered_max() { assertEquals(3, result.value()); } + @Test + public void distinct_max() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctMax(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator max"); + } + @Test public void test_max_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java index 01e72b9cdac..e2927772621 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java @@ -116,6 +116,13 @@ public void filtered_min() { assertEquals(2, result.value()); } + @Test + public void distinct_min() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctMin(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator min"); + } + @Test public void test_min_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java index c0872ed4345..fdd24fb5b16 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java @@ -100,6 +100,13 @@ public void filtered_sum() { assertEquals(9, result.value()); } + @Test + public void distinct_sum() { + assertThrows(ExpressionEvaluationException.class, + () -> dsl.distinctSum(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), + "unsupported distinct aggregator sum"); + } + @Test public void sum_with_missing() { ExprValue result = diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 6c6e5b73a14..e7cd34451db 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -5,4 +5,5 @@ SELECT SUM(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights -SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index f3807ae662f..fa06f01f518 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -35,6 +35,8 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; @@ -51,11 +53,15 @@ public class MetricAggregationBuilder extends ExpressionNodeVisitor { private final AggregationBuilderHelper> helper; + private final AggregationBuilderHelper cardinalityHelper; + private final AggregationBuilderHelper termsHelper; private final FilterQueryBuilder filterBuilder; public MetricAggregationBuilder( ExpressionSerializer serializer) { this.helper = new AggregationBuilderHelper<>(serializer); + this.cardinalityHelper = new AggregationBuilderHelper<>(serializer); + this.termsHelper = new AggregationBuilderHelper<>(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -78,8 +84,19 @@ public AggregationBuilder visitNamedAggregator(NamedAggregator node, Object context) { Expression expression = node.getArguments().get(0); Expression condition = node.getDelegated().condition(); + Boolean distinct = node.getDelegated().distinct(); String name = node.getName(); + if (distinct) { + switch (node.getFunctionName().getFunctionName()) { + case "count": + return make(AggregationBuilders.cardinality(name), expression); + default: + throw new IllegalStateException(String.format( + "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); + } + } + switch (node.getFunctionName().getFunctionName()) { case "avg": return make(AggregationBuilders.avg(name), expression, condition, name); @@ -108,6 +125,13 @@ private AggregationBuilder make(ValuesSourceAggregationBuilder builder, return aggregationBuilder; } + /** + * Make {@link CardinalityAggregationBuilder} for distinct count aggregations. + */ + private AggregationBuilder make(CardinalityAggregationBuilder builder, Expression expression) { + return cardinalityHelper.build(expression, builder::field, builder::script); + } + /** * Replace star or literal with OpenSearch metadata field "_index". Because: * 1) Analyzer already converts * to string literal, literal check here can handle diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index b956a2f5a07..c15cb152a38 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -32,12 +32,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Arrays; +import java.util.Collections; import java.util.List; import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; @@ -185,6 +187,21 @@ void should_build_max_aggregation() { new MaxAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_cardinality_aggregation() { + assertEquals( + "{\n" + + " \"count(distinct name)\" : {\n" + + " \"cardinality\" : {\n" + + " \"field\" : \"name\"\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Collections.singletonList(named("count(distinct name)", new CountAggregator( + Collections.singletonList(ref("name", STRING)), STRING).distinct(true))))); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 0ad08781bfe..51c558a68e9 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -336,8 +336,11 @@ caseFuncAlternative ; aggregateFunction - : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall - | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall + : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET + #regularAggregateFunctionCall + | functionName=aggregationFunctionName LR_BRACKET DISTINCT functionArg RR_BRACKET + #distinctAggregateFunctionCall + | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall ; filterClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index b1630aed509..d267a8df4fa 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -212,6 +212,15 @@ public UnresolvedExpression visitRegularAggregateFunctionCall( visitFunctionArg(ctx.functionArg())); } + @Override + public UnresolvedExpression visitDistinctAggregateFunctionCall( + OpenSearchSQLParser.DistinctAggregateFunctionCallContext ctx) { + return new AggregateFunction( + ctx.functionName.getText(), + visitFunctionArg(ctx.functionArg()), + true); + } + @Override public UnresolvedExpression visitCountStarFunctionCall(CountStarFunctionCallContext ctx) { return new AggregateFunction("COUNT", AllFields.of()); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 1d9516f8162..8e7adaaf039 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -36,6 +36,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.distinctAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -167,6 +168,19 @@ void can_build_implicit_group_by_for_aggregator_in_having_clause() { alias("AVG(age)", aggregate("AVG", qualifiedName("age")))))); } + @Test + void can_build_distinct_aggregator() { + assertThat( + buildAggregation("SELECT COUNT(DISTINCT name), AVG(DISTINCT balance) FROM test"), + allOf( + hasGroupByItems(), + hasAggregators( + alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName( + "name"))), + alias("AVG(DISTINCT balance)", distinctAggregate("AVG", qualifiedName( + "balance")))))); + } + @Test void should_build_nothing_if_no_group_by_and_no_aggregators_in_select() { assertNull(buildAggregation("SELECT name FROM test")); From e30b685149e555985ddb09c89e521551c3a2c78c Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Wed, 9 Jun 2021 12:22:32 -0700 Subject: [PATCH 02/17] fixed tests Signed-off-by: chloe-zh --- .../aggregation/AggregationState.java | 4 ++- .../expression/aggregation/Aggregator.java | 2 +- .../expression/aggregation/AvgAggregator.java | 6 ---- .../aggregation/CountAggregator.java | 3 +- .../expression/aggregation/MaxAggregator.java | 6 ---- .../expression/aggregation/MinAggregator.java | 5 --- .../expression/aggregation/SumAggregator.java | 5 --- .../aggregation/AggregatorStateTest.java | 35 +++++++++++++++++++ .../dsl/MetricAggregationBuilder.java | 3 +- .../dsl/MetricAggregationBuilderTest.java | 9 +++++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 1 + .../sql/ppl/parser/AstExpressionBuilder.java | 7 ++++ .../ppl/parser/AstExpressionBuilderTest.java | 30 ++++++++++++++++ 13 files changed, 89 insertions(+), 27 deletions(-) create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java index ed3eca77513..378490e7663 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java @@ -39,5 +39,7 @@ public interface AggregationState { */ ExprValue result(); - Set distinctSet(); + default Set distinctValues() { + return Set.of(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index 1e9af97e8b9..a0a8037751e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -129,7 +129,7 @@ public boolean conditionValue(BindingTuple tuple) { } private Boolean duplicated(ExprValue value, S state) { - for (ExprValue exprValue : state.distinctSet()) { + for (ExprValue exprValue : state.distinctValues()) { if (exprValue.compareTo(value) == 0) { return true; } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java index ca1c32c24f5..0ec0a02a3c1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Locale; -import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -81,10 +80,5 @@ protected static class AvgState implements AggregationState { public ExprValue result() { return count == 0 ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count); } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 36f78c50765..975a39a8cce 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -28,7 +28,6 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; - import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -84,7 +83,7 @@ public ExprValue result() { } @Override - public Set distinctSet() { + public Set distinctValues() { return set; } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java index 9a1d31caad1..11ad63093db 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MaxAggregator.java @@ -30,7 +30,6 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; -import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; @@ -75,10 +74,5 @@ public void max(ExprValue value) { public ExprValue result() { return maxResult; } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java index a40315c8c0a..e9672475bca 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java @@ -80,10 +80,5 @@ public void min(ExprValue value) { public ExprValue result() { return minResult; } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java index afdf61c5ad0..8de5ffb7a2d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java @@ -117,10 +117,5 @@ public void add(ExprValue value) { public ExprValue result() { return isEmptyCollection ? ExprNullValue.of() : sumResult; } - - @Override - public Set distinctSet() { - return Set.of(); - } } } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java new file mode 100644 index 00000000000..338a254148f --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + * + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.opensearch.sql.data.model.ExprIntegerValue; + +public class AggregatorStateTest extends AggregationTest { + + @Test + void count_distinct_values() { + CountAggregator.CountState state = new CountAggregator.CountState(); + state.count(new ExprIntegerValue(1)); + assertFalse(state.distinctValues().isEmpty()); + } + + @Test + void default_distinct_values() { + AvgAggregator.AvgState state = new AvgAggregator.AvgState(); + assertTrue(state.distinctValues().isEmpty()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index fa06f01f518..e3b5be881ce 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -38,6 +38,7 @@ import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.LiteralExpression; @@ -92,7 +93,7 @@ public AggregationBuilder visitNamedAggregator(NamedAggregator node, case "count": return make(AggregationBuilders.cardinality(name), expression); default: - throw new IllegalStateException(String.format( + throw new ExpressionEvaluationException(String.format( "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index c15cb152a38..bacd5413b9d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -49,6 +49,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.MaxAggregator; @@ -202,6 +203,14 @@ void should_build_cardinality_aggregation() { Collections.singletonList(ref("name", STRING)), STRING).distinct(true))))); } + @Test + void should_throw_exception_for_unsupported_distinct_aggregator() { + assertThrows(ExpressionEvaluationException.class, + () -> buildQuery(Collections.singletonList(named("avg(distinct age)", new AvgAggregator( + Collections.singletonList(ref("name", STRING)), STRING).distinct(true)))), + "unsupported distinct aggregator avg"); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 77aecf5a44e..e8b54dab4da 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -135,6 +135,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS #statsFunctionCall | COUNT LT_PRTHS RT_PRTHS #countAllFunctionCall + | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression? RT_PRTHS #distinctCountFunctionCall | percentileAggFunction #percentileAggFunctionCall ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 9fdf8d636d5..ef314072760 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -35,6 +35,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; @@ -203,6 +204,12 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex return new AggregateFunction("count", AllFields.of()); } + @Override + public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { + return new AggregateFunction("count", + ctx.valueExpression() != null ? visit(ctx.valueExpression()) : AllFields.of(), true); + } + @Override public UnresolvedExpression visitPercentileAggFunction(PercentileAggFunctionContext ctx) { return new AggregateFunction(ctx.PERCENTILE().getText(), visit(ctx.aggField), diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 07ad97401e7..6bbfda7aef0 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -37,6 +37,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.defaultFieldsArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultSortFieldArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultStatsArgs; +import static org.opensearch.sql.ast.dsl.AstDSL.distinctAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.equalTo; import static org.opensearch.sql.ast.dsl.AstDSL.eval; @@ -376,6 +377,35 @@ public void testCountFuncCallExpr() { )); } + @Test + public void testDistinctCount() { + assertEqual("source=t | stats distinct_count(a)", + agg( + relation("t"), + exprList( + alias("distinct_count(a)", + distinctAggregate("count", field("a")))), + emptyList(), + emptyList(), + defaultStatsArgs())); + + assertEqual("source=t | stats dc() by b", + agg( + relation("t"), + exprList( + alias( + "dc()", + distinctAggregate("count", AllFields.of()) + ) + ), + emptyList(), + exprList( + alias("b", field("b")) + ), + defaultStatsArgs() + )); + } + @Test public void testEvalFuncCallExpr() { assertEqual("source=t | eval f=abs(a)", From 866d71d7837478d29c7b465094594234e1d5ed73 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 9 Jun 2021 12:29:50 -0700 Subject: [PATCH 03/17] Merge remote-tracking branch 'upstream/develop' into issue/#100 Signed-off-by: chloe-zh # Conflicts: # opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java --- .../aggregation/dsl/MetricAggregationBuilder.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index aa116877dfa..a065b131967 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -99,7 +99,10 @@ public Pair visitNamedAggregator( if (distinct) { switch (node.getFunctionName().getFunctionName()) { case "count": - return make(AggregationBuilders.cardinality(name), expression); + return make( + AggregationBuilders.cardinality(name), + expression, + new SingleValueParser(name)); default: throw new ExpressionEvaluationException(String.format( "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); @@ -171,6 +174,12 @@ private AggregationBuilder make(CardinalityAggregationBuilder builder, Expressio return cardinalityHelper.build(expression, builder::field, builder::script); } + private Pair make(CardinalityAggregationBuilder builder, + Expression expression, + MetricParser parser) { + return Pair.of(cardinalityHelper.build(expression, builder::field, builder::script), parser); + } + /** * Replace star or literal with OpenSearch metadata field "_index". Because: 1) Analyzer already * converts * to string literal, literal check here can handle both COUNT(*) and COUNT(1). 2) From 8a6ca202fbe7a0e1b66fe94619fa899bd99caf7b Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 9 Jun 2021 12:30:34 -0700 Subject: [PATCH 04/17] update Signed-off-by: chloe-zh --- .../script/aggregation/dsl/MetricAggregationBuilder.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index a065b131967..9b6883e4ad4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -170,10 +170,6 @@ private Pair make( /** * Make {@link CardinalityAggregationBuilder} for distinct count aggregations. */ - private AggregationBuilder make(CardinalityAggregationBuilder builder, Expression expression) { - return cardinalityHelper.build(expression, builder::field, builder::script); - } - private Pair make(CardinalityAggregationBuilder builder, Expression expression, MetricParser parser) { From 43cbd17ca065644984648f97401c0e7a3a788758 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 9 Jun 2021 13:12:54 -0700 Subject: [PATCH 05/17] updated user doc Signed-off-by: chloe-zh --- docs/user/dql/aggregations.rst | 13 +++++++++++++ docs/user/ppl/cmd/stats.rst | 15 +++++++++++++++ .../dsl/MetricAggregationBuilder.java | 18 +++++++++--------- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 98b565e1ecd..3c8577fcde7 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -135,6 +135,19 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. +DISTINCT Aggregation +-------------------- + +To get the aggregation of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the aggregation function. Currently the distinct aggregation is only supported in ``COUNT`` aggregation. Example:: + + os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts; + fetched rows / total rows = 1/1 + +--------------------------+-----------------+ + | COUNT(DISTINCT gender) | COUNT(gender) | + |--------------------------+-----------------| + | 2 | 4 | + +--------------------------+-----------------+ + HAVING Clause ============= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index 3aca304fcd7..8a51811689a 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -134,3 +134,18 @@ PPL query:: | 36 | 32 | M | +------------+------------+----------+ +Example 7: Calculate the distinct count of a field +================================================== + +To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts. + +PPL query:: + + os> source=accounts | stats count(gender), distinct_count(gender); + fetched rows / total rows = 1/1 + +-----------------+--------------------------+ + | count(gender) | distinct_count(gender) | + |-----------------+--------------------------| + | 4 | 2 | + +-----------------+--------------------------+ + diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 9b6883e4ad4..84127d9a880 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -37,7 +37,6 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.exception.ExpressionEvaluationException; @@ -58,15 +57,16 @@ public class MetricAggregationBuilder extends ExpressionNodeVisitor, Object> { - private final AggregationBuilderHelper> helper; - private final AggregationBuilderHelper cardinalityHelper; - private final AggregationBuilderHelper termsHelper; + private final AggregationBuilderHelper> valuesSourceAggHelper; + private final AggregationBuilderHelper cardinalityAggHelper; private final FilterQueryBuilder filterBuilder; + /** + * Constructor. + */ public MetricAggregationBuilder(ExpressionSerializer serializer) { - this.helper = new AggregationBuilderHelper<>(serializer); - this.cardinalityHelper = new AggregationBuilderHelper<>(serializer); - this.termsHelper = new AggregationBuilderHelper<>(serializer); + this.valuesSourceAggHelper = new AggregationBuilderHelper<>(serializer); + this.cardinalityAggHelper = new AggregationBuilderHelper<>(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -158,7 +158,7 @@ private Pair make( String name, MetricParser parser) { ValuesSourceAggregationBuilder aggregationBuilder = - helper.build(expression, builder::field, builder::script); + valuesSourceAggHelper.build(expression, builder::field, builder::script); if (condition != null) { return Pair.of( makeFilterAggregation(aggregationBuilder, condition, name), @@ -173,7 +173,7 @@ private Pair make( private Pair make(CardinalityAggregationBuilder builder, Expression expression, MetricParser parser) { - return Pair.of(cardinalityHelper.build(expression, builder::field, builder::script), parser); + return Pair.of(cardinalityAggHelper.build(expression, builder::field, builder::script), parser); } /** From 392c96ca0d4f1effd0b8b11dd266de7e2039c97a Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 21:09:14 -0700 Subject: [PATCH 06/17] Update: support only count for distinct aggregations Signed-off-by: chloe-zh --- .../sql/analysis/ExpressionAnalyzer.java | 4 +- .../org/opensearch/sql/ast/dsl/AstDSL.java | 7 +++- .../sql/ast/expression/AggregateFunction.java | 19 +++------ .../sql/data/model/ExprValueUtils.java | 8 ++++ .../aggregation/AggregationState.java | 5 --- .../expression/aggregation/Aggregator.java | 14 +------ .../aggregation/CountAggregator.java | 40 ++++++++++++++----- .../expression/aggregation/MinAggregator.java | 1 - .../expression/aggregation/SumAggregator.java | 1 - .../sql/analysis/ExpressionAnalyzerTest.java | 12 +++++- .../sql/data/model/ExprValueUtilsTest.java | 4 +- .../aggregation/AggregationTest.java | 8 ++-- .../aggregation/AggregatorStateTest.java | 35 ---------------- .../aggregation/CountAggregatorTest.java | 8 ++++ .../correctness/queries/aggregation.txt | 4 +- sql/src/main/antlr/OpenSearchSQLParser.g4 | 4 +- .../sql/sql/parser/AstExpressionBuilder.java | 10 ++--- .../sql/parser/AstAggregationBuilderTest.java | 16 +++++--- .../sql/parser/AstExpressionBuilderTest.java | 23 +++++++++++ 19 files changed, 121 insertions(+), 102 deletions(-) delete mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 3cc1dc95278..6de239bef15 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -161,8 +161,8 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext Aggregator aggregator = (Aggregator) repository.compile( builtinFunctionName.get().getName(), Collections.singletonList(arg)); aggregator.distinct(node.getDistinct()); - if (node.getCondition() != null) { - aggregator.condition(analyze(node.getCondition(), context)); + if (node.condition() != null) { + aggregator.condition(analyze(node.condition(), context)); } return aggregator; } else { diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index be8f7095db5..3b78483736c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -211,13 +211,18 @@ public static UnresolvedExpression aggregate( public static UnresolvedExpression filteredAggregate( String func, UnresolvedExpression field, UnresolvedExpression condition) { - return new AggregateFunction(func, field, condition); + return new AggregateFunction(func, field).condition(condition); } public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) { return new AggregateFunction(func, field, true); } + public static UnresolvedExpression filteredDistinctCount( + String func, UnresolvedExpression field, UnresolvedExpression condition) { + return new AggregateFunction(func, field, true).condition(condition); + } + public static Function function(String funcName, UnresolvedExpression... funcArgs) { return new Function(funcName, Arrays.asList(funcArgs)); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index d11fdca3ac0..96bd33f1c92 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -28,9 +28,12 @@ import java.util.Collections; import java.util.List; +import javax.annotation.Nullable; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.common.utils.StringUtils; @@ -45,6 +48,8 @@ public class AggregateFunction extends UnresolvedExpression { private final String funcName; private final UnresolvedExpression field; private final List argList; + @Setter + @Accessors(fluent = true) private UnresolvedExpression condition; private Boolean distinct = false; @@ -59,20 +64,6 @@ public AggregateFunction(String funcName, UnresolvedExpression field) { this.argList = Collections.emptyList(); } - /** - * Constructor. - * @param funcName function name. - * @param field {@link UnresolvedExpression}. - * @param condition condition in aggregation filter. - */ - public AggregateFunction(String funcName, UnresolvedExpression field, - UnresolvedExpression condition) { - this.funcName = funcName; - this.field = field; - this.argList = Collections.emptyList(); - this.condition = condition; - } - /** * Constructor. * @param funcName function name. diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java index e2c5fb6a39f..b2172e54f16 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java @@ -157,6 +157,14 @@ public static ExprValue fromObjectValue(Object o, ExprCoreType type) { } } + public static Byte getByteValue(ExprValue exprValue) { + return exprValue.byteValue(); + } + + public static Short getShortValue(ExprValue exprValue) { + return exprValue.shortValue(); + } + public static Integer getIntegerValue(ExprValue exprValue) { return exprValue.integerValue(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java index 378490e7663..b1c29cb4a7a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregationState.java @@ -26,7 +26,6 @@ package org.opensearch.sql.expression.aggregation; -import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.storage.bindingtuple.BindingTuple; @@ -38,8 +37,4 @@ public interface AggregationState { * Get {@link ExprValue} result. */ ExprValue result(); - - default Set distinctValues() { - return Set.of(); - } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java index a0a8037751e..5328e11aadd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/Aggregator.java @@ -69,8 +69,6 @@ public abstract class Aggregator @Accessors(fluent = true) protected Boolean distinct = false; - - /** * Create an {@link AggregationState} which will be used for aggregation. */ @@ -95,8 +93,7 @@ public abstract class Aggregator */ public S iterate(BindingTuple tuple, S state) { ExprValue value = getArguments().get(0).valueOf(tuple); - if (value.isNull() || value.isMissing() || !conditionValue(tuple) - || (distinct && duplicated(value, state))) { + if (value.isNull() || value.isMissing() || !conditionValue(tuple)) { return state; } return iterate(value, state); @@ -128,13 +125,4 @@ public boolean conditionValue(BindingTuple tuple) { return ExprValueUtils.getBooleanValue(condition.valueOf(tuple)); } - private Boolean duplicated(ExprValue value, S state) { - for (ExprValue exprValue : state.distinctValues()) { - if (exprValue.compareTo(value) == 0) { - return true; - } - } - return false; - } - } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 975a39a8cce..34d064fe46d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -26,6 +26,16 @@ package org.opensearch.sql.expression.aggregation; +import static org.opensearch.sql.data.model.ExprValueUtils.getBooleanValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getByteValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getCollectionValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getDoubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getFloatValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getIntegerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getLongValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getShortValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getStringValue; +import static org.opensearch.sql.data.model.ExprValueUtils.getTupleValue; import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.HashSet; @@ -52,7 +62,7 @@ public CountAggregator.CountState create() { @Override protected CountState iterate(ExprValue value, CountState state) { - state.count(value); + state.count(value, distinct); return state; } @@ -66,25 +76,35 @@ public String toString() { */ protected static class CountState implements AggregationState { private int count; - private final Set set = new HashSet<>(); + private final Set distinctValues = new HashSet<>(); CountState() { this.count = 0; } - public void count(ExprValue value) { - set.add(value); - count++; + public void count(ExprValue value, Boolean distinct) { + if (distinct) { + if (!duplicated(value)) { + distinctValues.add(value); + count++; + } + } else { + count++; + } } - @Override - public ExprValue result() { - return ExprValueUtils.integerValue(count); + private boolean duplicated(ExprValue value) { + for (ExprValue exprValue : distinctValues) { + if (value.compareTo(exprValue) == 0) { + return true; + } + } + return false; } @Override - public Set distinctValues() { - return set; + public ExprValue result() { + return ExprValueUtils.integerValue(count); } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java index e9672475bca..46f69129ed8 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/MinAggregator.java @@ -30,7 +30,6 @@ import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.List; -import java.util.Set; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java index 8de5ffb7a2d..e658d21471e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/SumAggregator.java @@ -38,7 +38,6 @@ import java.util.List; import java.util.Locale; -import java.util.Set; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 628842b4f09..06233fbc9b6 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -293,13 +293,23 @@ public void aggregation_filter() { } @Test - public void distinct_aggregation() { + public void distinct_count() { assertAnalyzeEqual( dsl.distinctCount(DSL.ref("integer_value", INTEGER)), AstDSL.distinctAggregate("count", qualifiedName("integer_value")) ); } + @Test + public void filtered_distinct_count() { + assertAnalyzeEqual( + dsl.distinctCount(DSL.ref("integer_value", INTEGER)) + .condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), + AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function( + ">", qualifiedName("integer_value"), intLiteral(1))) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java b/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java index a27d90f35d0..af2dbf22fc1 100644 --- a/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java +++ b/core/src/test/java/org/opensearch/sql/data/model/ExprValueUtilsTest.java @@ -96,8 +96,8 @@ public class ExprValueUtilsTest { Lists.newArrayList(Iterables.concat(numberValues, nonNumberValues)); private static List> numberValueExtractor = Arrays.asList( - ExprValue::byteValue, - ExprValue::shortValue, + ExprValueUtils::getByteValue, + ExprValueUtils::getShortValue, ExprValueUtils::getIntegerValue, ExprValueUtils::getLongValue, ExprValueUtils::getFloatValue, diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java index 634a3a71920..2cce9018bf8 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregationTest.java @@ -118,10 +118,10 @@ public class AggregationTest extends ExpressionTestBase { protected static List tuples_with_duplicates = Arrays.asList( - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3))); + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 4d)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 3d)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2, "double_value", 2d)), + ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3, "double_value", 1d))); protected static List tuples_with_null_and_missing = Arrays.asList( diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java deleted file mode 100644 index 338a254148f..00000000000 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AggregatorStateTest.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - * - */ - -package org.opensearch.sql.expression.aggregation; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import org.junit.jupiter.api.Test; -import org.opensearch.sql.data.model.ExprIntegerValue; - -public class AggregatorStateTest extends AggregationTest { - - @Test - void count_distinct_values() { - CountAggregator.CountState state = new CountAggregator.CountState(); - state.count(new ExprIntegerValue(1)); - assertFalse(state.distinctValues().isEmpty()); - } - - @Test - void default_distinct_values() { - AvgAggregator.AvgState state = new AvgAggregator.AvgState(); - assertTrue(state.distinctValues().isEmpty()); - } -} diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java index 26a53539ace..73bb37a3daf 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/CountAggregatorTest.java @@ -136,6 +136,14 @@ public void distinct_count() { assertEquals(3, result.value()); } + @Test + public void filtered_distinct_count() { + ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)) + .condition(dsl.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))), + tuples_with_duplicates); + assertEquals(2, result.value()); + } + @Test public void count_with_missing() { ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)), diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index e7cd34451db..4fb07e33055 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -6,4 +6,6 @@ SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) +SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 51c558a68e9..ec8ef8bb1a9 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -338,9 +338,9 @@ caseFuncAlternative aggregateFunction : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall - | functionName=aggregationFunctionName LR_BRACKET DISTINCT functionArg RR_BRACKET - #distinctAggregateFunctionCall | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall + | COUNT LR_BRACKET DISTINCT (functionArg | STAR) RR_BRACKET + #distinctCountFunctionCall ; filterClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index d267a8df4fa..62f80eab8e6 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -43,6 +43,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CountStarFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DateLiteralContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DistinctCountFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IsNullPredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext; @@ -171,7 +172,7 @@ public UnresolvedExpression visitShowDescribePattern( public UnresolvedExpression visitFilteredAggregationFunctionCall( OpenSearchSQLParser.FilteredAggregationFunctionCallContext ctx) { AggregateFunction agg = (AggregateFunction) visit(ctx.aggregateFunction()); - return new AggregateFunction(agg.getFuncName(), agg.getField(), visit(ctx.filterClause())); + return agg.condition(visit(ctx.filterClause())); } @Override @@ -213,11 +214,10 @@ public UnresolvedExpression visitRegularAggregateFunctionCall( } @Override - public UnresolvedExpression visitDistinctAggregateFunctionCall( - OpenSearchSQLParser.DistinctAggregateFunctionCallContext ctx) { + public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { return new AggregateFunction( - ctx.functionName.getText(), - visitFunctionArg(ctx.functionArg()), + ctx.COUNT().getText(), + ctx.functionArg() != null ? visitFunctionArg(ctx.functionArg()) : AllFields.of(), true); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 8e7adaaf039..437e8953fac 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -51,6 +51,7 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.UnresolvedPlan; @@ -171,14 +172,19 @@ void can_build_implicit_group_by_for_aggregator_in_having_clause() { @Test void can_build_distinct_aggregator() { assertThat( - buildAggregation("SELECT COUNT(DISTINCT name), AVG(DISTINCT balance) FROM test"), + buildAggregation("SELECT COUNT(DISTINCT name) FROM test group by age"), allOf( - hasGroupByItems(), + hasGroupByItems(alias("age", qualifiedName("age"))), hasAggregators( alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName( - "name"))), - alias("AVG(DISTINCT balance)", distinctAggregate("AVG", qualifiedName( - "balance")))))); + "name")))))); + + assertThat( + buildAggregation("SELECT COUNT(DISTINCT *) FROM test"), + allOf( + hasGroupByItems(), + hasAggregators( + alias("COUNT(DISTINCT *)", distinctAggregate("COUNT", AllFields.of()))))); } @Test diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index a3c8494e7a7..8ddbe7feab6 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -57,6 +57,7 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; @@ -410,6 +411,28 @@ public void filteredAggregation() { ); } + @Test + public void distinctCount() { + assertEquals( + AstDSL.distinctAggregate("count", qualifiedName("name")), + buildExprAst("count(distinct name)") + ); + + assertEquals( + AstDSL.distinctAggregate("count", AllFields.of()), + buildExprAst("count(distinct *)") + ); + } + + @Test + public void filteredDistinctCount() { + assertEquals( + AstDSL.filteredDistinctCount("count", qualifiedName("name"), function( + ">", qualifiedName("age"), intLiteral(30))), + buildExprAst("count(distinct name) filter(where age > 30)") + ); + } + private Node buildExprAst(String expr) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(expr)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); From 078eae75e6647e3be5d7e5589d8099b954f599d1 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 22:09:09 -0700 Subject: [PATCH 07/17] Update doc; removed distinct start Signed-off-by: chloe-zh --- .../aggregation/AvgAggregatorTest.java | 7 ------- .../aggregation/MaxAggregatorTest.java | 7 ------- .../aggregation/MinAggregatorTest.java | 7 ------- .../aggregation/SumAggregatorTest.java | 7 ------- docs/user/dql/aggregations.rst | 19 ++++++++++++++++--- sql/src/main/antlr/OpenSearchSQLParser.g4 | 3 +-- .../sql/sql/parser/AstExpressionBuilder.java | 2 +- .../sql/parser/AstExpressionBuilderTest.java | 6 +----- 8 files changed, 19 insertions(+), 39 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java index 33ea4c91233..494d3cfab2e 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java @@ -61,13 +61,6 @@ public void filtered_avg() { assertEquals(3.0, result.value()); } - @Test - public void distinct_avg() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctAvg(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator avg"); - } - @Test public void avg_with_missing() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java index 20cde543141..5aa9d3a7473 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MaxAggregatorTest.java @@ -116,13 +116,6 @@ public void filtered_max() { assertEquals(3, result.value()); } - @Test - public void distinct_max() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctMax(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator max"); - } - @Test public void test_max_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java index e2927772621..01e72b9cdac 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/MinAggregatorTest.java @@ -116,13 +116,6 @@ public void filtered_min() { assertEquals(2, result.value()); } - @Test - public void distinct_min() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctMin(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator min"); - } - @Test public void test_min_null() { ExprValue result = diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java index fdd24fb5b16..c0872ed4345 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/SumAggregatorTest.java @@ -100,13 +100,6 @@ public void filtered_sum() { assertEquals(9, result.value()); } - @Test - public void distinct_sum() { - assertThrows(ExpressionEvaluationException.class, - () -> dsl.distinctSum(DSL.ref("integer_value", INTEGER)).valueOf(valueEnv()), - "unsupported distinct aggregator sum"); - } - @Test public void sum_with_missing() { ExprValue result = diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 3c8577fcde7..e332da7c144 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -135,10 +135,10 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. -DISTINCT Aggregation --------------------- +DISTINCT COUNT Aggregation +-------------------------- -To get the aggregation of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the aggregation function. Currently the distinct aggregation is only supported in ``COUNT`` aggregation. Example:: +To get the count of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the count aggregation. Example:: os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts; fetched rows / total rows = 1/1 @@ -247,3 +247,16 @@ The ``FILTER`` clause can be used in aggregation functions without GROUP BY as w | 4 | 1 | +--------------+------------+ +Distinct count aggregate with FILTER +------------------------------------ + +The ``FILTER`` clause is also used in distinct count to do the filtering before count the distinct values of specific field. For example:: + + os> SELECT COUNT(DISTINCT firstname) FILTER(WHERE age > 30) AS distinct_count FROM accounts + fetched rows / total rows = 1/1 + +------------------+ + | distinct_count | + |------------------| + | 3 | + +------------------+ + diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index ec8ef8bb1a9..05c1dffe9c9 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -339,8 +339,7 @@ aggregateFunction : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall | COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall - | COUNT LR_BRACKET DISTINCT (functionArg | STAR) RR_BRACKET - #distinctCountFunctionCall + | COUNT LR_BRACKET DISTINCT functionArg RR_BRACKET #distinctCountFunctionCall ; filterClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 62f80eab8e6..8dda63b7505 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -217,7 +217,7 @@ public UnresolvedExpression visitRegularAggregateFunctionCall( public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { return new AggregateFunction( ctx.COUNT().getText(), - ctx.functionArg() != null ? visitFunctionArg(ctx.functionArg()) : AllFields.of(), + visitFunctionArg(ctx.functionArg()), true); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index 8ddbe7feab6..c7c0c9f6fd2 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -28,6 +28,7 @@ package org.opensearch.sql.sql.parser; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.and; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; @@ -417,11 +418,6 @@ public void distinctCount() { AstDSL.distinctAggregate("count", qualifiedName("name")), buildExprAst("count(distinct name)") ); - - assertEquals( - AstDSL.distinctAggregate("count", AllFields.of()), - buildExprAst("count(distinct *)") - ); } @Test From f10b28253b14e2f41474237764cfd752406213df Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 22:22:31 -0700 Subject: [PATCH 08/17] Removed unnecessary methods Signed-off-by: chloe-zh --- .../sql/ast/expression/AggregateFunction.java | 2 +- .../org/opensearch/sql/expression/DSL.java | 16 ---------------- .../dsl/MetricAggregationBuilder.java | 18 +++++++++--------- .../dsl/MetricAggregationBuilderTest.java | 3 +-- 4 files changed, 11 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java index 96bd33f1c92..e909c46ee7a 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/AggregateFunction.java @@ -68,7 +68,7 @@ public AggregateFunction(String funcName, UnresolvedExpression field) { * Constructor. * @param funcName function name. * @param field {@link UnresolvedExpression}. - * @param distinct field is distinct. + * @param distinct whether distinct field is specified or not. */ public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) { this.funcName = funcName; diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 93f86ca1f82..50b10d55dd5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -492,18 +492,10 @@ public Aggregator avg(Expression... expressions) { return aggregate(BuiltinFunctionName.AVG, expressions); } - public Aggregator distinctAvg(Expression... expressions) { - return avg(expressions).distinct(true); - } - public Aggregator sum(Expression... expressions) { return aggregate(BuiltinFunctionName.SUM, expressions); } - public Aggregator distinctSum(Expression... expressions) { - return sum(expressions).distinct(true); - } - public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } @@ -531,18 +523,10 @@ public Aggregator min(Expression... expressions) { return aggregate(BuiltinFunctionName.MIN, expressions); } - public Aggregator distinctMin(Expression... expressions) { - return min(expressions).distinct(true); - } - public Aggregator max(Expression... expressions) { return aggregate(BuiltinFunctionName.MAX, expressions); } - public Aggregator distinctMax(Expression... expressions) { - return max(expressions).distinct(true); - } - private FunctionExpression function(BuiltinFunctionName functionName, Expression... expressions) { return (FunctionExpression) repository.compile( functionName.getName(), Arrays.asList(expressions)); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 84127d9a880..84f2b016343 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -32,6 +32,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; @@ -97,15 +98,14 @@ public Pair visitNamedAggregator( String name = node.getName(); if (distinct) { - switch (node.getFunctionName().getFunctionName()) { - case "count": - return make( - AggregationBuilders.cardinality(name), - expression, - new SingleValueParser(name)); - default: - throw new ExpressionEvaluationException(String.format( - "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); + if ("count".equals(node.getFunctionName().getFunctionName().toLowerCase(Locale.ROOT))) { + return make( + AggregationBuilders.cardinality(name), + expression, + new SingleValueParser(name)); + } else { + throw new IllegalStateException(String.format( + "unsupported distinct aggregator %s", node.getFunctionName().getFunctionName())); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index e8f7fb79edb..e62a6c37c78 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -49,7 +49,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.MaxAggregator; @@ -205,7 +204,7 @@ void should_build_cardinality_aggregation() { @Test void should_throw_exception_for_unsupported_distinct_aggregator() { - assertThrows(ExpressionEvaluationException.class, + assertThrows(IllegalStateException.class, () -> buildQuery(Collections.singletonList(named("avg(distinct age)", new AvgAggregator( Collections.singletonList(ref("name", STRING)), STRING).distinct(true)))), "unsupported distinct aggregator avg"); From df81cfa5b8511f5b34e33df978334dc05cede808 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Thu, 10 Jun 2021 23:43:03 -0700 Subject: [PATCH 09/17] update Signed-off-by: chloe-zh --- .../expression/aggregation/NamedAggregator.java | 4 ++++ .../dsl/MetricAggregationBuilder.java | 14 ++++++++++++-- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../sql/ppl/parser/AstExpressionBuilder.java | 3 +-- .../sql/ppl/parser/AstExpressionBuilderTest.java | 16 ---------------- 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java index a1bf2b99613..02e9c1e8296 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/NamedAggregator.java @@ -54,6 +54,8 @@ public class NamedAggregator extends Aggregator { /** * NamedAggregator. + * The aggregator properties {@link #condition} and {@link #distinct} + * are inherited by named aggregator to avoid errors introduced by the property inconsistency. * * @param name name * @param delegated delegated @@ -64,6 +66,8 @@ public NamedAggregator( super(delegated.getFunctionName(), delegated.getArguments(), delegated.returnType); this.name = name; this.delegated = delegated; + this.distinct = delegated.distinct; + this.condition = delegated.condition != null ? delegated.condition : null; } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 84f2b016343..9a4b5138ae4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -40,7 +40,6 @@ import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; -import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.LiteralExpression; @@ -102,6 +101,8 @@ public Pair visitNamedAggregator( return make( AggregationBuilders.cardinality(name), expression, + condition, + name, new SingleValueParser(name)); } else { throw new IllegalStateException(String.format( @@ -172,8 +173,17 @@ private Pair make( */ private Pair make(CardinalityAggregationBuilder builder, Expression expression, + Expression condition, + String name, MetricParser parser) { - return Pair.of(cardinalityAggHelper.build(expression, builder::field, builder::script), parser); + CardinalityAggregationBuilder aggregationBuilder = + cardinalityAggHelper.build(expression, builder::field, builder::script); + if (condition != null) { + return Pair.of( + makeFilterAggregation(aggregationBuilder, condition, name), + FilterParser.builder().name(name).metricsParser(parser).build()); + } + return Pair.of(aggregationBuilder, parser); } /** diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index e8b54dab4da..6581e7cdbbc 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -135,7 +135,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS #statsFunctionCall | COUNT LT_PRTHS RT_PRTHS #countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression? RT_PRTHS #distinctCountFunctionCall + | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS #distinctCountFunctionCall | percentileAggFunction #percentileAggFunctionCall ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index ef314072760..7da4f90cf0c 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -206,8 +206,7 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex @Override public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { - return new AggregateFunction("count", - ctx.valueExpression() != null ? visit(ctx.valueExpression()) : AllFields.of(), true); + return new AggregateFunction("count", visit(ctx.valueExpression()), true); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 6bbfda7aef0..b1e25420aa2 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -388,22 +388,6 @@ public void testDistinctCount() { emptyList(), emptyList(), defaultStatsArgs())); - - assertEqual("source=t | stats dc() by b", - agg( - relation("t"), - exprList( - alias( - "dc()", - distinctAggregate("count", AllFields.of()) - ) - ), - emptyList(), - exprList( - alias("b", field("b")) - ), - defaultStatsArgs() - )); } @Test From 94a045f742eeec9213f4c194347ff75ab34b75f2 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Fri, 11 Jun 2021 16:47:17 -0700 Subject: [PATCH 10/17] update Signed-off-by: chloe-zh --- .../src/test/resources/correctness/queries/aggregation.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index fa543c9c20b..b3dcc11bace 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -11,5 +11,5 @@ SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) +SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file From 11a9758f611b7974ef9166799695bc1fbb295a39 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Mon, 14 Jun 2021 11:21:01 -0700 Subject: [PATCH 11/17] modified comparison test Signed-off-by: chloe-zh --- .../src/test/resources/correctness/queries/aggregation.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index b3dcc11bace..d3bc194e2e0 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -11,5 +11,5 @@ SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) FROM opensearch_dashboards_sample_data_flights +SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) as distinct_count FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file From d5dc9eb93ea84810253ba1ab4a732359c73fff3e Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Tue, 15 Jun 2021 16:06:27 -0700 Subject: [PATCH 12/17] removed a comparison test and added it to aggregationIT Signed-off-by: chloe-zh --- .../java/org/opensearch/sql/sql/AggregationIT.java | 10 +++++++++- .../test/resources/correctness/queries/aggregation.txt | 1 - 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java index 3cbb222afe1..33cddc6f1f9 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -30,7 +30,15 @@ protected void init() throws Exception { } @Test - void filteredAggregateWithSubquery() throws IOException { + void filteredAggregatePushedDown() throws IOException { + JSONObject response = executeQuery( + "SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TEST_INDEX_BANK); + verifySchema(response, schema("COUNT(*)", null, "integer")); + verifyDataRows(response, rows(3)); + } + + @Test + void filteredAggregateNotPushedDown() throws IOException { JSONObject response = executeQuery( "SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK + ") AS a"); diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index d3bc194e2e0..0c0648a9371 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -11,5 +11,4 @@ SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin), COUNT(DISTINCT Dest) FROM opensearch_dashboards_sample_data_flights -SELECT COUNT(DISTINCT Origin) FILTER(WHERE AvgTicketPrice < 1000) as distinct_count FROM opensearch_dashboards_sample_data_flights SELECT COUNT(DISTINCT Origin) FROM (SELECT * FROM opensearch_dashboards_sample_data_flights) AS flights \ No newline at end of file From 684a7421c72b31ccf51e42396fa526878898b48a Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Tue, 15 Jun 2021 16:42:13 -0700 Subject: [PATCH 13/17] added ppl IT test cases; added window function test cases Signed-off-by: chloe-zh --- .../java/org/opensearch/sql/ppl/StatsCommandIT.java | 13 +++++++++++++ .../test/resources/correctness/queries/window.txt | 3 +++ .../aggregation/dsl/MetricAggregationBuilder.java | 2 +- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java index ff3ad2a6c8b..4a9603fe6bd 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java @@ -77,6 +77,19 @@ public void testStatsCountAll() throws IOException { verifyDataRows(response, rows(1000)); } + @Test + public void testStatsDistinctCount() throws IOException { + JSONObject response = + executeQuery(String.format("source=%s | stats distinct_count(gender)", TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("distinct_count(gender)", null, "integer")); + verifyDataRows(response, rows(2)); + + response = + executeQuery(String.format("source=%s | stats dc(age)", TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("dc(age)", null, "integer")); + verifyDataRows(response, rows(21)); + } + @Test public void testStatsMin() throws IOException { JSONObject response = executeQuery(String.format( diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index c3f27153229..07f74742323 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -5,6 +5,7 @@ SELECT DistanceMiles, ROW_NUMBER() OVER (ORDER BY DistanceMiles DESC) AS num FRO SELECT DistanceMiles, RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, DENSE_RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, COUNT(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT DistanceMiles, COUNT(DISTINCT DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights @@ -24,6 +25,7 @@ SELECT FlightDelayMin, AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER (ORDER BY F SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, COUNT(DISTINCT day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce @@ -33,6 +35,7 @@ SELECT user, VAR_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_ SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, COUNT(DISTINCT day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 4641fd134ff..7a321b4fce2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -38,8 +38,8 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; -import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; From c750f5979096c730a87c5f46550e75d9854146c9 Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Wed, 16 Jun 2021 11:09:57 -0700 Subject: [PATCH 14/17] moved distinct window function test cases to WindowsIT Signed-off-by: chloe-zh --- .../opensearch/sql/sql/WindowFunctionIT.java | 48 +++++++++++++++++++ .../resources/correctness/queries/window.txt | 3 -- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java index b92ca17238c..52373a72e32 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java @@ -29,6 +29,8 @@ import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRowsInOrder; + import org.json.JSONObject; import org.junit.Test; @@ -40,6 +42,7 @@ public class WindowFunctionIT extends SQLIntegTestCase { @Override protected void init() throws Exception { loadIndex(Index.BANK_WITH_NULL_VALUES); + loadIndex(Index.BANK); } @Test @@ -74,4 +77,49 @@ public void testOrderByNullLast() { rows(null, 7)); } + @Test + public void testDistinctCountOverNull() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER() " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRows(response, + rows("Duke Willmington", 2), + rows("Bond", 2), + rows("Bates", 2), + rows("Adams", 2), + rows("Ratliff", 2), + rows("Ayala", 2), + rows("Mcpherson", 2)); + } + + @Test + public void testDistinctCountOver() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER(ORDER BY lastname) " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRowsInOrder(response, + rows("Adams", 1), + rows("Ayala", 2), + rows("Bates", 2), + rows("Bond", 2), + rows("Duke Willmington", 2), + rows("Mcpherson", 2), + rows("Ratliff", 2)); + } + + @Test + public void testDistinctCountPartition() { + JSONObject response = new JSONObject(executeQuery( + "SELECT lastname, COUNT(DISTINCT gender) OVER(PARTITION BY gender ORDER BY lastname) " + + "FROM " + TestsConstants.TEST_INDEX_BANK, "jdbc")); + verifyDataRowsInOrder(response, + rows("Ayala", 1), + rows("Bates", 1), + rows("Mcpherson", 1), + rows("Adams", 1), + rows("Bond", 1), + rows("Duke Willmington", 1), + rows("Ratliff", 1)); + } + } diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index 07f74742323..c3f27153229 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -5,7 +5,6 @@ SELECT DistanceMiles, ROW_NUMBER() OVER (ORDER BY DistanceMiles DESC) AS num FRO SELECT DistanceMiles, RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, DENSE_RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, COUNT(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights -SELECT DistanceMiles, COUNT(DISTINCT DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights @@ -25,7 +24,6 @@ SELECT FlightDelayMin, AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER (ORDER BY F SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce -SELECT user, COUNT(DISTINCT day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce @@ -35,7 +33,6 @@ SELECT user, VAR_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_ SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce -SELECT user, COUNT(DISTINCT day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce From 9fa771d3389b87fc432b637e88320c2aeb8dd341 Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Wed, 16 Jun 2021 14:05:30 -0700 Subject: [PATCH 15/17] added ut Signed-off-by: chloe-zh --- .../aggregation/CountAggregator.java | 10 ----- .../dsl/MetricAggregationBuilderTest.java | 37 ++++++++++++++++++- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index 34d064fe46d..b9653796678 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -26,16 +26,6 @@ package org.opensearch.sql.expression.aggregation; -import static org.opensearch.sql.data.model.ExprValueUtils.getBooleanValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getByteValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getCollectionValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getDoubleValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getFloatValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getIntegerValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getLongValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getShortValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getStringValue; -import static org.opensearch.sql.data.model.ExprValueUtils.getTupleValue; import static org.opensearch.sql.utils.ExpressionUtils.format; import java.util.HashSet; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 1e157139454..129814d45fb 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -53,18 +53,21 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.MaxAggregator; import org.opensearch.sql.expression.aggregation.MinAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.aggregation.SumAggregator; +import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) class MetricAggregationBuilderTest { + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); @Mock private ExpressionSerializer serializer; @@ -271,7 +274,39 @@ void should_build_cardinality_aggregation() { + "}", buildQuery( Collections.singletonList(named("count(distinct name)", new CountAggregator( - Collections.singletonList(ref("name", STRING)), STRING).distinct(true))))); + Collections.singletonList(ref("name", STRING)), INTEGER).distinct(true))))); + } + + @Test + void should_build_filtered_cardinality_aggregation() { + assertEquals( + "{\n" + + " \"count(distinct name) filter(where age > 30)\" : {\n" + + " \"filter\" : {\n" + + " \"range\" : {\n" + + " \"age\" : {\n" + + " \"from\" : 30,\n" + + " \"to\" : null,\n" + + " \"include_lower\" : false,\n" + + " \"include_upper\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + " },\n" + + " \"aggregations\" : {\n" + + " \"count(distinct name) filter(where age > 30)\" : {\n" + + " \"cardinality\" : {\n" + + " \"field\" : \"name\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + buildQuery(Collections.singletonList(named( + "count(distinct name) filter(where age > 30)", + new CountAggregator(Collections.singletonList(ref("name", STRING)), INTEGER) + .condition(dsl.greater(ref("age", INTEGER), literal(30))) + .distinct(true))))); } @Test From 5d42554fc70ed755c5d50d53c332cef71b70d897 Mon Sep 17 00:00:00 2001 From: Chloe Zhang Date: Wed, 16 Jun 2021 14:32:17 -0700 Subject: [PATCH 16/17] update Signed-off-by: chloe-zh --- .../sql/sql/parser/AstAggregationBuilderTest.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java index 437e8953fac..44c84495c23 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -178,13 +178,6 @@ void can_build_distinct_aggregator() { hasAggregators( alias("COUNT(DISTINCT name)", distinctAggregate("COUNT", qualifiedName( "name")))))); - - assertThat( - buildAggregation("SELECT COUNT(DISTINCT *) FROM test"), - allOf( - hasGroupByItems(), - hasAggregators( - alias("COUNT(DISTINCT *)", distinctAggregate("COUNT", AllFields.of()))))); } @Test From 80a4c611b6931223e39ed2443a44e9338a3a6bba Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Thu, 17 Jun 2021 11:01:42 -0700 Subject: [PATCH 17/17] update Signed-off-by: chloe-zh --- .../sql/expression/aggregation/CountAggregator.java | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java index b9653796678..579622b546b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/CountAggregator.java @@ -74,7 +74,7 @@ protected static class CountState implements AggregationState { public void count(ExprValue value, Boolean distinct) { if (distinct) { - if (!duplicated(value)) { + if (!distinctValues.contains(value)) { distinctValues.add(value); count++; } @@ -83,15 +83,6 @@ public void count(ExprValue value, Boolean distinct) { } } - private boolean duplicated(ExprValue value) { - for (ExprValue exprValue : distinctValues) { - if (value.compareTo(exprValue) == 0) { - return true; - } - } - return false; - } - @Override public ExprValue result() { return ExprValueUtils.integerValue(count);