Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions core/src/main/java/org/apache/druid/math/expr/ExprEval.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,18 @@ public abstract class ExprEval<T>
*/
public static ExprEval deserialize(ByteBuffer buffer, int position)
{
// | expression type (byte) | expression bytes |
ExprType type = ExprType.fromByte(buffer.get(position));
int offset = position + 1;
final ExprType type = ExprType.fromByte(buffer.get(position));
return deserialize(buffer, position + 1, type);
}

/**
* Deserialize an expression stored in a bytebuffer, e.g. for an agg.
*
* This should be refactored to be consolidated with some of the standard type handling of aggregators probably
*/
public static ExprEval deserialize(ByteBuffer buffer, int offset, ExprType type)
{
// | expression bytes |
switch (type) {
case LONG:
// | expression type (byte) | is null (byte) | long bytes |
Expand Down
18 changes: 16 additions & 2 deletions core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,23 @@ private void assertExpr(int position, ExprEval expected, int maxSizeBytes)
{
ExprEval.serialize(buffer, position, expected, maxSizeBytes);
if (ExprType.isArray(expected.type())) {
Assert.assertArrayEquals(expected.asArray(), ExprEval.deserialize(buffer, position).asArray());
Assert.assertArrayEquals(
expected.asArray(),
ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).asArray()
);
Assert.assertArrayEquals(
expected.asArray(),
ExprEval.deserialize(buffer, position).asArray()
);
} else {
Assert.assertEquals(expected.value(), ExprEval.deserialize(buffer, position).value());
Assert.assertEquals(
expected.value(),
ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).value()
);
Assert.assertEquals(
expected.value(),
ExprEval.deserialize(buffer, position).value()
);
}
assertEstimatedBytes(expected, maxSizeBytes);
}
Expand Down
3 changes: 3 additions & 0 deletions docs/querying/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ Only the COUNT and ARRAY_AGG aggregations can accept the DISTINCT keyword.
|`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.|N/A|
|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`|
|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`|
|`BIT_AND(expr)`|Performs a bitwise AND operation on all input values.|`null` if `druid.generic.useDefaultValueForNull=false`, otherwise `0`|
|`BIT_OR(expr)`|Performs a bitwise OR operation on all input values.|`null` if `druid.generic.useDefaultValueForNull=false`, otherwise `0`|
|`BIT_XOR(expr)`|Performs a bitwise XOR operation on all input values.|`null` if `druid.generic.useDefaultValueForNull=false`, otherwise `0`|

For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.md#approx).

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,19 @@ public class ExpressionLambdaAggregator implements Aggregator
private final Expr lambda;
private final ExpressionLambdaAggregatorInputBindings bindings;
private final int maxSizeBytes;
private boolean hasValue;

public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings, int maxSizeBytes)
public ExpressionLambdaAggregator(
final Expr lambda,
final ExpressionLambdaAggregatorInputBindings bindings,
final boolean isNullUnlessAggregated,
final int maxSizeBytes
)
{
this.lambda = lambda;
this.bindings = bindings;
this.maxSizeBytes = maxSizeBytes;
this.hasValue = !isNullUnlessAggregated;
}

@Override
Expand All @@ -43,13 +50,14 @@ public void aggregate()
final ExprEval<?> eval = lambda.eval(bindings);
ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes);
bindings.accumulate(eval);
hasValue = true;
}

@Nullable
@Override
public Object get()
{
return bindings.getAccumulator().value();
return hasValue ? bindings.getAccumulator().value() : null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Comparators;
Expand Down Expand Up @@ -74,6 +75,7 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
private final String foldExpressionString;
private final String initialValueExpressionString;
private final String initialCombineValueExpressionString;
private final boolean isNullUnlessAggregated;

private final String combineExpressionString;
@Nullable
Expand Down Expand Up @@ -105,6 +107,7 @@ public ExpressionLambdaAggregatorFactory(
@JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier,
@JsonProperty("initialValue") final String initialValue,
@JsonProperty("initialCombineValue") @Nullable final String initialCombineValue,
@JsonProperty("isNullUnlessAggregated") @Nullable final Boolean isNullUnlessAggregated,
@JsonProperty("fold") final String foldExpression,
@JsonProperty("combine") @Nullable final String combineExpression,
@JsonProperty("compare") @Nullable final String compareExpression,
Expand All @@ -121,6 +124,7 @@ public ExpressionLambdaAggregatorFactory(

this.initialValueExpressionString = initialValue;
this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue;
this.isNullUnlessAggregated = isNullUnlessAggregated == null ? NullHandling.sqlCompatible() : isNullUnlessAggregated;
this.foldExpressionString = foldExpression;
if (combineExpression != null) {
this.combineExpressionString = combineExpression;
Expand Down Expand Up @@ -195,6 +199,12 @@ public String getInitialCombineValueExpressionString()
return initialCombineValueExpressionString;
}

@JsonProperty("isNullUnlessAggregated")
public boolean getIsNullUnlessAggregated()
{
return isNullUnlessAggregated;
}

@JsonProperty("fold")
public String getFoldExpressionString()
{
Expand Down Expand Up @@ -249,6 +259,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory)
return new ExpressionLambdaAggregator(
thePlan.getExpression(),
thePlan.getBindings(),
isNullUnlessAggregated,
maxSizeBytes.getBytesInInt()
);
}
Expand All @@ -261,6 +272,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
thePlan.getExpression(),
thePlan.getInitialValue(),
thePlan.getBindings(),
isNullUnlessAggregated,
maxSizeBytes.getBytesInInt()
);
}
Expand Down Expand Up @@ -329,6 +341,7 @@ public AggregatorFactory getCombiningFactory()
accumulatorId,
initialValueExpressionString,
initialCombineValueExpressionString,
isNullUnlessAggregated,
foldExpressionString,
combineExpressionString,
compareExpressionString,
Expand All @@ -348,6 +361,7 @@ public List<AggregatorFactory> getRequiredColumns()
accumulatorId,
initialValueExpressionString,
initialCombineValueExpressionString,
isNullUnlessAggregated,
foldExpressionString,
combineExpressionString,
compareExpressionString,
Expand Down Expand Up @@ -407,6 +421,7 @@ public boolean equals(Object o)
&& foldExpressionString.equals(that.foldExpressionString)
&& initialValueExpressionString.equals(that.initialValueExpressionString)
&& initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString)
&& isNullUnlessAggregated == that.isNullUnlessAggregated
&& combineExpressionString.equals(that.combineExpressionString)
&& Objects.equals(compareExpressionString, that.compareExpressionString)
&& Objects.equals(finalizeExpressionString, that.finalizeExpressionString);
Expand All @@ -422,6 +437,7 @@ public int hashCode()
foldExpressionString,
initialValueExpressionString,
initialCombineValueExpressionString,
isNullUnlessAggregated,
combineExpressionString,
compareExpressionString,
finalizeExpressionString,
Expand All @@ -439,6 +455,7 @@ public String toString()
", foldExpressionString='" + foldExpressionString + '\'' +
", initialValueExpressionString='" + initialValueExpressionString + '\'' +
", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' +
", nullUnlessAggregated='" + isNullUnlessAggregated + '\'' +
", combineExpressionString='" + combineExpressionString + '\'' +
", compareExpressionString='" + compareExpressionString + '\'' +
", finalizeExpressionString='" + finalizeExpressionString + '\'' +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,73 +21,93 @@

import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprType;

import javax.annotation.Nullable;
import java.nio.ByteBuffer;

public class ExpressionLambdaBufferAggregator implements BufferAggregator
{
private static final short NOT_AGGREGATED_BIT = 1 << 7;
private static final short IS_AGGREGATED_MASK = 0x3F;
private final Expr lambda;
private final ExprEval<?> initialValue;
private final ExpressionLambdaAggregatorInputBindings bindings;
private final int maxSizeBytes;
private final boolean isNullUnlessAggregated;

public ExpressionLambdaBufferAggregator(
Expr lambda,
ExprEval<?> initialValue,
ExpressionLambdaAggregatorInputBindings bindings,
boolean isNullUnlessAggregated,
int maxSizeBytes
)
{
this.lambda = lambda;
this.initialValue = initialValue;
this.bindings = bindings;
this.isNullUnlessAggregated = isNullUnlessAggregated;
this.maxSizeBytes = maxSizeBytes;
}

@Override
public void init(ByteBuffer buf, int position)
{
ExprEval.serialize(buf, position, initialValue, maxSizeBytes);
// set a bit to indicate we haven't aggregated on top of expression type (not going to lie this could be nicer)
if (isNullUnlessAggregated) {
buf.put(position, (byte) (buf.get(position) | NOT_AGGREGATED_BIT));
}
}

@Override
public void aggregate(ByteBuffer buf, int position)
{
ExprEval<?> acc = ExprEval.deserialize(buf, position);
ExprEval<?> acc = ExprEval.deserialize(buf, position + 1, getType(buf, position));
bindings.setAccumulator(acc);
ExprEval<?> newAcc = lambda.eval(bindings);
ExprEval.serialize(buf, position, newAcc, maxSizeBytes);
// scrub not aggregated bit
buf.put(position, (byte) (buf.get(position) & IS_AGGREGATED_MASK));
}

@Nullable
@Override
public Object get(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position).value();
if (isNullUnlessAggregated && (buf.get(position) & NOT_AGGREGATED_BIT) != 0) {
return null;
}
return ExprEval.deserialize(buf, position + 1, getType(buf, position)).value();
}

@Override
public float getFloat(ByteBuffer buf, int position)
{
return (float) ExprEval.deserialize(buf, position).asDouble();
return (float) ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble();
}

@Override
public double getDouble(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position).asDouble();
return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble();
}

@Override
public long getLong(ByteBuffer buf, int position)
{
return ExprEval.deserialize(buf, position).asLong();
return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asLong();
}

@Override
public void close()
{
// nothing to close
}

private static ExprType getType(ByteBuffer buf, int position)
{
return ExprType.fromByte((byte) (buf.get(position) & IS_AGGREGATED_MASK));
}
}
Loading