From 9e494a192de4f4ea3e52e45367b3ebe0ebbe5704 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Sat, 17 Oct 2020 17:21:06 +0530 Subject: [PATCH 1/8] First draft of grouping_id function --- .../query/aggregation/AggregatorUtil.java | 4 + .../GroupingAggregatorFactory.java | 258 ++++++++++++++++++ .../constant/LongConstantAggregator.java | 62 +++++ .../LongConstantBufferAggregator.java | 71 +++++ .../LongConstantVectorAggregator.java | 66 +++++ .../epinephelinae/GroupByRowProcessor.java | 7 +- .../epinephelinae/RowBasedGrouperHelper.java | 29 +- .../groupby/strategy/GroupByStrategyV2.java | 11 +- .../GroupingAggregatorFactoryTest.java | 174 ++++++++++++ .../builtin/GroupingSqlAggregator.java | 126 +++++++++ .../calcite/planner/DruidOperatorTable.java | 2 + 11 files changed, 799 insertions(+), 11 deletions(-) create mode 100644 processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java create mode 100644 processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java create mode 100644 processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java create mode 100644 processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java create mode 100644 processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java create mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java index 2cc5f1b06662..3c7b8d474395 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorUtil.java @@ -134,6 +134,10 @@ public class AggregatorUtil public static final byte FLOAT_ANY_CACHE_TYPE_ID = 0x44; public static final byte STRING_ANY_CACHE_TYPE_ID = 0x45; + // GROUPING aggregator + public static final byte GROUPING_CACHE_TYPE_ID = 0x46; + + /** * returns the list of dependent postAggregators that should be calculated in order to calculate given postAgg * diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java new file mode 100644 index 000000000000..478eafd1da52 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import org.apache.druid.annotations.EverythingIsNonnullByDefault; +import org.apache.druid.query.aggregation.constant.LongConstantAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantBufferAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantVectorAggregator; +import org.apache.druid.query.cache.CacheKeyBuilder; +import org.apache.druid.segment.ColumnInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.vector.VectorColumnSelectorFactory; +import org.apache.druid.utils.CollectionUtils; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +@EverythingIsNonnullByDefault +public class GroupingAggregatorFactory extends AggregatorFactory +{ + private static final Comparator VALUE_COMPARATOR = Long::compare; + private final String name; + private final List groupings; + private final long value; + @Nullable + private final Set keyDimensions; + + @JsonCreator + public GroupingAggregatorFactory( + @JsonProperty("name") String name, + @JsonProperty("groupings") List groupings + ) + { + this(name, groupings, null); + } + + @VisibleForTesting + GroupingAggregatorFactory( + String name, + List groupings, + @Nullable Set keyDimensions + ) + { + Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); + this.name = name; + this.groupings = groupings; + this.keyDimensions = keyDimensions; + value = groupingId(groupings, keyDimensions); + } + + @Override + public Aggregator factorize(ColumnSelectorFactory metricFactory) + { + return new LongConstantAggregator(value); + } + + @Override + public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) + { + return new LongConstantBufferAggregator(value); + } + + @Override + public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory) + { + return new LongConstantVectorAggregator(value); + } + + @Override + public boolean canVectorize(ColumnInspector columnInspector) + { + return true; + } + + /** + * Replace the param {@code keyDimensions} with the new set of key dimensions + */ + public GroupingAggregatorFactory withKeyDimensions(Set newKeyDimensions) + { + return new GroupingAggregatorFactory(name, groupings, newKeyDimensions); + } + + @Override + public Comparator getComparator() + { + return VALUE_COMPARATOR; + } + + @JsonProperty + public List getGroupings() + { + return groupings; + } + + @Override + @JsonProperty + public String getName() + { + return name; + } + + public long getValue() + { + return value; + } + + @Nullable + @Override + public Object combine(@Nullable Object lhs, @Nullable Object rhs) + { + return lhs; + } + + @Override + public AggregatorFactory getCombiningFactory() + { + return new GroupingAggregatorFactory(name, groupings, keyDimensions); + } + + @Override + public List getRequiredColumns() + { + return Collections.singletonList(new GroupingAggregatorFactory(name, groupings, keyDimensions)); + } + + @Override + public Object deserialize(Object object) + { + return object; + } + + @Nullable + @Override + public Object finalizeComputation(@Nullable Object object) + { + return object; + } + + @Override + public List requiredFields() + { + // The aggregator doesn't need to read any fields. + return Collections.emptyList(); + } + + @Override + public ValueType getType() + { + return ValueType.LONG; + } + + @Override + public ValueType getFinalizedType() + { + return ValueType.LONG; + } + + @Override + public int getMaxIntermediateSize() + { + return Integer.BYTES; + } + + @Override + public byte[] getCacheKey() + { + return new CacheKeyBuilder(AggregatorUtil.GROUPING_CACHE_TYPE_ID) + .appendStrings(groupings) + .build(); + } + + private long groupingId(List groupings, @Nullable Set keyDimensions) + { + Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions"); + // Integer.size is just a sanity check. In practice, it will be just few dimensions. + Preconditions.checkArgument( + groupings.size() < Integer.SIZE, + "Number of dimensions %s is more than supported %s", + groupings.size(), + Integer.SIZE - 1 + ); + long temp = 0L; + for (String groupingDimension : groupings) { + if (isDimensionIncluded(groupingDimension, keyDimensions)) { + temp = temp | 1L; + } + temp = temp << 1; + } + return temp >> 1; + } + + private boolean isDimensionIncluded(String dimToCheck, @Nullable Set keyDimensions) + { + if (null == keyDimensions) { + // All dimensions are included + return true; + } else { + return keyDimensions.contains(dimToCheck); + } + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GroupingAggregatorFactory factory = (GroupingAggregatorFactory) o; + return name.equals(factory.name) && + groupings.equals(factory.groupings) && + Objects.equals(keyDimensions, factory.keyDimensions); + } + + @Override + public int hashCode() + { + return Objects.hash(name, groupings, keyDimensions); + } + + @Override + public String toString() + { + return "GroupingAggregatorFactory{" + + "name='" + name + '\'' + + ", groupings=" + groupings + + ", keyDimensions=" + keyDimensions + + '}'; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java new file mode 100644 index 000000000000..60b6eb3e63b8 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.druid.query.aggregation.Aggregator; + +public class LongConstantAggregator implements Aggregator +{ + private final long value; + + public LongConstantAggregator(long value) + { + this.value = value; + } + + @Override + public void aggregate() + { + // No-op + } + + @Override + public Object get() + { + return value; + } + + @Override + public float getFloat() + { + return (float) value; + } + + @Override + public long getLong() + { + return value; + } + + @Override + public void close() + { + + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java new file mode 100644 index 000000000000..bc06a2725ca7 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.druid.query.aggregation.BufferAggregator; + +import java.nio.ByteBuffer; + +public class LongConstantBufferAggregator implements BufferAggregator +{ + private final long value; + + public LongConstantBufferAggregator(long value) + { + this.value = value; + } + + @Override + public void init(ByteBuffer buf, int position) + { + // Since we always return a constant value despite what is in the buffer, there is no need to + // update the buffer at all + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + + } + + @Override + public Object get(ByteBuffer buf, int position) + { + return value; + } + + @Override + public float getFloat(ByteBuffer buf, int position) + { + return (float) value; + } + + @Override + public long getLong(ByteBuffer buf, int position) + { + return value; + } + + @Override + public void close() + { + + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java new file mode 100644 index 000000000000..f127decad009 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.druid.query.aggregation.VectorAggregator; + +import javax.annotation.Nullable; +import java.nio.ByteBuffer; + +public class LongConstantVectorAggregator implements VectorAggregator +{ + private final long value; + + public LongConstantVectorAggregator(long value) + { + this.value = value; + } + + @Override + public void init(ByteBuffer buf, int position) + { + // Since we always return a constant value despite what is in the buffer, there is no need to + // update the buffer at all + } + + @Override + public void aggregate(ByteBuffer buf, int position, int startRow, int endRow) + { + + } + + @Override + public void aggregate(ByteBuffer buf, int numRows, int[] positions, @Nullable int[] rows, int positionOffset) + { + + } + + @Override + public Object get(ByteBuffer buf, int position) + { + return value; + } + + @Override + public void close() + { + + } +} diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java index f86b2f0c3425..7e753af1331e 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByRowProcessor.java @@ -29,6 +29,7 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.query.ResourceLimitExceededException; +import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.ResultRow; @@ -66,7 +67,7 @@ public interface ResultSupplier extends Closeable * @param dimensionsToInclude list of dimensions to include, or null to include all dimensions. Used by processing * of subtotals. If specified, the results will not necessarily be fully grouped. */ - Sequence results(@Nullable List dimensionsToInclude); + Sequence results(@Nullable List dimensionsToInclude); } private GroupByRowProcessor() @@ -140,7 +141,7 @@ public ByteBuffer get() return new ResultSupplier() { @Override - public Sequence results(@Nullable List dimensionsToInclude) + public Sequence results(@Nullable List dimensionsToInclude) { return getRowsFromGrouper(query, grouper, dimensionsToInclude); } @@ -156,7 +157,7 @@ public void close() throws IOException private static Sequence getRowsFromGrouper( final GroupByQuery query, final Grouper grouper, - @Nullable List dimensionsToInclude + @Nullable List dimensionsToInclude ) { return new BaseSequence<>( diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index c099eedd7617..98579009ec8c 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -44,6 +44,7 @@ import org.apache.druid.query.BaseQuery; import org.apache.druid.query.ColumnSelectorPlus; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; import org.apache.druid.query.dimension.ColumnSelectorStrategy; import org.apache.druid.query.dimension.ColumnSelectorStrategyFactory; import org.apache.druid.query.dimension.DimensionSpec; @@ -87,6 +88,7 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.function.ToLongFunction; +import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -534,21 +536,35 @@ public static CloseableGrouperIterator makeGrouperIterat public static CloseableGrouperIterator makeGrouperIterator( final Grouper grouper, final GroupByQuery query, - @Nullable final List dimsToInclude, + @Nullable final List dimsToInclude, final Closeable closeable ) { final boolean includeTimestamp = query.getResultRowHasTimestamp(); final BitSet dimsToIncludeBitSet = new BitSet(query.getDimensions().size()); final int resultRowDimensionStart = query.getResultRowDimensionStart(); + final BitSet groupingAggregatorsBitSet = new BitSet(query.getAggregatorSpecs().size()); + final Object[] groupingAggregatorValues = new Long[query.getAggregatorSpecs().size()]; if (dimsToInclude != null) { - for (String dimension : dimsToInclude) { - final int dimIndex = query.getResultRowSignature().indexOf(dimension); + for (DimensionSpec dimensionSpec : dimsToInclude) { + String outputName = dimensionSpec.getOutputName(); + final int dimIndex = query.getResultRowSignature().indexOf(outputName); if (dimIndex >= 0) { dimsToIncludeBitSet.set(dimIndex - resultRowDimensionStart); } } + + Set keyDimensionNames = dimsToInclude.stream().map(DimensionSpec::getDimension).collect(Collectors.toSet()); + for (int i = 0; i < query.getAggregatorSpecs().size(); i++) { + AggregatorFactory aggregatorFactory = query.getAggregatorSpecs().get(i); + if (aggregatorFactory instanceof GroupingAggregatorFactory) { + groupingAggregatorsBitSet.set(i); + groupingAggregatorValues[i] = ((GroupingAggregatorFactory) aggregatorFactory) + .withKeyDimensions(keyDimensionNames) + .getValue(); + } + } } return new CloseableGrouperIterator<>( @@ -576,7 +592,12 @@ public static CloseableGrouperIterator makeGrouperIterat // Add aggregations. final int resultRowAggregatorStart = query.getResultRowAggregatorStart(); for (int i = 0; i < entry.getValues().length; i++) { - resultRow.set(resultRowAggregatorStart + i, entry.getValues()[i]); + if (dimsToInclude != null && groupingAggregatorsBitSet.get(i)) { + resultRow.set(resultRowAggregatorStart + i, groupingAggregatorValues[i]); + } else { + resultRow.set(resultRowAggregatorStart + i, entry.getValues()[i]); + + } } return resultRow; diff --git a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java index e81eded2f9c0..5e7c8b2d6bf6 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java @@ -405,6 +405,8 @@ public Sequence processSubtotalsSpec( // Iterate through each subtotalSpec, build results for it and add to subtotalsResults for (List subtotalSpec : subtotals) { final ImmutableSet dimsInSubtotalSpec = ImmutableSet.copyOf(subtotalSpec); + // Dimension spec including dimension name and output name + final List subTotalDimensionSpec = new ArrayList<>(dimsInSubtotalSpec.size()); final List dimensions = query.getDimensions(); final List newDimensions = new ArrayList<>(); @@ -418,6 +420,7 @@ public Sequence processSubtotalsSpec( dimensionSpec.getOutputType() ) ); + subTotalDimensionSpec.add(dimensionSpec); } else { // Insert dummy dimension so all subtotals queries have ResultRows with the same shape. // Use a field name that does not appear in the main query result, to assure the result will be null. @@ -447,7 +450,7 @@ public Sequence processSubtotalsSpec( // Since subtotalSpec is a prefix of base query dimensions, so results from base query are also sorted // by subtotalSpec as needed by stream merging. subtotalsResults.add( - processSubtotalsResultAndOptionallyClose(() -> resultSupplierOneFinal, subtotalSpec, subtotalQuery, false) + processSubtotalsResultAndOptionallyClose(() -> resultSupplierOneFinal, subTotalDimensionSpec, subtotalQuery, false) ); } else { // Since subtotalSpec is not a prefix of base query dimensions, so results from base query are not sorted @@ -459,7 +462,7 @@ public Sequence processSubtotalsSpec( Supplier resultSupplierTwo = () -> GroupByRowProcessor.process( baseSubtotalQuery, subtotalQuery, - resultSupplierOneFinal.results(subtotalSpec), + resultSupplierOneFinal.results(subTotalDimensionSpec), configSupplier.get(), resource, spillMapper, @@ -468,7 +471,7 @@ public Sequence processSubtotalsSpec( ); subtotalsResults.add( - processSubtotalsResultAndOptionallyClose(resultSupplierTwo, subtotalSpec, subtotalQuery, true) + processSubtotalsResultAndOptionallyClose(resultSupplierTwo, subTotalDimensionSpec, subtotalQuery, true) ); } } @@ -486,7 +489,7 @@ public Sequence processSubtotalsSpec( private Sequence processSubtotalsResultAndOptionallyClose( Supplier baseResultsSupplier, - List dimsToInclude, + List dimsToInclude, GroupByQuery subtotalQuery, boolean closeOnSequenceRead ) diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java new file mode 100644 index 000000000000..5f9285e73371 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation; + +import com.google.common.collect.Sets; +import junitparams.converters.Nullable; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.query.aggregation.constant.LongConstantAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantBufferAggregator; +import org.apache.druid.query.aggregation.constant.LongConstantVectorAggregator; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; + +@RunWith(Enclosed.class) +public class GroupingAggregatorFactoryTest +{ + public static GroupingAggregatorFactory makeFactory(String[] groupings, @Nullable String[] keyDims) + { + GroupingAggregatorFactory factory = new GroupingAggregatorFactory("name", Arrays.asList(groupings)); + if (null != keyDims) { + factory = (GroupingAggregatorFactory) factory.withKeyDimensions(Sets.newHashSet(keyDims)); + } + return factory; + } + + public static class NewAggregatorTests + { + private ColumnSelectorFactory metricFactory; + + @Before + public void setup() + { + metricFactory = EasyMock.mock(ColumnSelectorFactory.class); + } + + @Test + public void testNewAggregator() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + Aggregator aggregator = factory.factorize(metricFactory); + Assert.assertEquals(LongConstantAggregator.class, aggregator.getClass()); + Assert.assertEquals(2, aggregator.getLong()); + } + + @Test + public void testNewBufferAggregator() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + BufferAggregator aggregator = factory.factorizeBuffered(metricFactory); + Assert.assertEquals(LongConstantBufferAggregator.class, aggregator.getClass()); + Assert.assertEquals(2, aggregator.getLong(null, 0)); + } + + @Test + public void testNewVectorAggregator() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + Assert.assertTrue(factory.canVectorize(metricFactory)); + VectorAggregator aggregator = factory.factorizeVector(null); + Assert.assertEquals(LongConstantVectorAggregator.class, aggregator.getClass()); + Assert.assertEquals(2L, aggregator.get(null, 0)); + } + + @Test + public void testWithKeyDimensions() + { + GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); + Aggregator aggregator = factory.factorize(metricFactory); + Assert.assertEquals(2, aggregator.getLong()); + factory = (GroupingAggregatorFactory) factory.withKeyDimensions(Sets.newHashSet("b")); + aggregator = factory.factorize(metricFactory); + Assert.assertEquals(1, aggregator.getLong()); + } + } + + public static class GroupingDimensionsTest + { + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Test + public void testFactory_nullGroupingDimensions() + { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Must have a non-empty grouping dimensions"); + GroupingAggregatorFactory factory = new GroupingAggregatorFactory("name", null, Sets.newHashSet("b")); + } + + @Test + public void testFactory_emptyGroupingDimensions() + { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Must have a non-empty grouping dimensions"); + makeFactory(new String[0], null); + } + + @Test + public void testFactory_highNumberOfGroupingDimensions() + { + exception.expect(IllegalArgumentException.class); + exception.expectMessage(StringUtils.format( + "Number of dimensions %d is more than supported %d", + Integer.SIZE, + Integer.SIZE - 1 + )); + makeFactory(new String[Integer.SIZE], null); + } + } + + @RunWith(Parameterized.class) + public static class ValueTests + { + private final GroupingAggregatorFactory factory; + private final long value; + + public ValueTests(String[] groupings, @Nullable String[] keyDimensions, long value) + { + factory = makeFactory(groupings, keyDimensions); + this.value = value; + } + + @Parameterized.Parameters + public static Collection arguments() + { + String[] maxGroupingList = new String[Integer.SIZE - 1]; + for (int i = 0; i < maxGroupingList.length; i++) { + maxGroupingList[i] = String.valueOf(i); + } + return Arrays.asList(new Object[][]{ + {new String[]{"a", "b"}, new String[0], 0}, + {new String[]{"a", "b"}, null, 3}, + {new String[]{"a", "b"}, new String[]{"a"}, 2}, + {new String[]{"a", "b"}, new String[]{"b"}, 1}, + {new String[]{"a", "b"}, new String[]{"a", "b"}, 3}, + {new String[]{"b", "a"}, new String[]{"a"}, 1}, + {maxGroupingList, null, Integer.MAX_VALUE} + }); + } + + @Test + public void testValue() + { + Assert.assertEquals(value, factory.factorize(null).getLong()); + } + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java new file mode 100644 index 000000000000..bbff23f2375e --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/GroupingSqlAggregator.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.aggregation.builtin; + +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; +import org.apache.druid.segment.VirtualColumn; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.sql.calcite.aggregation.Aggregation; +import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class GroupingSqlAggregator implements SqlAggregator +{ + @Override + public SqlAggFunction calciteFunction() + { + return SqlStdOperatorTable.GROUPING; + } + + @Nullable + @Override + public Aggregation toDruidAggregation( + PlannerContext plannerContext, + RowSignature rowSignature, + VirtualColumnRegistry virtualColumnRegistry, + RexBuilder rexBuilder, + String name, + AggregateCall aggregateCall, + Project project, + List existingAggregations, + boolean finalizeAggregations + ) + { + List arguments = aggregateCall.getArgList() + .stream() + .map(i -> getColumnName( + plannerContext, + rowSignature, + project, + virtualColumnRegistry, + i + )) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + if (arguments.size() < aggregateCall.getArgList().size()) { + return null; + } + + for (Aggregation existing : existingAggregations) { + for (AggregatorFactory factory : existing.getAggregatorFactories()) { + if (!(factory instanceof GroupingAggregatorFactory)) { + continue; + } + GroupingAggregatorFactory groupingFactory = (GroupingAggregatorFactory) factory; + if (groupingFactory.getGroupings().equals(arguments) + && groupingFactory.getName().equals(name)) { + return Aggregation.create(groupingFactory); + } + } + } + AggregatorFactory factory = new GroupingAggregatorFactory(name, arguments); + return Aggregation.create(factory); + } + + @Nullable + private String getColumnName( + PlannerContext plannerContext, + RowSignature rowSignature, + Project project, + VirtualColumnRegistry virtualColumnRegistry, + int fieldNumber + ) + { + RexNode node = Expressions.fromFieldAccess(rowSignature, project, fieldNumber); + if (null == node) { + return null; + } + DruidExpression expression = Expressions.toDruidExpression(plannerContext, rowSignature, node); + if (null == expression) { + return null; + } + if (expression.isDirectColumnAccess()) { + return expression.getDirectColumn(); + } + + VirtualColumn virtualColumn = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( + plannerContext, + expression, + node.getType() + ); + return virtualColumn.getOutputName(); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index e1fb0b48fe54..9fbc8dffb53a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -37,6 +37,7 @@ import org.apache.druid.sql.calcite.aggregation.builtin.AvgSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.CountSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator; +import org.apache.druid.sql.calcite.aggregation.builtin.GroupingSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.MaxSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.MinSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.SumSqlAggregator; @@ -130,6 +131,7 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new MaxSqlAggregator()) .add(new SumSqlAggregator()) .add(new SumZeroSqlAggregator()) + .add(new GroupingSqlAggregator()) .build(); From 60f1041472a1b14cfd98359fcddca27e07bd9da0 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Mon, 19 Oct 2020 17:11:37 +0530 Subject: [PATCH 2/8] Add more tests and documentation --- docs/querying/aggregations.md | 22 ++++++ docs/querying/groupbyquery.md | 4 +- docs/querying/sql.md | 4 +- .../druid/jackson/AggregatorsModule.java | 4 +- .../GroupingAggregatorFactory.java | 42 ++++++++--- .../epinephelinae/RowBasedGrouperHelper.java | 2 + .../GroupingAggregatorFactoryTest.java | 10 +-- .../constant/LongConstantAggregatorTest.java | 63 +++++++++++++++++ .../LongConstantBufferAggregatorTest.java | 70 +++++++++++++++++++ .../LongConstantVectorAggregatorTest.java | 65 +++++++++++++++++ 10 files changed, 269 insertions(+), 17 deletions(-) create mode 100644 processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantAggregatorTest.java create mode 100644 processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregatorTest.java create mode 100644 processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregatorTest.java diff --git a/docs/querying/aggregations.md b/docs/querying/aggregations.md index 7da5b5c56bd5..bd75913c7594 100644 --- a/docs/querying/aggregations.md +++ b/docs/querying/aggregations.md @@ -426,3 +426,25 @@ This makes it possible to compute the results of a filtered and an unfiltered ag "aggregator" : } ``` + +### Grouping Aggregator + +A grouping aggregator can only be used as part of GroupBy queries which have a subtotal spec. It returns a number for +each output row that lets you infer whether a particular dimension is included in the sub-grouping used for that row. You can pass +a *non-empty* list of dimensions to this aggregator which *must* be a subset of dimensions that you are grouping on. +E.g if the aggregator has `["dim1", "dim2"]` as input dimensions and `[["dim1", "dim2"], ["dim1"], ["dim2"], []]` as subtotals, +following can be the possible output of the aggregator + +| subtotal used in query | Output | (bits representation) | +|------------------------|--------|-----------------------| +| `["dim1", "dim2"]` | 3 | (11) | +| `["dim1"]` | 2 | (10) | +| `["dim2"]` | 1 | (01) | +| `[]` | 0 | (00) | + +As illustrated in above example, output number can be though of as an unsigned n bit number where n is the number of dimensions passed to the aggregator. +The bit at position X is set in this number to 1 if a dimension at position X in input to aggregators is included in the sub-grouping. + +```json +{ "type" : "grouping", "name" : , "groupings" : [] } +``` \ No newline at end of file diff --git a/docs/querying/groupbyquery.md b/docs/querying/groupbyquery.md index 652953490321..73ac973b6d54 100644 --- a/docs/querying/groupbyquery.md +++ b/docs/querying/groupbyquery.md @@ -226,7 +226,9 @@ The response for the query above would look something like: ] ``` -> Notice that dimensions that are not included in an individual subtotalsSpec grouping are returned with a `null` value. This response format represents a behavior change as of Apache Druid 0.18.0. In release 0.17.0 and earlier, such dimensions were entirely excluded from the result. +> Notice that dimensions that are not included in an individual subtotalsSpec grouping are returned with a `null` value. This response format represents a behavior change as of Apache Druid 0.18.0. +> In release 0.17.0 and earlier, such dimensions were entirely excluded from the result. If you were relying on this old behaviour to determine whether a particular dimension was not part of +> a subtotal grouping, you can now use [Grouping aggregator](aggregations.md#Grouping Aggregator) instead. ## Implementation details diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 702070d9ace3..ffbf382ff837 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -99,7 +99,8 @@ total. Finally, GROUP BY CUBE computes a grouping set for each combination of gr `GROUP BY CUBE (country, city)` is equivalent to `GROUP BY GROUPING SETS ( (country, city), (country), (city), () )`. Grouping columns that do not apply to a particular row will contain `NULL`. For example, when computing `GROUP BY GROUPING SETS ( (country, city), () )`, the grand total row corresponding to `()` will have `NULL` for the -"country" and "city" columns. +"country" and "city" columns. Column may also be `NULL` if it was `NULL` in the data itself. To differentiate such rows +, you can use `GROUPING` aggregation. When using GROUP BY GROUPING SETS, GROUP BY ROLLUP, or GROUP BY CUBE, be aware that results may not be generated in the order that you specify your grouping sets in the query. If you need results to be generated in a particular order, use @@ -337,6 +338,7 @@ Only the COUNT aggregation can accept DISTINCT. |`LATEST(expr, maxBytesPerString)`|Like `LATEST(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| |`ANY_VALUE(expr)`|Returns any value of `expr` including null. `expr` must be numeric. This aggregator can simplify and optimize the performance by returning the first encountered value (including null)| |`ANY_VALUE(expr, maxBytesPerString)`|Like `ANY_VALUE(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| +|`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#Grouping Aggregator) on how to infer this number.| For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.html#approx). diff --git a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java index a974edd764cb..795ea5b6d31c 100644 --- a/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java +++ b/processing/src/main/java/org/apache/druid/jackson/AggregatorsModule.java @@ -31,6 +31,7 @@ import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; import org.apache.druid.query.aggregation.FloatSumAggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; import org.apache.druid.query.aggregation.HistogramAggregatorFactory; import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; @@ -118,7 +119,8 @@ public AggregatorsModule() @JsonSubTypes.Type(name = "longAny", value = LongAnyAggregatorFactory.class), @JsonSubTypes.Type(name = "floatAny", value = FloatAnyAggregatorFactory.class), @JsonSubTypes.Type(name = "doubleAny", value = DoubleAnyAggregatorFactory.class), - @JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class) + @JsonSubTypes.Type(name = "stringAny", value = StringAnyAggregatorFactory.class), + @JsonSubTypes.Type(name = "grouping", value = GroupingAggregatorFactory.class) }) public interface AggregatorFactoryMixin { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java index 478eafd1da52..31dfb8f20184 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -134,6 +134,9 @@ public long getValue() @Override public Object combine(@Nullable Object lhs, @Nullable Object rhs) { + if (null == lhs) { + return rhs; + } return lhs; } @@ -184,35 +187,56 @@ public ValueType getFinalizedType() @Override public int getMaxIntermediateSize() { - return Integer.BYTES; + return Long.BYTES; } @Override public byte[] getCacheKey() { - return new CacheKeyBuilder(AggregatorUtil.GROUPING_CACHE_TYPE_ID) - .appendStrings(groupings) - .build(); + CacheKeyBuilder keyBuilder = new CacheKeyBuilder(AggregatorUtil.GROUPING_CACHE_TYPE_ID) + .appendStrings(groupings); + if (null != keyDimensions) { + keyBuilder.appendStrings(keyDimensions); + } + return keyBuilder.build(); } + /** + * Gives the list of grouping dimensions, return a long value where each bit at position X in the returned value + * corresponds to the dimension in groupings at same position X. X is the position relative to the right end. if + * keyDimensions contain the grouping dimension at position X, the bit is set to 1 at position X, otherwise it is + * set to 0. An example adapted from Microsoft SQL documentation + * + * groupings keyDimensions value (3 least significant bits) value (long) + * a,b,c [a] 100 4 + * a,b,c [b] 010 2 + * a,b,c [c] 001 1 + * a,b,c [a,b] 110 6 + * a,b,c [a,c] 101 5 + * a,b,c [b,c] 011 3 + * a,b,c [a,b,c] 111 7 + * a,b,c [] 000 0 // None included + * a,b,c 111 7 // All included + */ private long groupingId(List groupings, @Nullable Set keyDimensions) { Preconditions.checkArgument(!CollectionUtils.isNullOrEmpty(groupings), "Must have a non-empty grouping dimensions"); - // Integer.size is just a sanity check. In practice, it will be just few dimensions. + // (Long.SIZE - 1) is just a sanity check. In practice, it will be just few dimensions. This limit + // also makes sure that values are always positive. Preconditions.checkArgument( - groupings.size() < Integer.SIZE, + groupings.size() < Long.SIZE, "Number of dimensions %s is more than supported %s", groupings.size(), - Integer.SIZE - 1 + Long.SIZE - 1 ); long temp = 0L; for (String groupingDimension : groupings) { + temp = temp << 1; if (isDimensionIncluded(groupingDimension, keyDimensions)) { temp = temp | 1L; } - temp = temp << 1; } - return temp >> 1; + return temp; } private boolean isDimensionIncluded(String dimToCheck, @Nullable Set keyDimensions) diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index 98579009ec8c..fd0ed2efaa16 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -555,6 +555,8 @@ public static CloseableGrouperIterator makeGrouperIterat } } + // KeyDimensionNames are the column names of dimensions. Its required since aggregators are not aware of + // output column names Set keyDimensionNames = dimsToInclude.stream().map(DimensionSpec::getDimension).collect(Collectors.toSet()); for (int i = 0; i < query.getAggregatorSpecs().size(); i++) { AggregatorFactory aggregatorFactory = query.getAggregatorSpecs().get(i); diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java index 5f9285e73371..5a733ce23b7f 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java @@ -128,10 +128,10 @@ public void testFactory_highNumberOfGroupingDimensions() exception.expect(IllegalArgumentException.class); exception.expectMessage(StringUtils.format( "Number of dimensions %d is more than supported %d", - Integer.SIZE, - Integer.SIZE - 1 + Long.SIZE, + Long.SIZE - 1 )); - makeFactory(new String[Integer.SIZE], null); + makeFactory(new String[Long.SIZE], null); } } @@ -150,7 +150,7 @@ public ValueTests(String[] groupings, @Nullable String[] keyDimensions, long val @Parameterized.Parameters public static Collection arguments() { - String[] maxGroupingList = new String[Integer.SIZE - 1]; + String[] maxGroupingList = new String[Long.SIZE - 1]; for (int i = 0; i < maxGroupingList.length; i++) { maxGroupingList[i] = String.valueOf(i); } @@ -161,7 +161,7 @@ public static Collection arguments() {new String[]{"a", "b"}, new String[]{"b"}, 1}, {new String[]{"a", "b"}, new String[]{"a", "b"}, 3}, {new String[]{"b", "a"}, new String[]{"a"}, 1}, - {maxGroupingList, null, Integer.MAX_VALUE} + {maxGroupingList, null, Long.MAX_VALUE} }); } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantAggregatorTest.java new file mode 100644 index 000000000000..d4f8b02220bf --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantAggregatorTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.commons.lang.math.RandomUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class LongConstantAggregatorTest +{ + private long randomVal; + private LongConstantAggregator aggregator; + + @Before + public void setup() + { + randomVal = RandomUtils.nextLong(); + aggregator = new LongConstantAggregator(randomVal); + } + + @Test + public void testLong() + { + Assert.assertEquals(randomVal, aggregator.getLong()); + } + + @Test + public void testAggregate() + { + aggregator.aggregate(); + Assert.assertEquals(randomVal, aggregator.getLong()); + } + + @Test + public void testFloat() + { + Assert.assertEquals((float) randomVal, aggregator.getFloat(), 0.0001f); + } + + @Test + public void testGet() + { + Assert.assertEquals(randomVal, aggregator.get()); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregatorTest.java new file mode 100644 index 000000000000..0608fca85b74 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregatorTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.commons.lang.math.RandomUtils; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class LongConstantBufferAggregatorTest +{ + private long randomVal; + private LongConstantBufferAggregator aggregator; + private ByteBuffer byteBuffer; + + @Before + public void setup() + { + randomVal = RandomUtils.nextLong(); + aggregator = new LongConstantBufferAggregator(randomVal); + byteBuffer = EasyMock.mock(ByteBuffer.class); + EasyMock.replay(byteBuffer); + EasyMock.verifyUnexpectedCalls(byteBuffer); + } + + @Test + public void testLong() + { + Assert.assertEquals(randomVal, aggregator.getLong(byteBuffer, 0)); + } + + @Test + public void testAggregate() + { + aggregator.aggregate(byteBuffer, 0); + Assert.assertEquals(randomVal, aggregator.getLong(byteBuffer, 0)); + } + + @Test + public void testFloat() + { + Assert.assertEquals((float) randomVal, aggregator.getFloat(byteBuffer, 0), 0.0001f); + } + + @Test + public void testGet() + { + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregatorTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregatorTest.java new file mode 100644 index 000000000000..f62dd0369c61 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregatorTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.aggregation.constant; + +import org.apache.commons.lang.math.RandomUtils; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +public class LongConstantVectorAggregatorTest +{ + private long randomVal; + private LongConstantVectorAggregator aggregator; + private ByteBuffer byteBuffer; + + @Before + public void setup() + { + randomVal = RandomUtils.nextLong(); + aggregator = new LongConstantVectorAggregator(randomVal); + byteBuffer = EasyMock.mock(ByteBuffer.class); + EasyMock.replay(byteBuffer); + EasyMock.verifyUnexpectedCalls(byteBuffer); + } + + @Test + public void testAggregate() + { + aggregator.aggregate(byteBuffer, 0, 1, 10); + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } + + @Test + public void testAggregateWithIndirection() + { + aggregator.aggregate(byteBuffer, 2, new int[]{2, 3}, null, 0); + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } + + @Test + public void testGet() + { + Assert.assertEquals(randomVal, aggregator.get(byteBuffer, 0)); + } +} From b36161f92c3651c028294f3ab96b9153bb9747be Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Tue, 17 Nov 2020 13:32:29 +0530 Subject: [PATCH 3/8] Add calcite tests --- .../GroupingAggregatorFactory.java | 2 +- .../druid/sql/calcite/CalciteQueryTest.java | 142 ++++++++++++++++-- 2 files changed, 130 insertions(+), 14 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java index 31dfb8f20184..be8f19f04ca2 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -202,7 +202,7 @@ public byte[] getCacheKey() } /** - * Gives the list of grouping dimensions, return a long value where each bit at position X in the returned value + * Given the list of grouping dimensions, returns a long value where each bit at position X in the returned value * corresponds to the dimension in groupings at same position X. X is the position relative to the right end. if * keyDimensions contain the grouping dimension at position X, the bit is set to 1 at position X, otherwise it is * set to 0. An example adapted from Microsoft SQL documentation diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 5e116c7aaa6c..5fd43a5cca1a 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -55,6 +55,7 @@ import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; +import org.apache.druid.query.aggregation.GroupingAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; import org.apache.druid.query.aggregation.LongMinAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; @@ -74,6 +75,7 @@ import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory; import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory; import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator; +import org.apache.druid.query.aggregation.post.ExpressionPostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; @@ -123,6 +125,7 @@ import org.junit.runner.RunWith; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -12028,7 +12031,7 @@ public void testGroupingSets() throws Exception cannotVectorize(); testQuery( - "SELECT dim2, gran, SUM(cnt)\n" + "SELECT dim2, gran, SUM(cnt), GROUPING(dim2, gran)\n" + "FROM (SELECT FLOOR(__time TO MONTH) AS gran, COALESCE(dim2, '') dim2, cnt FROM druid.foo) AS x\n" + "GROUP BY GROUPING SETS ( (dim2, gran), (dim2), (gran), () )", ImmutableList.of( @@ -12054,7 +12057,10 @@ public void testGroupingSets() throws Exception new DefaultDimensionSpec("v1", "d1", ValueType.LONG) ) ) - .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("a0", "cnt"), + new GroupingAggregatorFactory("a1", Arrays.asList("v0", "v1")) + )) .setSubtotalsSpec( ImmutableList.of( ImmutableList.of("d0", "d1"), @@ -12067,17 +12073,127 @@ public void testGroupingSets() throws Exception .build() ), ImmutableList.of( - new Object[]{"", timestamp("2000-01-01"), 2L}, - new Object[]{"", timestamp("2001-01-01"), 1L}, - new Object[]{"a", timestamp("2000-01-01"), 1L}, - new Object[]{"a", timestamp("2001-01-01"), 1L}, - new Object[]{"abc", timestamp("2001-01-01"), 1L}, - new Object[]{"", null, 3L}, - new Object[]{"a", null, 2L}, - new Object[]{"abc", null, 1L}, - new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L}, - new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L}, - new Object[]{NULL_STRING, null, 6L} + new Object[]{"", timestamp("2000-01-01"), 2L, 3L}, + new Object[]{"", timestamp("2001-01-01"), 1L, 3L}, + new Object[]{"a", timestamp("2000-01-01"), 1L, 3L}, + new Object[]{"a", timestamp("2001-01-01"), 1L, 3L}, + new Object[]{"abc", timestamp("2001-01-01"), 1L, 3L}, + new Object[]{"", null, 3L, 2L}, + new Object[]{"a", null, 2L, 2L}, + new Object[]{"abc", null, 1L, 2L}, + new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 1L}, + new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 1L}, + new Object[]{NULL_STRING, null, 6L, 0L} + ) + ); + } + + @Test + public void testGroupingAggregatorDifferentOrder() throws Exception + { + // Cannot vectorize due to virtual columns. + cannotVectorize(); + + testQuery( + "SELECT dim2, gran, SUM(cnt), GROUPING(gran, dim2)\n" + + "FROM (SELECT FLOOR(__time TO MONTH) AS gran, COALESCE(dim2, '') dim2, cnt FROM druid.foo) AS x\n" + + "GROUP BY GROUPING SETS ( (dim2, gran), (dim2), (gran), () )", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "case_searched(notnull(\"dim2\"),\"dim2\",'')", + ValueType.STRING + ), + expressionVirtualColumn( + "v1", + "timestamp_floor(\"__time\",'P1M',null,'UTC')", + ValueType.LONG + ) + ) + .setDimensions( + dimensions( + new DefaultDimensionSpec("v0", "d0"), + new DefaultDimensionSpec("v1", "d1", ValueType.LONG) + ) + ) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("a0", "cnt"), + new GroupingAggregatorFactory("a1", Arrays.asList("v1", "v0")) + )) + .setSubtotalsSpec( + ImmutableList.of( + ImmutableList.of("d0", "d1"), + ImmutableList.of("d0"), + ImmutableList.of("d1"), + ImmutableList.of() + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", timestamp("2000-01-01"), 2L, 3L}, + new Object[]{"", timestamp("2001-01-01"), 1L, 3L}, + new Object[]{"a", timestamp("2000-01-01"), 1L, 3L}, + new Object[]{"a", timestamp("2001-01-01"), 1L, 3L}, + new Object[]{"abc", timestamp("2001-01-01"), 1L, 3L}, + new Object[]{"", null, 3L, 1L}, + new Object[]{"a", null, 2L, 1L}, + new Object[]{"abc", null, 1L, 1L}, + new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 2L}, + new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 2L}, + new Object[]{NULL_STRING, null, 6L, 0L} + ) + ); + } + + @Test + public void testGroupingAggregatorWithPostAggregator() throws Exception + { + testQuery( + "SELECT dim2, SUM(cnt), GROUPING(dim2), \n" + + "CASE WHEN GROUPING(dim2) = 0 THEN 'ALL' ELSE 'INDIVIDUAL' END\n" + + "FROM druid.foo\n" + + "GROUP BY GROUPING SETS ( (dim2), () )", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + dimensions( + new DefaultDimensionSpec("dim2", "d0", ValueType.STRING) + ) + ) + .setAggregatorSpecs(aggregators( + new LongSumAggregatorFactory("a0", "cnt"), + new GroupingAggregatorFactory("a1", Collections.singletonList("dim2")) + )) + .setSubtotalsSpec( + ImmutableList.of( + ImmutableList.of("d0"), + ImmutableList.of() + ) + ) + .setPostAggregatorSpecs(Collections.singletonList(new ExpressionPostAggregator( + "p0", + "case_searched((\"a1\" == 0),'ALL','INDIVIDUAL')", + null, + ExprMacroTable.nil() + ))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"", 3L, 1L, "INDIVIDUAL"}, + new Object[]{"a", 2L, 1L, "INDIVIDUAL"}, + new Object[]{"abc", 1L, 1L, "INDIVIDUAL"}, + new Object[]{NULL_STRING, 6L, 0L, "ALL"} ) ); } From 18ec231dc043f845b884ebb7f7d4948bf08b0f37 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Thu, 19 Nov 2020 12:48:33 +0530 Subject: [PATCH 4/8] Fix travis failures --- docs/querying/groupbyquery.md | 2 +- .../druid/sql/calcite/CalciteQueryTest.java | 24 ++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/querying/groupbyquery.md b/docs/querying/groupbyquery.md index 73ac973b6d54..58d654ae699c 100644 --- a/docs/querying/groupbyquery.md +++ b/docs/querying/groupbyquery.md @@ -227,7 +227,7 @@ The response for the query above would look something like: ``` > Notice that dimensions that are not included in an individual subtotalsSpec grouping are returned with a `null` value. This response format represents a behavior change as of Apache Druid 0.18.0. -> In release 0.17.0 and earlier, such dimensions were entirely excluded from the result. If you were relying on this old behaviour to determine whether a particular dimension was not part of +> In release 0.17.0 and earlier, such dimensions were entirely excluded from the result. If you were relying on this old behavior to determine whether a particular dimension was not part of > a subtotal grouping, you can now use [Grouping aggregator](aggregations.md#Grouping Aggregator) instead. diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 5fd43a5cca1a..3a7f17bd9db9 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -12155,6 +12155,23 @@ public void testGroupingAggregatorDifferentOrder() throws Exception @Test public void testGroupingAggregatorWithPostAggregator() throws Exception { + List resultList; + if (NullHandling.sqlCompatible()) { + resultList = ImmutableList.of( + new Object[]{NULL_STRING, 2L, 1L, "INDIVIDUAL"}, + new Object[]{"", 1L, 1L, "INDIVIDUAL"}, + new Object[]{"a", 2L, 1L, "INDIVIDUAL"}, + new Object[]{"abc", 1L, 1L, "INDIVIDUAL"}, + new Object[]{NULL_STRING, 6L, 0L, "ALL"} + ); + } else { + resultList = ImmutableList.of( + new Object[]{"", 3L, 1L, "INDIVIDUAL"}, + new Object[]{"a", 2L, 1L, "INDIVIDUAL"}, + new Object[]{"abc", 1L, 1L, "INDIVIDUAL"}, + new Object[]{NULL_STRING, 6L, 0L, "ALL"} + ); + } testQuery( "SELECT dim2, SUM(cnt), GROUPING(dim2), \n" + "CASE WHEN GROUPING(dim2) = 0 THEN 'ALL' ELSE 'INDIVIDUAL' END\n" @@ -12189,12 +12206,7 @@ public void testGroupingAggregatorWithPostAggregator() throws Exception .setContext(QUERY_CONTEXT_DEFAULT) .build() ), - ImmutableList.of( - new Object[]{"", 3L, 1L, "INDIVIDUAL"}, - new Object[]{"a", 2L, 1L, "INDIVIDUAL"}, - new Object[]{"abc", 1L, 1L, "INDIVIDUAL"}, - new Object[]{NULL_STRING, 6L, 0L, "ALL"} - ) + resultList ); } From 1ae99fc38e42da00853fd9343b777f95e793df89 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Tue, 1 Dec 2020 13:23:12 +0530 Subject: [PATCH 5/8] bit of a change --- docs/querying/aggregations.md | 15 +++-- .../GroupingAggregatorFactory.java | 24 +++---- .../GroupingAggregatorFactoryTest.java | 29 ++++---- .../druid/sql/calcite/CalciteQueryTest.java | 66 +++++++++---------- 4 files changed, 68 insertions(+), 66 deletions(-) diff --git a/docs/querying/aggregations.md b/docs/querying/aggregations.md index bd75913c7594..8fd99824e0d6 100644 --- a/docs/querying/aggregations.md +++ b/docs/querying/aggregations.md @@ -437,13 +437,14 @@ following can be the possible output of the aggregator | subtotal used in query | Output | (bits representation) | |------------------------|--------|-----------------------| -| `["dim1", "dim2"]` | 3 | (11) | -| `["dim1"]` | 2 | (10) | -| `["dim2"]` | 1 | (01) | -| `[]` | 0 | (00) | - -As illustrated in above example, output number can be though of as an unsigned n bit number where n is the number of dimensions passed to the aggregator. -The bit at position X is set in this number to 1 if a dimension at position X in input to aggregators is included in the sub-grouping. +| `["dim1", "dim2"]` | 0 | (00) | +| `["dim1"]` | 1 | (01) | +| `["dim2"]` | 2 | (10) | +| `[]` | 3 | (11) | + +As illustrated in above example, output number can be thought of as an unsigned n bit number where n is the number of dimensions passed to the aggregator. +The bit at position X is set in this number to 0 if a dimension at position X in input to aggregators is included in the sub-grouping. Otherwise, this bit +is set to 1. ```json { "type" : "grouping", "name" : , "groupings" : [] } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java index be8f19f04ca2..a96bc45c9175 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -204,19 +204,19 @@ public byte[] getCacheKey() /** * Given the list of grouping dimensions, returns a long value where each bit at position X in the returned value * corresponds to the dimension in groupings at same position X. X is the position relative to the right end. if - * keyDimensions contain the grouping dimension at position X, the bit is set to 1 at position X, otherwise it is - * set to 0. An example adapted from Microsoft SQL documentation + * keyDimensions contain the grouping dimension at position X, the bit is set to 0 at position X, otherwise it is + * set to 1. An example adapted from Microsoft SQL documentation * * groupings keyDimensions value (3 least significant bits) value (long) - * a,b,c [a] 100 4 - * a,b,c [b] 010 2 - * a,b,c [c] 001 1 - * a,b,c [a,b] 110 6 - * a,b,c [a,c] 101 5 - * a,b,c [b,c] 011 3 - * a,b,c [a,b,c] 111 7 - * a,b,c [] 000 0 // None included - * a,b,c 111 7 // All included + * a,b,c [a] 011 3 + * a,b,c [b] 101 5 + * a,b,c [c] 110 6 + * a,b,c [a,b] 001 1 + * a,b,c [a,c] 010 2 + * a,b,c [b,c] 100 4 + * a,b,c [a,b,c] 000 0 + * a,b,c [] 111 7 // None included + * a,b,c 000 0 // All included */ private long groupingId(List groupings, @Nullable Set keyDimensions) { @@ -232,7 +232,7 @@ private long groupingId(List groupings, @Nullable Set keyDimensi long temp = 0L; for (String groupingDimension : groupings) { temp = temp << 1; - if (isDimensionIncluded(groupingDimension, keyDimensions)) { + if (!isDimensionIncluded(groupingDimension, keyDimensions)) { temp = temp | 1L; } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java index 5a733ce23b7f..0117bf1e82a5 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/GroupingAggregatorFactoryTest.java @@ -46,7 +46,7 @@ public static GroupingAggregatorFactory makeFactory(String[] groupings, @Nullabl { GroupingAggregatorFactory factory = new GroupingAggregatorFactory("name", Arrays.asList(groupings)); if (null != keyDims) { - factory = (GroupingAggregatorFactory) factory.withKeyDimensions(Sets.newHashSet(keyDims)); + factory = factory.withKeyDimensions(Sets.newHashSet(keyDims)); } return factory; } @@ -67,7 +67,7 @@ public void testNewAggregator() GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); Aggregator aggregator = factory.factorize(metricFactory); Assert.assertEquals(LongConstantAggregator.class, aggregator.getClass()); - Assert.assertEquals(2, aggregator.getLong()); + Assert.assertEquals(1, aggregator.getLong()); } @Test @@ -76,7 +76,7 @@ public void testNewBufferAggregator() GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); BufferAggregator aggregator = factory.factorizeBuffered(metricFactory); Assert.assertEquals(LongConstantBufferAggregator.class, aggregator.getClass()); - Assert.assertEquals(2, aggregator.getLong(null, 0)); + Assert.assertEquals(1, aggregator.getLong(null, 0)); } @Test @@ -86,7 +86,7 @@ public void testNewVectorAggregator() Assert.assertTrue(factory.canVectorize(metricFactory)); VectorAggregator aggregator = factory.factorizeVector(null); Assert.assertEquals(LongConstantVectorAggregator.class, aggregator.getClass()); - Assert.assertEquals(2L, aggregator.get(null, 0)); + Assert.assertEquals(1L, aggregator.get(null, 0)); } @Test @@ -94,10 +94,10 @@ public void testWithKeyDimensions() { GroupingAggregatorFactory factory = makeFactory(new String[]{"a", "b"}, new String[]{"a"}); Aggregator aggregator = factory.factorize(metricFactory); - Assert.assertEquals(2, aggregator.getLong()); - factory = (GroupingAggregatorFactory) factory.withKeyDimensions(Sets.newHashSet("b")); - aggregator = factory.factorize(metricFactory); Assert.assertEquals(1, aggregator.getLong()); + factory = factory.withKeyDimensions(Sets.newHashSet("b")); + aggregator = factory.factorize(metricFactory); + Assert.assertEquals(2, aggregator.getLong()); } } @@ -155,13 +155,14 @@ public static Collection arguments() maxGroupingList[i] = String.valueOf(i); } return Arrays.asList(new Object[][]{ - {new String[]{"a", "b"}, new String[0], 0}, - {new String[]{"a", "b"}, null, 3}, - {new String[]{"a", "b"}, new String[]{"a"}, 2}, - {new String[]{"a", "b"}, new String[]{"b"}, 1}, - {new String[]{"a", "b"}, new String[]{"a", "b"}, 3}, - {new String[]{"b", "a"}, new String[]{"a"}, 1}, - {maxGroupingList, null, Long.MAX_VALUE} + {new String[]{"a", "b"}, new String[0], 3}, + {new String[]{"a", "b"}, null, 0}, + {new String[]{"a", "b"}, new String[]{"a"}, 1}, + {new String[]{"a", "b"}, new String[]{"b"}, 2}, + {new String[]{"a", "b"}, new String[]{"a", "b"}, 0}, + {new String[]{"b", "a"}, new String[]{"a"}, 2}, + {maxGroupingList, null, 0}, + {maxGroupingList, new String[0], Long.MAX_VALUE} }); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 3a7f17bd9db9..9b16be36e884 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -12073,17 +12073,17 @@ public void testGroupingSets() throws Exception .build() ), ImmutableList.of( - new Object[]{"", timestamp("2000-01-01"), 2L, 3L}, - new Object[]{"", timestamp("2001-01-01"), 1L, 3L}, - new Object[]{"a", timestamp("2000-01-01"), 1L, 3L}, - new Object[]{"a", timestamp("2001-01-01"), 1L, 3L}, - new Object[]{"abc", timestamp("2001-01-01"), 1L, 3L}, - new Object[]{"", null, 3L, 2L}, - new Object[]{"a", null, 2L, 2L}, - new Object[]{"abc", null, 1L, 2L}, - new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 1L}, - new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 1L}, - new Object[]{NULL_STRING, null, 6L, 0L} + new Object[]{"", timestamp("2000-01-01"), 2L, 0L}, + new Object[]{"", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2000-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"abc", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"", null, 3L, 1L}, + new Object[]{"a", null, 2L, 1L}, + new Object[]{"abc", null, 1L, 1L}, + new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 2L}, + new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 2L}, + new Object[]{NULL_STRING, null, 6L, 3L} ) ); } @@ -12137,17 +12137,17 @@ public void testGroupingAggregatorDifferentOrder() throws Exception .build() ), ImmutableList.of( - new Object[]{"", timestamp("2000-01-01"), 2L, 3L}, - new Object[]{"", timestamp("2001-01-01"), 1L, 3L}, - new Object[]{"a", timestamp("2000-01-01"), 1L, 3L}, - new Object[]{"a", timestamp("2001-01-01"), 1L, 3L}, - new Object[]{"abc", timestamp("2001-01-01"), 1L, 3L}, - new Object[]{"", null, 3L, 1L}, - new Object[]{"a", null, 2L, 1L}, - new Object[]{"abc", null, 1L, 1L}, - new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 2L}, - new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 2L}, - new Object[]{NULL_STRING, null, 6L, 0L} + new Object[]{"", timestamp("2000-01-01"), 2L, 0L}, + new Object[]{"", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2000-01-01"), 1L, 0L}, + new Object[]{"a", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"abc", timestamp("2001-01-01"), 1L, 0L}, + new Object[]{"", null, 3L, 2L}, + new Object[]{"a", null, 2L, 2L}, + new Object[]{"abc", null, 1L, 2L}, + new Object[]{NULL_STRING, timestamp("2000-01-01"), 3L, 1L}, + new Object[]{NULL_STRING, timestamp("2001-01-01"), 3L, 1L}, + new Object[]{NULL_STRING, null, 6L, 3L} ) ); } @@ -12158,23 +12158,23 @@ public void testGroupingAggregatorWithPostAggregator() throws Exception List resultList; if (NullHandling.sqlCompatible()) { resultList = ImmutableList.of( - new Object[]{NULL_STRING, 2L, 1L, "INDIVIDUAL"}, - new Object[]{"", 1L, 1L, "INDIVIDUAL"}, - new Object[]{"a", 2L, 1L, "INDIVIDUAL"}, - new Object[]{"abc", 1L, 1L, "INDIVIDUAL"}, - new Object[]{NULL_STRING, 6L, 0L, "ALL"} + new Object[]{NULL_STRING, 2L, 0L, "INDIVIDUAL"}, + new Object[]{"", 1L, 0L, "INDIVIDUAL"}, + new Object[]{"a", 2L, 0L, "INDIVIDUAL"}, + new Object[]{"abc", 1L, 0L, "INDIVIDUAL"}, + new Object[]{NULL_STRING, 6L, 1L, "ALL"} ); } else { resultList = ImmutableList.of( - new Object[]{"", 3L, 1L, "INDIVIDUAL"}, - new Object[]{"a", 2L, 1L, "INDIVIDUAL"}, - new Object[]{"abc", 1L, 1L, "INDIVIDUAL"}, - new Object[]{NULL_STRING, 6L, 0L, "ALL"} + new Object[]{"", 3L, 0L, "INDIVIDUAL"}, + new Object[]{"a", 2L, 0L, "INDIVIDUAL"}, + new Object[]{"abc", 1L, 0L, "INDIVIDUAL"}, + new Object[]{NULL_STRING, 6L, 1L, "ALL"} ); } testQuery( "SELECT dim2, SUM(cnt), GROUPING(dim2), \n" - + "CASE WHEN GROUPING(dim2) = 0 THEN 'ALL' ELSE 'INDIVIDUAL' END\n" + + "CASE WHEN GROUPING(dim2) = 1 THEN 'ALL' ELSE 'INDIVIDUAL' END\n" + "FROM druid.foo\n" + "GROUP BY GROUPING SETS ( (dim2), () )", ImmutableList.of( @@ -12199,7 +12199,7 @@ public void testGroupingAggregatorWithPostAggregator() throws Exception ) .setPostAggregatorSpecs(Collections.singletonList(new ExpressionPostAggregator( "p0", - "case_searched((\"a1\" == 0),'ALL','INDIVIDUAL')", + "case_searched((\"a1\" == 1),'ALL','INDIVIDUAL')", null, ExprMacroTable.nil() ))) From 6dd363203dd13d21657ebf38af0a837df96218e4 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Fri, 4 Dec 2020 17:03:09 +0530 Subject: [PATCH 6/8] Add documentation --- .../GroupingAggregatorFactory.java | 36 ++++++++++++++++++- .../constant/LongConstantAggregator.java | 4 +++ .../LongConstantBufferAggregator.java | 3 ++ .../LongConstantVectorAggregator.java | 3 ++ .../epinephelinae/RowBasedGrouperHelper.java | 15 ++++++-- 5 files changed, 57 insertions(+), 4 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java index a96bc45c9175..b55701dd40e3 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -41,6 +41,40 @@ import java.util.Objects; import java.util.Set; +/** + * This class implements {@code grouping_id} function to determine the grouping that a row is part of. Different rows + * in same result could have different grouping columns when subtotals are used. + * + * It takes following arguments + * - {@code name} - Name of aggregators + * - {code groupings} - List of dimensions that user is interested in tracking + * - {@code keyDimensions} - The list of grouping dimensions being included in the result row. This list is a subset of + * {@code groupings0}. This argument cannot be passed by the user. It is set by druid engine + * when a particular subtotal spec is being processed. Whenever druid engine processes a new + * subtotal spec, engine sets that subtotal spec as new {@code keyDimensions}. + * + * When key dimensions are updated, {@code value} is updated as well. How the value is determined is captured + * at {@link #groupingId(List, Set)}. + * + * since grouping_id has to be calculated only once, it could have been implemented as a virtual function or + * post-aggregator etc. We modelled it as an aggregation operator so that its output can be used in a post-aggregator. + * Calcite too models grouping_id as an aggregation operator. + * Since it is a non-trivial special aggregation, implementing it required changes in core druid engine to work. There + * were few approaches. We chose the approach that required least changes in core druid. + * Refer to https://github.com/apache/druid/pull/10518#discussion_r532941216 for more details. + * + * Currently, it works in following way + * - On data servers (no change), + * - this factory generates {@link LongConstantAggregator} / {@link LongConstantBufferAggregator} / {@link LongConstantVectorAggregator} + * with keyDimensions as null + * - The aggregators don't actually aggregate anything and their result is not actually used. We could have remove + * these aggregators on data servers but that will result in a signature mismatch on broker and data nodes. That would + * have required extra handling and would have been error-prone. + * - On brokers + * - Results from data node is already re-processed for each subtotal spec. In this path, we also update the + * grouping id for each row. + * + */ @EverythingIsNonnullByDefault public class GroupingAggregatorFactory extends AggregatorFactory { @@ -205,7 +239,7 @@ public byte[] getCacheKey() * Given the list of grouping dimensions, returns a long value where each bit at position X in the returned value * corresponds to the dimension in groupings at same position X. X is the position relative to the right end. if * keyDimensions contain the grouping dimension at position X, the bit is set to 0 at position X, otherwise it is - * set to 1. An example adapted from Microsoft SQL documentation + * set to 1. * * groupings keyDimensions value (3 least significant bits) value (long) * a,b,c [a] 011 3 diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java index 60b6eb3e63b8..1fae2715b1d4 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantAggregator.java @@ -21,6 +21,10 @@ import org.apache.druid.query.aggregation.Aggregator; +/** + * This aggregator is a no-op aggregator with a fixed non-null output value. It can be used in scenarios where + * result is constant such as {@link org.apache.druid.query.aggregation.GroupingAggregatorFactory} + */ public class LongConstantAggregator implements Aggregator { private final long value; diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java index bc06a2725ca7..1ddf11b57d7b 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantBufferAggregator.java @@ -23,6 +23,9 @@ import java.nio.ByteBuffer; +/** + * {@link BufferAggregator} variant of {@link LongConstantAggregator} + */ public class LongConstantBufferAggregator implements BufferAggregator { private final long value; diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java index f127decad009..4af4b8a9fe77 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/constant/LongConstantVectorAggregator.java @@ -24,6 +24,9 @@ import javax.annotation.Nullable; import java.nio.ByteBuffer; +/** + * {@link VectorAggregator} variant of {@link LongConstantAggregator} + */ public class LongConstantVectorAggregator implements VectorAggregator { private final long value; diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index fd0ed2efaa16..6eb18d1e305c 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -555,12 +555,20 @@ public static CloseableGrouperIterator makeGrouperIterat } } - // KeyDimensionNames are the column names of dimensions. Its required since aggregators are not aware of - // output column names - Set keyDimensionNames = dimsToInclude.stream().map(DimensionSpec::getDimension).collect(Collectors.toSet()); + /** + * KeyDimensionNames are the input column names of dimensions. Its required since aggregators are not aware of the + * output column names. + * As we exclude certain dimensions from the result row, the value for any grouping_id aggregators have to change + * to reflect the new grouping dimensions, that aggregation is being done upon. We will mark the indices which have + * grouping aggregators and update the value for each row at those indices. + */ + Set keyDimensionNames = dimsToInclude.stream() + .map(DimensionSpec::getDimension) + .collect(Collectors.toSet()); for (int i = 0; i < query.getAggregatorSpecs().size(); i++) { AggregatorFactory aggregatorFactory = query.getAggregatorSpecs().get(i); if (aggregatorFactory instanceof GroupingAggregatorFactory) { + groupingAggregatorsBitSet.set(i); groupingAggregatorValues[i] = ((GroupingAggregatorFactory) aggregatorFactory) .withKeyDimensions(keyDimensionNames) @@ -595,6 +603,7 @@ public static CloseableGrouperIterator makeGrouperIterat final int resultRowAggregatorStart = query.getResultRowAggregatorStart(); for (int i = 0; i < entry.getValues().length; i++) { if (dimsToInclude != null && groupingAggregatorsBitSet.get(i)) { + // Override with a new value, reflecting the new set of grouping dimensions resultRow.set(resultRowAggregatorStart + i, groupingAggregatorValues[i]); } else { resultRow.set(resultRowAggregatorStart + i, entry.getValues()[i]); From 5fc9a1b5f093baf75eecb4c46d5d24dc5b657bf6 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Sun, 6 Dec 2020 15:39:41 +0530 Subject: [PATCH 7/8] Fix typos --- .../aggregation/GroupingAggregatorFactory.java | 18 +++++++++--------- .../epinephelinae/RowBasedGrouperHelper.java | 12 +++++------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java index b55701dd40e3..e8a205a8ab84 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -42,14 +42,14 @@ import java.util.Set; /** - * This class implements {@code grouping_id} function to determine the grouping that a row is part of. Different rows - * in same result could have different grouping columns when subtotals are used. + * This class implements {@code grouping_id} function to determine the grouping that a row is part of. Different result rows + * for a query could have different grouping columns when subtotals are used. * - * It takes following arguments + * This aggregator factory takes following arguments * - {@code name} - Name of aggregators - * - {code groupings} - List of dimensions that user is interested in tracking + * - {@code groupings} - List of dimensions that the user is interested in tracking * - {@code keyDimensions} - The list of grouping dimensions being included in the result row. This list is a subset of - * {@code groupings0}. This argument cannot be passed by the user. It is set by druid engine + * {@code groupings}. This argument cannot be passed by the user. It is set by druid engine * when a particular subtotal spec is being processed. Whenever druid engine processes a new * subtotal spec, engine sets that subtotal spec as new {@code keyDimensions}. * @@ -67,11 +67,11 @@ * - On data servers (no change), * - this factory generates {@link LongConstantAggregator} / {@link LongConstantBufferAggregator} / {@link LongConstantVectorAggregator} * with keyDimensions as null - * - The aggregators don't actually aggregate anything and their result is not actually used. We could have remove - * these aggregators on data servers but that will result in a signature mismatch on broker and data nodes. That would - * have required extra handling and would have been error-prone. + * - The aggregators don't actually aggregate anything and their result is not actually used. We could have removed + * these aggregators on data servers but that would result in a signature mismatch on broker and data nodes. That requires + * extra handling and is error-prone. * - On brokers - * - Results from data node is already re-processed for each subtotal spec. In this path, we also update the + * - Results from data node is already being re-processed for each subtotal spec. We made modifications in this path to update the * grouping id for each row. * */ diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index 6eb18d1e305c..9a87cb59b7a2 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -555,13 +555,11 @@ public static CloseableGrouperIterator makeGrouperIterat } } - /** - * KeyDimensionNames are the input column names of dimensions. Its required since aggregators are not aware of the - * output column names. - * As we exclude certain dimensions from the result row, the value for any grouping_id aggregators have to change - * to reflect the new grouping dimensions, that aggregation is being done upon. We will mark the indices which have - * grouping aggregators and update the value for each row at those indices. - */ + // KeyDimensionNames are the input column names of dimensions. Its required since aggregators are not aware of the + // output column names. + // As we exclude certain dimensions from the result row, the value for any grouping_id aggregators have to change + // to reflect the new grouping dimensions, that aggregation is being done upon. We will mark the indices which have + // grouping aggregators and update the value for each row at those indices. Set keyDimensionNames = dimsToInclude.stream() .map(DimensionSpec::getDimension) .collect(Collectors.toSet()); From 8678c78d770bce2dda47a4dc15589ead42243bc4 Mon Sep 17 00:00:00 2001 From: Abhishek Agarwal <1477457+abhishekagarwal87@users.noreply.github.com> Date: Mon, 7 Dec 2020 16:18:47 +0530 Subject: [PATCH 8/8] typo fix --- .../druid/query/aggregation/GroupingAggregatorFactory.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java index e8a205a8ab84..62fbb47aa095 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/GroupingAggregatorFactory.java @@ -42,7 +42,7 @@ import java.util.Set; /** - * This class implements {@code grouping_id} function to determine the grouping that a row is part of. Different result rows + * This class implements {@code grouping} function to determine the grouping that a row is part of. Different result rows * for a query could have different grouping columns when subtotals are used. * * This aggregator factory takes following arguments @@ -56,9 +56,9 @@ * When key dimensions are updated, {@code value} is updated as well. How the value is determined is captured * at {@link #groupingId(List, Set)}. * - * since grouping_id has to be calculated only once, it could have been implemented as a virtual function or + * since grouping has to be calculated only once, it could have been implemented as a virtual function or * post-aggregator etc. We modelled it as an aggregation operator so that its output can be used in a post-aggregator. - * Calcite too models grouping_id as an aggregation operator. + * Calcite too models grouping function as an aggregation operator. * Since it is a non-trivial special aggregation, implementing it required changes in core druid engine to work. There * were few approaches. We chose the approach that required least changes in core druid. * Refer to https://github.com/apache/druid/pull/10518#discussion_r532941216 for more details.