Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Expression arg = node.getField().accept(this, context);
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
aggregator.distinct(node.getDistinct());
if (node.condition() != null) {
aggregator.condition(analyze(node.condition(), context));
}
return aggregator;
} else {
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,16 @@ public static UnresolvedExpression aggregate(

public static UnresolvedExpression filteredAggregate(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, condition);
return new AggregateFunction(func, field).condition(condition);
}

public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) {
return new AggregateFunction(func, field, true);
}

public static UnresolvedExpression filteredDistinctCount(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, true).condition(condition);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@

import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.common.utils.StringUtils;

Expand All @@ -45,7 +48,10 @@ public class AggregateFunction extends UnresolvedExpression {
private final String funcName;
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
@Setter
@Accessors(fluent = true)
private UnresolvedExpression condition;
private Boolean distinct = false;

/**
* Constructor.
Expand All @@ -62,14 +68,13 @@ public AggregateFunction(String funcName, UnresolvedExpression field) {
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param condition condition in aggregation filter.
* @param distinct whether distinct field is specified or not.
*/
public AggregateFunction(String funcName, UnresolvedExpression field,
UnresolvedExpression condition) {
public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.condition = condition;
this.distinct = distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ public static ExprValue fromObjectValue(Object o, ExprCoreType type) {
}
}

public static Byte getByteValue(ExprValue exprValue) {
return exprValue.byteValue();
}

public static Short getShortValue(ExprValue exprValue) {
return exprValue.shortValue();
}

public static Integer getIntegerValue(ExprValue exprValue) {
return exprValue.integerValue();
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ public Aggregator count(Expression... expressions) {
return aggregate(BuiltinFunctionName.COUNT, expressions);
}

public Aggregator distinctCount(Expression... expressions) {
return count(expressions).distinct(true);
}

public Aggregator varSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARSAMP, expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
@Accessors(fluent = true)
protected Expression condition;
@Setter
@Getter
@Accessors(fluent = true)
protected Boolean distinct = false;

/**
* Create an {@link AggregationState} which will be used for aggregation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand All @@ -50,7 +52,7 @@ public CountAggregator.CountState create() {

@Override
protected CountState iterate(ExprValue value, CountState state) {
state.count++;
state.count(value, distinct);
return state;
}

Expand All @@ -64,11 +66,23 @@ public String toString() {
*/
protected static class CountState implements AggregationState {
private int count;
private final Set<ExprValue> distinctValues = new HashSet<>();

CountState() {
this.count = 0;
}

public void count(ExprValue value, Boolean distinct) {
if (distinct) {
if (!distinctValues.contains(value)) {
distinctValues.add(value);
count++;
}
} else {
count++;
}
}

@Override
public ExprValue result() {
return ExprValueUtils.integerValue(count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public class NamedAggregator extends Aggregator<AggregationState> {

/**
* NamedAggregator.
* The aggregator properties {@link #condition} is inherited by named aggregator
* to avoid errors introduced by the property inconsistency.
* The aggregator properties {@link #condition} and {@link #distinct}
* are inherited by named aggregator to avoid errors introduced by the property inconsistency.
*
* @param name name
* @param delegated delegated
Expand All @@ -67,6 +67,7 @@ public NamedAggregator(
this.name = name;
this.delegated = delegated;
this.condition = delegated.condition;
this.distinct = delegated.distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,24 @@ public void variance_mapto_varPop() {
);
}

@Test
public void distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
AstDSL.distinctAggregate("count", qualifiedName("integer_value"))
);
}

@Test
public void filtered_distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function(
">", qualifiedName("integer_value"), intLiteral(1)))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ public class ExprValueUtilsTest {
Lists.newArrayList(Iterables.concat(numberValues, nonNumberValues));

private static List<Function<ExprValue, Object>> numberValueExtractor = Arrays.asList(
ExprValue::byteValue,
ExprValue::shortValue,
ExprValueUtils::getByteValue,
ExprValueUtils::getShortValue,
ExprValueUtils::getIntegerValue,
ExprValueUtils::getLongValue,
ExprValueUtils::getFloatValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ public class AggregationTest extends ExpressionTestBase {
"timestamp_value",
"2040-01-01 07:00:00")));

protected static List<ExprValue> tuples_with_duplicates =
Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 4d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, "double_value", 3d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2, "double_value", 2d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3, "double_value", 1d)));

protected static List<ExprValue> tuples_with_null_and_missing =
Arrays.asList(
ExprValueUtils.tupleValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ public void filtered_count() {
assertEquals(3, result.value());
}

@Test
public void distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void filtered_distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))),
tuples_with_duplicates);
assertEquals(2, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down
26 changes: 26 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,19 @@ Example::
| 2.8613807855648994 |
+--------------------+

DISTINCT COUNT Aggregation
--------------------------

To get the count of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the count aggregation. Example::

os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts;
fetched rows / total rows = 1/1
+--------------------------+-----------------+
| COUNT(DISTINCT gender) | COUNT(gender) |
|--------------------------+-----------------|
| 2 | 4 |
+--------------------------+-----------------+

HAVING Clause
=============

Expand Down Expand Up @@ -456,3 +469,16 @@ The ``FILTER`` clause can be used in aggregation functions without GROUP BY as w
| 4 | 1 |
+--------------+------------+

Distinct count aggregate with FILTER
------------------------------------

The ``FILTER`` clause is also used in distinct count to do the filtering before count the distinct values of specific field. For example::

os> SELECT COUNT(DISTINCT firstname) FILTER(WHERE age > 30) AS distinct_count FROM accounts
fetched rows / total rows = 1/1
+------------------+
| distinct_count |
|------------------|
| 3 |
+------------------+

15 changes: 15 additions & 0 deletions docs/user/ppl/cmd/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,18 @@ PPL query::
| 36 | 32 | M |
+------------+------------+----------+

Example 7: Calculate the distinct count of a field
==================================================

To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts.

PPL query::

os> source=accounts | stats count(gender), distinct_count(gender);
fetched rows / total rows = 1/1
+-----------------+--------------------------+
| count(gender) | distinct_count(gender) |
|-----------------+--------------------------|
| 4 | 2 |
+-----------------+--------------------------+

Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ public void testStatsCountAll() throws IOException {
verifyDataRows(response, rows(1000));
}

@Test
public void testStatsDistinctCount() throws IOException {
JSONObject response =
executeQuery(String.format("source=%s | stats distinct_count(gender)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("distinct_count(gender)", null, "integer"));
verifyDataRows(response, rows(2));

response =
executeQuery(String.format("source=%s | stats dc(age)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("dc(age)", null, "integer"));
verifyDataRows(response, rows(21));
}

@Test
public void testStatsMin() throws IOException {
JSONObject response = executeQuery(String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ protected void init() throws Exception {
}

@Test
void filteredAggregateWithSubquery() throws IOException {
void filteredAggregatePushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TEST_INDEX_BANK);
verifySchema(response, schema("COUNT(*)", null, "integer"));
verifyDataRows(response, rows(3));
}

@Test
void filteredAggregateNotPushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK
+ ") AS a");
Expand Down
Loading