From 88cd2aa92c43698f162e31910c4918666b778239 Mon Sep 17 00:00:00 2001 From: kaijianding Date: Thu, 16 Feb 2017 21:29:31 +0800 Subject: [PATCH] average aggregator in both ingestion phase and query phase --- docs/content/querying/aggregations.md | 14 + .../io/druid/jackson/AggregatorsModule.java | 7 + .../query/aggregation/avg/AvgAggregator.java | 109 +++++++ .../avg/AvgAggregatorCollector.java | 193 ++++++++++++ .../aggregation/avg/AvgAggregatorFactory.java | 293 ++++++++++++++++++ .../aggregation/avg/AvgBufferAggregator.java | 143 +++++++++ .../druid/query/aggregation/avg/AvgSerde.java | 121 ++++++++ .../segment/QueryableIndexStorageAdapter.java | 23 ++ .../avg/AvgAggregatorCollectorTest.java | 162 ++++++++++ .../aggregation/avg/AvgAggregatorTest.java | 146 +++++++++ .../aggregation/avg/AvgGroupByQueryTest.java | 163 ++++++++++ .../query/aggregation/avg/AvgSerdeTest.java | 46 +++ .../query/aggregation/avg/AvgTestHelper.java | 91 ++++++ .../avg/AvgTimeseriesQueryTest.java | 118 +++++++ .../aggregation/avg/AvgTopNQueryTest.java | 149 +++++++++ 15 files changed, 1778 insertions(+) create mode 100644 processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregator.java create mode 100644 processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorCollector.java create mode 100644 processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorFactory.java create mode 100644 processing/src/main/java/io/druid/query/aggregation/avg/AvgBufferAggregator.java create mode 100644 processing/src/main/java/io/druid/query/aggregation/avg/AvgSerde.java create mode 100644 processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorCollectorTest.java create mode 100644 processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorTest.java create mode 100644 processing/src/test/java/io/druid/query/aggregation/avg/AvgGroupByQueryTest.java create mode 100644 processing/src/test/java/io/druid/query/aggregation/avg/AvgSerdeTest.java create mode 100644 processing/src/test/java/io/druid/query/aggregation/avg/AvgTestHelper.java create mode 100644 processing/src/test/java/io/druid/query/aggregation/avg/AvgTimeseriesQueryTest.java create mode 100644 processing/src/test/java/io/druid/query/aggregation/avg/AvgTopNQueryTest.java diff --git a/docs/content/querying/aggregations.md b/docs/content/querying/aggregations.md index 77aec92786f9..ed36a8bbc2bc 100644 --- a/docs/content/querying/aggregations.md +++ b/docs/content/querying/aggregations.md @@ -76,6 +76,20 @@ Computes the sum of values as 64-bit floating point value. Similar to `longSum` { "type" : "longMax", "name" : , "fieldName" : } ``` +### Average aggregator + +Computes the average of values as 64-bit floating point value + +```json +{ "type" : "avg", "name" : , "fieldName" : , "inputType" : } +``` + +`name` – output name for the averaged value +`fieldName` – name of the metric column to average over +`inputType` is one of `float`/`long`/`avg`, by default it is `float` + +If `fieldName` column is pre-averaged in ingestion phase, the `inputType` should be `avg` in query + ### First / Last aggregator First and Last aggregator cannot be used in ingestion spec, and should only be specified as part of queries. diff --git a/processing/src/main/java/io/druid/jackson/AggregatorsModule.java b/processing/src/main/java/io/druid/jackson/AggregatorsModule.java index 764ee6fcc19f..229da712bec0 100644 --- a/processing/src/main/java/io/druid/jackson/AggregatorsModule.java +++ b/processing/src/main/java/io/druid/jackson/AggregatorsModule.java @@ -35,6 +35,8 @@ import io.druid.query.aggregation.LongMinAggregatorFactory; import io.druid.query.aggregation.LongSumAggregatorFactory; import io.druid.query.aggregation.PostAggregator; +import io.druid.query.aggregation.avg.AvgAggregatorFactory; +import io.druid.query.aggregation.avg.AvgSerde; import io.druid.query.aggregation.cardinality.CardinalityAggregatorFactory; import io.druid.query.aggregation.first.DoubleFirstAggregatorFactory; import io.druid.query.aggregation.first.LongFirstAggregatorFactory; @@ -71,6 +73,10 @@ public AggregatorsModule() ComplexMetrics.registerSerde("preComputedHyperUnique", new PreComputedHyperUniquesSerde(HyperLogLogHash.getDefault())); } + if (ComplexMetrics.getSerdeForType("avg") == null) { + ComplexMetrics.registerSerde("avg", new AvgSerde()); + } + setMixInAnnotation(AggregatorFactory.class, AggregatorFactoryMixin.class); setMixInAnnotation(PostAggregator.class, PostAggregatorMixin.class); } @@ -84,6 +90,7 @@ public AggregatorsModule() @JsonSubTypes.Type(name = "doubleMin", value = DoubleMinAggregatorFactory.class), @JsonSubTypes.Type(name = "longMax", value = LongMaxAggregatorFactory.class), @JsonSubTypes.Type(name = "longMin", value = LongMinAggregatorFactory.class), + @JsonSubTypes.Type(name = "avg", value = AvgAggregatorFactory.class), @JsonSubTypes.Type(name = "javascript", value = JavaScriptAggregatorFactory.class), @JsonSubTypes.Type(name = "histogram", value = HistogramAggregatorFactory.class), @JsonSubTypes.Type(name = "hyperUnique", value = HyperUniquesAggregatorFactory.class), diff --git a/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregator.java b/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregator.java new file mode 100644 index 000000000000..3f548b97eb4e --- /dev/null +++ b/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregator.java @@ -0,0 +1,109 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import io.druid.query.aggregation.Aggregator; +import io.druid.segment.FloatColumnSelector; +import io.druid.segment.LongColumnSelector; +import io.druid.segment.ObjectColumnSelector; + +/** + */ +public abstract class AvgAggregator implements Aggregator +{ + protected final AvgAggregatorCollector holder = new AvgAggregatorCollector(); + + @Override + public void reset() + { + holder.reset(); + } + + @Override + public Object get() + { + return holder; + } + + @Override + public void close() + { + } + + @Override + public float getFloat() + { + return (float) holder.compute(); + } + + @Override + public long getLong() + { + return (long) holder.compute(); + } + + public static final class FloatAvgAggregator extends AvgAggregator + { + private final FloatColumnSelector selector; + + public FloatAvgAggregator(FloatColumnSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate() + { + holder.add(selector.get()); + } + } + + public static final class LongAvgAggregator extends AvgAggregator + { + private final LongColumnSelector selector; + + public LongAvgAggregator(LongColumnSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate() + { + holder.add(selector.get()); + } + } + + public static final class ObjectAvgAggregator extends AvgAggregator + { + private final ObjectColumnSelector selector; + + public ObjectAvgAggregator(ObjectColumnSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate() + { + AvgAggregatorCollector.combineValues(holder, selector.get()); + } + } +} diff --git a/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorCollector.java b/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorCollector.java new file mode 100644 index 000000000000..5f4982ade66c --- /dev/null +++ b/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorCollector.java @@ -0,0 +1,193 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.primitives.Doubles; +import com.google.common.primitives.Longs; + +import java.nio.ByteBuffer; +import java.util.Comparator; + +public class AvgAggregatorCollector +{ + public static AvgAggregatorCollector from(ByteBuffer buffer) + { + return new AvgAggregatorCollector(buffer.getLong(), buffer.getDouble()); + } + + public static final Comparator COMPARATOR = new Comparator() + { + @Override + public int compare(AvgAggregatorCollector o1, AvgAggregatorCollector o2) + { + int compare = Longs.compare(o1.count, o2.count); + if (compare != 0) { + return compare; + } + return Doubles.compare(o1.sum, o2.sum); + } + }; + + static Object combineValues(Object lhs, Object rhs) + { + final AvgAggregatorCollector holder1 = (AvgAggregatorCollector) lhs; + final AvgAggregatorCollector holder2 = (AvgAggregatorCollector) rhs; + + if (holder2.count == 0) { + return holder1; + } + if (holder1.count == 0) { + holder1.count = holder2.count; + holder1.sum = holder2.sum; + return holder1; + } + + holder1.count += holder2.count; + holder1.sum += holder2.sum; + + return holder1; + } + + static int getMaxIntermediateSize() + { + return Longs.BYTES + Doubles.BYTES; + } + + long count; // number of elements + double sum; // sum of elements + + public AvgAggregatorCollector() + { + this(0, 0); + } + + public void reset() + { + count = 0; + sum = 0; + } + + public AvgAggregatorCollector(long count, double sum) + { + this.count = count; + this.sum = sum; + } + + public AvgAggregatorCollector add(float v) + { + count++; + sum += v; + return this; + } + + public AvgAggregatorCollector add(long v) + { + count++; + sum += v; + return this; + } + + public double compute() + { + if (count == 0) { + throw new IllegalStateException("should not be empty holder"); + } + return sum / count; + } + + @JsonValue + public byte[] toByteArray() + { + final ByteBuffer buffer = toByteBuffer(); + buffer.flip(); + byte[] theBytes = new byte[buffer.remaining()]; + buffer.get(theBytes); + + return theBytes; + } + + public ByteBuffer toByteBuffer() + { + return ByteBuffer.allocate(Longs.BYTES + Doubles.BYTES) + .putLong(count) + .putDouble(sum); + } + + @VisibleForTesting + boolean equalsWithEpsilon(AvgAggregatorCollector o, double epsilon) + { + if (this == o) { + return true; + } + + if (count != o.count) { + return false; + } + if (Math.abs(sum - o.sum) > epsilon) { + return false; + } + + return true; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + AvgAggregatorCollector that = (AvgAggregatorCollector) o; + + if (count != that.count) { + return false; + } + if (Double.compare(that.sum, sum) != 0) { + return false; + } + + return true; + } + + @Override + public int hashCode() + { + int result; + long temp; + result = (int) (count ^ (count >>> 32)); + temp = Double.doubleToLongBits(sum); + result = 31 * result + (int) (temp ^ (temp >>> 32)); + return result; + } + + @Override + public String toString() + { + return "AvgAggregatorCollector{" + + "count=" + count + + ", sum=" + sum + + '}'; + } +} diff --git a/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorFactory.java b/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorFactory.java new file mode 100644 index 000000000000..99d6fe621fed --- /dev/null +++ b/processing/src/main/java/io/druid/query/aggregation/avg/AvgAggregatorFactory.java @@ -0,0 +1,293 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeName; +import com.google.common.base.Preconditions; +import com.metamx.common.IAE; +import io.druid.common.utils.StringUtils; +import io.druid.query.aggregation.Aggregator; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.AggregatorFactoryNotMergeableException; +import io.druid.query.aggregation.Aggregators; +import io.druid.query.aggregation.BufferAggregator; +import io.druid.segment.ColumnSelectorFactory; +import io.druid.segment.FloatColumnSelector; +import io.druid.segment.LongColumnSelector; +import io.druid.segment.ObjectColumnSelector; +import org.apache.commons.codec.binary.Base64; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + */ +@JsonTypeName("avg") +public class AvgAggregatorFactory extends AggregatorFactory +{ + public static FloatColumnSelector asFloatColumnSelector(final ObjectColumnSelector selector) + { + return new FloatColumnSelector() + { + @Override + public float get() + { + return (float) ((AvgAggregatorCollector) selector.get()).compute(); + } + }; + } + + public static LongColumnSelector asLongColumnSelector(final ObjectColumnSelector selector) + { + return new LongColumnSelector() + { + @Override + public long get() + { + return (long) ((AvgAggregatorCollector) selector.get()).compute(); + } + }; + } + + protected static final byte CACHE_TYPE_ID = 22; + + private final String name; + private final String fieldName; + private final String inputType; + + @JsonCreator + public AvgAggregatorFactory( + @JsonProperty("name") String name, + @JsonProperty("fieldName") String fieldName, + @JsonProperty("inputType") String inputType + ) + { + Preconditions.checkNotNull(name, "Must have a valid, non-null aggregator name"); + Preconditions.checkNotNull(fieldName, "Must have a valid, non-null fieldName"); + + this.name = name; + this.fieldName = fieldName; + this.inputType = inputType == null ? "float" : inputType; + } + + public AvgAggregatorFactory(String name, String fieldName) + { + this(name, fieldName, null); + } + + @Override + public String getTypeName() + { + return "avg"; + } + + @Override + public int getMaxIntermediateSize() + { + return AvgAggregatorCollector.getMaxIntermediateSize(); + } + + @Override + public Aggregator factorize(ColumnSelectorFactory metricFactory) + { + ObjectColumnSelector selector = metricFactory.makeObjectColumnSelector(fieldName); + if (selector == null) { + return Aggregators.noopAggregator(); + } + if ("float".equalsIgnoreCase(inputType)) { + return new AvgAggregator.FloatAvgAggregator( + metricFactory.makeFloatColumnSelector(fieldName) + ); + } else if ("long".equalsIgnoreCase(inputType)) { + return new AvgAggregator.LongAvgAggregator( + metricFactory.makeLongColumnSelector(fieldName) + ); + } else if ("avg".equalsIgnoreCase(inputType)) { + return new AvgAggregator.ObjectAvgAggregator(selector); + } + + throw new IAE( + "Incompatible type for metric[%s], expected a float, long or avg, got a %s", fieldName, inputType + ); + } + + @Override + public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) + { + ObjectColumnSelector selector = metricFactory.makeObjectColumnSelector(fieldName); + if (selector == null) { + return Aggregators.noopBufferAggregator(); + } + + if ("float".equalsIgnoreCase(inputType)) { + return new AvgBufferAggregator.FloatAvgBufferAggregator( + metricFactory.makeFloatColumnSelector(fieldName) + ); + } else if ("long".equalsIgnoreCase(inputType)) { + return new AvgBufferAggregator.LongAvgBufferAggregator( + metricFactory.makeLongColumnSelector(fieldName) + ); + } else if ("avg".equalsIgnoreCase(inputType)) { + return new AvgBufferAggregator.ObjectAvgBufferAggregator(selector); + } + + throw new IAE( + "Incompatible type for metric[%s], expected a float, long or avg, got a %s", fieldName, inputType + ); + } + + @Override + public AggregatorFactory getCombiningFactory() + { + return new AvgAggregatorFactory(name, name, getTypeName()); + } + + @Override + public List getRequiredColumns() + { + return Arrays.asList(new AvgAggregatorFactory(fieldName, fieldName, inputType)); + } + + @Override + public AggregatorFactory getMergingFactory(AggregatorFactory other) throws AggregatorFactoryNotMergeableException + { + if (Objects.equals(getName(), other.getName()) && this.getClass() == other.getClass()) { + return getCombiningFactory(); + } else { + throw new AggregatorFactoryNotMergeableException(this, other); + } + } + + @Override + public Comparator getComparator() + { + return AvgAggregatorCollector.COMPARATOR; + } + + @Override + public Object combine(Object lhs, Object rhs) + { + return AvgAggregatorCollector.combineValues(lhs, rhs); + } + + @Override + public Object finalizeComputation(Object object) + { + return ((AvgAggregatorCollector) object).compute(); + } + + @Override + public Object deserialize(Object object) + { + if (object instanceof byte[]) { + return AvgAggregatorCollector.from(ByteBuffer.wrap((byte[]) object)); + } else if (object instanceof ByteBuffer) { + return AvgAggregatorCollector.from((ByteBuffer) object); + } else if (object instanceof String) { + return AvgAggregatorCollector.from( + ByteBuffer.wrap(Base64.decodeBase64(StringUtils.toUtf8((String) object))) + ); + } + return object; + } + + @JsonProperty + public String getFieldName() + { + return fieldName; + } + + @Override + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public String getInputType() + { + return inputType; + } + + @Override + public List requiredFields() + { + return Arrays.asList(fieldName); + } + + @Override + public byte[] getCacheKey() + { + byte[] fieldNameBytes = com.metamx.common.StringUtils.toUtf8(fieldName); + byte[] inputTypeBytes = com.metamx.common.StringUtils.toUtf8(inputType); + + return ByteBuffer.allocate(1 + fieldNameBytes.length + inputTypeBytes.length) + .put(CACHE_TYPE_ID) + .put(fieldNameBytes) + .put(inputTypeBytes) + .array(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + AvgAggregatorFactory that = (AvgAggregatorFactory) o; + + if (name != null ? !name.equals(that.name) : that.name != null) { + return false; + } + if (fieldName != null ? !fieldName.equals(that.fieldName) : that.fieldName != null) { + return false; + } + return inputType != null ? inputType.equals(that.inputType) : that.inputType == null; + } + + @Override + public int hashCode() + { + int result = name != null ? name.hashCode() : 0; + result = 31 * result + (fieldName != null ? fieldName.hashCode() : 0); + result = 31 * result + (inputType != null ? inputType.hashCode() : 0); + return result; + } + + @Override + public String toString() + { + return "AvgAggregatorFactory{" + + "name='" + name + '\'' + + ", fieldName='" + fieldName + '\'' + + ", inputType='" + inputType + '\'' + + '}'; + } +} diff --git a/processing/src/main/java/io/druid/query/aggregation/avg/AvgBufferAggregator.java b/processing/src/main/java/io/druid/query/aggregation/avg/AvgBufferAggregator.java new file mode 100644 index 000000000000..9bf07d688bae --- /dev/null +++ b/processing/src/main/java/io/druid/query/aggregation/avg/AvgBufferAggregator.java @@ -0,0 +1,143 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.google.common.primitives.Longs; +import io.druid.query.aggregation.BufferAggregator; +import io.druid.segment.FloatColumnSelector; +import io.druid.segment.LongColumnSelector; +import io.druid.segment.ObjectColumnSelector; + +import java.nio.ByteBuffer; + +/** + */ +public abstract class AvgBufferAggregator implements BufferAggregator +{ + private static final int COUNT_OFFSET = 0; + private static final int SUM_OFFSET = Longs.BYTES; + + @Override + public void init(final ByteBuffer buf, final int position) + { + buf.putLong(position + COUNT_OFFSET, 0).putDouble(position + SUM_OFFSET, 0); + } + + @Override + public Object get(final ByteBuffer buf, final int position) + { + AvgAggregatorCollector holder = new AvgAggregatorCollector(); + holder.count = buf.getLong(position); + holder.sum = buf.getDouble(position + SUM_OFFSET); + return holder; + } + + @Override + public float getFloat(ByteBuffer buf, int position) + { + long count = buf.getLong(position + COUNT_OFFSET); + double sum = buf.getDouble(position + SUM_OFFSET); + return (float) sum / count; + } + + @Override + public long getLong(ByteBuffer buf, int position) + { + long count = buf.getLong(position + COUNT_OFFSET); + double sum = buf.getDouble(position + SUM_OFFSET); + return (long) sum / count; + } + + @Override + public void close() + { + } + + public static final class FloatAvgBufferAggregator extends AvgBufferAggregator + { + private final FloatColumnSelector selector; + + public FloatAvgBufferAggregator(FloatColumnSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + float v = selector.get(); + long count = buf.getLong(position + COUNT_OFFSET) + 1; + double sum = buf.getDouble(position + SUM_OFFSET) + v; + buf.putLong(position, count); + buf.putDouble(position + SUM_OFFSET, sum); + } + } + + public static final class LongAvgBufferAggregator extends AvgBufferAggregator + { + private final LongColumnSelector selector; + + public LongAvgBufferAggregator(LongColumnSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + long v = selector.get(); + long count = buf.getLong(position + COUNT_OFFSET) + 1; + double sum = buf.getDouble(position + SUM_OFFSET) + v; + buf.putLong(position, count); + buf.putDouble(position + SUM_OFFSET, sum); + } + } + + public static final class ObjectAvgBufferAggregator extends AvgBufferAggregator + { + private final ObjectColumnSelector selector; + + public ObjectAvgBufferAggregator(ObjectColumnSelector selector) + { + this.selector = selector; + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + AvgAggregatorCollector holder2 = (AvgAggregatorCollector) selector.get(); + + long count = buf.getLong(position + COUNT_OFFSET); + if (count == 0) { + buf.putLong(position, holder2.count); + buf.putDouble(position + SUM_OFFSET, holder2.sum); + return; + } + + double sum = buf.getDouble(position + SUM_OFFSET); + + count += holder2.count; + sum += holder2.sum; + + buf.putLong(position, count); + buf.putDouble(position + SUM_OFFSET, sum); + } + } +} diff --git a/processing/src/main/java/io/druid/query/aggregation/avg/AvgSerde.java b/processing/src/main/java/io/druid/query/aggregation/avg/AvgSerde.java new file mode 100644 index 000000000000..5e0ae3803f1f --- /dev/null +++ b/processing/src/main/java/io/druid/query/aggregation/avg/AvgSerde.java @@ -0,0 +1,121 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.google.common.collect.Ordering; +import io.druid.data.input.InputRow; +import io.druid.segment.column.ColumnBuilder; +import io.druid.segment.data.GenericIndexed; +import io.druid.segment.data.ObjectStrategy; +import io.druid.segment.serde.ComplexColumnPartSupplier; +import io.druid.segment.serde.ComplexMetricExtractor; +import io.druid.segment.serde.ComplexMetricSerde; + +import java.nio.ByteBuffer; +import java.util.List; + +/** + */ +public class AvgSerde extends ComplexMetricSerde +{ + private static final Ordering comparator = + Ordering.from(AvgAggregatorCollector.COMPARATOR).nullsFirst(); + + @Override + public String getTypeName() + { + return "avg"; + } + + @Override + public ComplexMetricExtractor getExtractor() + { + return new ComplexMetricExtractor() + { + @Override + public Class extractedClass() + { + return AvgAggregatorCollector.class; + } + + @Override + public AvgAggregatorCollector extractValue(InputRow inputRow, String metricName) + { + Object rawValue = inputRow.getRaw(metricName); + + if (rawValue instanceof AvgAggregatorCollector) { + return (AvgAggregatorCollector) rawValue; + } + AvgAggregatorCollector collector = new AvgAggregatorCollector(); + + List dimValues = inputRow.getDimension(metricName); + if (dimValues != null && dimValues.size() > 0) { + for (String dimValue : dimValues) { + float value = Float.parseFloat(dimValue); + collector.add(value); + } + } + return collector; + } + }; + } + + @Override + public void deserializeColumn( + ByteBuffer byteBuffer, ColumnBuilder columnBuilder + ) + { + final GenericIndexed column = GenericIndexed.read(byteBuffer, getObjectStrategy()); + columnBuilder.setComplexColumn(new ComplexColumnPartSupplier(getTypeName(), column)); + } + + @Override + public ObjectStrategy getObjectStrategy() + { + return new ObjectStrategy() + { + @Override + public Class getClazz() + { + return AvgAggregatorCollector.class; + } + + @Override + public AvgAggregatorCollector fromByteBuffer(ByteBuffer buffer, int numBytes) + { + final ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); + readOnlyBuffer.limit(readOnlyBuffer.position() + numBytes); + return AvgAggregatorCollector.from(readOnlyBuffer); + } + + @Override + public byte[] toBytes(AvgAggregatorCollector collector) + { + return collector == null ? new byte[]{} : collector.toByteArray(); + } + + @Override + public int compare(AvgAggregatorCollector o1, AvgAggregatorCollector o2) + { + return comparator.compare(o1, o2); + } + }; + } +} diff --git a/processing/src/main/java/io/druid/segment/QueryableIndexStorageAdapter.java b/processing/src/main/java/io/druid/segment/QueryableIndexStorageAdapter.java index dadf2539b5c4..c116cfd0f4b0 100644 --- a/processing/src/main/java/io/druid/segment/QueryableIndexStorageAdapter.java +++ b/processing/src/main/java/io/druid/segment/QueryableIndexStorageAdapter.java @@ -32,6 +32,7 @@ import io.druid.java.util.common.guava.Sequence; import io.druid.java.util.common.guava.Sequences; import io.druid.query.QueryInterruptedException; +import io.druid.query.aggregation.avg.AvgAggregatorFactory; import io.druid.query.dimension.DimensionSpec; import io.druid.query.extraction.ExtractionFn; import io.druid.query.filter.BooleanFilter; @@ -677,6 +678,17 @@ public FloatColumnSelector makeFloatColumnSelector(String columnName) } } + if (cachedMetricVals == null) { + Column holder = index.getColumn(columnName); + if (holder!= null) { + boolean isAvg = holder.getCapabilities().getType() == ValueType.COMPLEX + && "avg".equals(holder.getComplexColumn().getTypeName()); + if (isAvg) { + final ObjectColumnSelector objectColumnSelector = makeObjectColumnSelector(columnName); + return AvgAggregatorFactory.asFloatColumnSelector(objectColumnSelector); + } + } + } if (cachedMetricVals == null) { return ZeroFloatColumnSelector.instance(); } @@ -711,6 +723,17 @@ public LongColumnSelector makeLongColumnSelector(String columnName) } } + if (cachedMetricVals == null) { + Column holder = index.getColumn(columnName); + if (holder != null) { + boolean isAvg = holder.getCapabilities().getType() == ValueType.COMPLEX + && "avg".equals(holder.getComplexColumn().getTypeName()); + if (isAvg) { + final ObjectColumnSelector objectColumnSelector = makeObjectColumnSelector(columnName); + return AvgAggregatorFactory.asLongColumnSelector(objectColumnSelector); + } + } + } if (cachedMetricVals == null) { return ZeroLongColumnSelector.instance(); } diff --git a/processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorCollectorTest.java b/processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorCollectorTest.java new file mode 100644 index 000000000000..6e053b9c478b --- /dev/null +++ b/processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorCollectorTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.google.common.collect.Lists; +import com.metamx.common.Pair; +import io.druid.segment.FloatColumnSelector; +import io.druid.segment.ObjectColumnSelector; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +public class AvgAggregatorCollectorTest +{ + private static final float[] market_upfront = new float[]{ + 800.0f, 800.0f, 826.0602f, 1564.6177f, 1006.4021f, 869.64374f, 809.04175f, 1458.4027f, 852.4375f, 879.9881f, + 950.1468f, 712.7746f, 846.2675f, 682.8855f, 1109.875f, 594.3817f, 870.1159f, 677.511f, 1410.2781f, 1219.4321f, + 979.306f, 1224.5016f, 1215.5898f, 716.6092f, 1301.0233f, 786.3633f, 989.9315f, 1609.0967f, 1023.2952f, 1367.6381f, + 1627.598f, 810.8894f, 1685.5001f, 545.9906f, 1870.061f, 555.476f, 1643.3408f, 943.4972f, 1667.4978f, 913.5611f, + 1218.5619f, 1273.7074f, 888.70526f, 1113.1141f, 864.5689f, 1308.582f, 785.07886f, 1363.6149f, 787.1253f, + 826.0392f, 1107.2438f, 872.6257f, 1188.3693f, 911.9568f, 794.0988f, 1299.0933f, 1212.9283f, 901.3273f, 723.5143f, + 1061.9734f, 602.97955f, 879.4061f, 724.2625f, 862.93134f, 1133.1351f, 948.65796f, 807.6017f, 914.525f, 1553.3485f, + 1208.4567f, 679.6193f, 645.1777f, 1120.0887f, 1649.5333f, 1433.3988f, 1598.1793f, 1192.5631f, 1022.85455f, + 1228.5024f, 1298.4158f, 1345.9644f, 1291.898f, 1306.4957f, 1287.7667f, 1631.5844f, 578.79596f, 1017.5732f, + 782.0135f, 829.91626f, 1862.7379f, 873.3065f, 1427.0167f, 1430.2573f, 1101.9182f, 1166.1411f, 1004.94086f, + 740.1837f, 865.7779f, 901.30756f, 691.9589f, 1674.3317f, 975.57794f, 1360.6948f, 755.89935f, 771.34845f, + 869.30835f, 1095.6376f, 906.3738f, 988.8938f, 835.76263f, 776.70294f, 875.6834f, 1070.8363f, 835.46124f, + 715.5161f, 755.64655f, 771.1005f, 764.50806f, 736.40924f, 884.8373f, 918.72284f, 893.98505f, 832.8749f, + 850.995f, 767.9733f, 848.3399f, 878.6838f, 906.1019f, 1403.8302f, 936.4296f, 846.2884f, 856.4901f, 1032.2576f, + 954.7542f, 1031.99f, 907.02155f, 1110.789f, 843.95215f, 1362.6506f, 884.8015f, 1684.2688f, 873.65204f, 855.7177f, + 996.56415f, 1061.6786f, 962.2358f, 1019.8985f, 1056.4193f, 1198.7231f, 1108.1361f, 1289.0095f, + 1069.4318f, 1001.13403f, 1030.4995f, 1734.2749f, 1063.2012f, 1447.3412f, 1234.2476f, 1144.3424f, 1049.7385f, + 811.9913f, 768.4231f, 1151.0692f, 877.0794f, 1146.4231f, 902.6157f, 1355.8434f, 897.39343f, 1260.1431f, 762.8625f, + 935.168f, 782.10785f, 996.2054f, 767.69214f, 1031.7415f, 775.9656f, 1374.9684f, 853.163f, 1456.6118f, 811.92523f, + 989.0328f, 744.7446f, 1166.4012f, 753.105f, 962.7312f, 780.272f + }; + + private static final float[] market_total_market = new float[]{ + 1000.0f, 1000.0f, 1040.9456f, 1689.0128f, 1049.142f, 1073.4766f, 1007.36554f, 1545.7089f, 1016.9652f, 1077.6127f, + 1075.0896f, 953.9954f, 1022.7833f, 937.06195f, 1156.7448f, 849.8775f, 1066.208f, 904.34064f, 1240.5255f, + 1343.2325f, 1088.9431f, 1349.2544f, 1102.8667f, 939.2441f, 1109.8754f, 997.99457f, 1037.4495f, 1686.4197f, + 1074.007f, 1486.2013f, 1300.3022f, 1021.3345f, 1314.6195f, 792.32605f, 1233.4489f, 805.9301f, 1184.9207f, + 1127.231f, 1203.4656f, 1100.9048f, 1097.2112f, 1410.793f, 1033.4012f, 1283.166f, 1025.6333f, 1331.861f, + 1039.5005f, 1332.4684f, 1011.20544f, 1029.9952f, 1047.2129f, 1057.08f, 1064.9727f, 1082.7277f, 971.0508f, + 1320.6383f, 1070.1655f, 1089.6478f, 980.3866f, 1179.6959f, 959.2362f, 1092.417f, 987.0674f, 1103.4583f, + 1091.2231f, 1199.6074f, 1044.3843f, 1183.2408f, 1289.0973f, 1360.0325f, 993.59125f, 1021.07117f, 1105.3834f, + 1601.8295f, 1200.5272f, 1600.7233f, 1317.4584f, 1304.3262f, 1544.1082f, 1488.7378f, 1224.8271f, 1421.6487f, + 1251.9062f, 1414.619f, 1350.1754f, 970.7283f, 1057.4272f, 1073.9673f, 996.4337f, 1743.9218f, 1044.5629f, + 1474.5911f, 1159.2788f, 1292.5428f, 1124.2014f, 1243.354f, 1051.809f, 1143.0784f, 1097.4907f, 1010.3703f, + 1326.8291f, 1179.8038f, 1281.6012f, 994.73126f, 1081.6504f, 1103.2397f, 1177.8584f, 1152.5477f, 1117.954f, + 1084.3325f, 1029.8025f, 1121.3854f, 1244.85f, 1077.2794f, 1098.5432f, 998.65076f, 1088.8076f, 1008.74554f, + 998.75397f, 1129.7233f, 1075.243f, 1141.5884f, 1037.3811f, 1099.1973f, 981.5773f, 1092.942f, 1072.2394f, + 1154.4156f, 1311.1786f, 1176.6052f, 1107.2202f, 1102.699f, 1285.0901f, 1217.5475f, 1283.957f, 1178.8302f, + 1301.7781f, 1119.2472f, 1403.3389f, 1156.6019f, 1429.5802f, 1137.8423f, 1124.9352f, 1256.4998f, 1217.8774f, + 1247.8909f, 1185.71f, 1345.7817f, 1250.1667f, 1390.754f, 1224.1162f, 1361.0802f, 1190.9337f, 1310.7971f, + 1466.2094f, 1366.4476f, 1314.8397f, 1522.0437f, 1193.5563f, 1321.375f, 1055.7837f, 1021.6387f, 1197.0084f, + 1131.532f, 1192.1443f, 1154.2896f, 1272.6771f, 1141.5146f, 1190.8961f, 1009.36316f, 1006.9138f, 1032.5999f, + 1137.3857f, 1030.0756f, 1005.25305f, 1030.0947f, 1112.7948f, 1113.3575f, 1153.9747f, 1069.6409f, 1016.13745f, + 994.9023f, 1032.1543f, 999.5864f, 994.75275f, 1029.057f + }; + + @Test + public void testAvg() + { + Random random = new Random(); + for (float[] values : Arrays.asList(market_upfront, market_total_market)) { + double sum = 0; + for (float f : values) { + sum += f; + } + final double avg = sum / values.length; + + AvgAggregatorCollector holder = new AvgAggregatorCollector(); + for (float f : values) { + holder.add(f); + } + Assert.assertEquals(holder.compute(), avg, 0.001); + + for (int mergeOn : new int[]{2, 3, 5, 9}) { + List holders1 = Lists.newArrayListWithCapacity(mergeOn); + List> holders2 = Lists.newArrayListWithCapacity(mergeOn); + + FloatHandOver valueHandOver = new FloatHandOver(); + for (int i = 0; i < mergeOn; i++) { + holders1.add(new AvgAggregatorCollector()); + holders2.add(Pair.of( + new AvgBufferAggregator.FloatAvgBufferAggregator(valueHandOver), + ByteBuffer.allocate(AvgAggregatorCollector.getMaxIntermediateSize()) + )); + } + for (float f : values) { + valueHandOver.v = f; + int index = random.nextInt(mergeOn); + holders1.get(index).add(f); + holders2.get(index).lhs.aggregate(holders2.get(index).rhs, 0); + } + AvgAggregatorCollector holder1 = holders1.get(0); + for (int i = 1; i < mergeOn; i++) { + holder1 = (AvgAggregatorCollector) AvgAggregatorCollector.combineValues(holder1, holders1.get(i)); + } + ObjectHandOver collectHandOver = new ObjectHandOver(); + ByteBuffer buffer = ByteBuffer.allocate(AvgAggregatorCollector.getMaxIntermediateSize()); + AvgBufferAggregator merger = new AvgBufferAggregator.ObjectAvgBufferAggregator(collectHandOver); + for (int i = 0; i < mergeOn; i++) { + collectHandOver.v = holders2.get(i).lhs.get(holders2.get(i).rhs, 0); + merger.aggregate(buffer, 0); + } + AvgAggregatorCollector holder2 = (AvgAggregatorCollector) merger.get(buffer, 0); + Assert.assertEquals(holder2.compute(), avg, 0.001); + } + } + } + + private static class FloatHandOver implements FloatColumnSelector + { + float v; + + @Override + public float get() + { + return v; + } + } + + private static class ObjectHandOver implements ObjectColumnSelector + { + Object v; + + @Override + public Class classOfObject() + { + return v == null ? Object.class : v.getClass(); + } + + @Override + public Object get() + { + return v; + } + } +} diff --git a/processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorTest.java b/processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorTest.java new file mode 100644 index 000000000000..86aef151ed61 --- /dev/null +++ b/processing/src/test/java/io/druid/query/aggregation/avg/AvgAggregatorTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import io.druid.jackson.DefaultObjectMapper; +import io.druid.query.aggregation.TestFloatColumnSelector; +import io.druid.query.aggregation.TestObjectColumnSelector; +import io.druid.segment.ColumnSelectorFactory; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.ByteBuffer; + +/** + */ +public class AvgAggregatorTest +{ + private AvgAggregatorFactory aggFactory; + private ColumnSelectorFactory colSelectorFactory; + private TestFloatColumnSelector selector; + + private final float[] values = {1.1f, 2.7f, 3.5f, 1.3f}; + + public AvgAggregatorTest() throws Exception + { + String aggSpecJson = "{\"type\": \"avg\", \"name\": \"billy\", \"fieldName\": \"nilly\"}"; + aggFactory = new DefaultObjectMapper().readValue(aggSpecJson, AvgAggregatorFactory.class); + } + + @Before + public void setup() + { + selector = new TestFloatColumnSelector(values); + colSelectorFactory = EasyMock.createMock(ColumnSelectorFactory.class); + EasyMock.expect(colSelectorFactory.makeObjectColumnSelector("nilly")).andReturn(new TestObjectColumnSelector(0.0f)); + EasyMock.expect(colSelectorFactory.makeFloatColumnSelector("nilly")).andReturn(selector); + EasyMock.replay(colSelectorFactory); + } + + @Test + public void testDoubleAvgAggregator() + { + AvgAggregator agg = (AvgAggregator) aggFactory.factorize(colSelectorFactory); + + assertValues((AvgAggregatorCollector) agg.get(), 0, 0d); + aggregate(selector, agg); + assertValues((AvgAggregatorCollector) agg.get(), 1, 1.1d); + aggregate(selector, agg); + assertValues((AvgAggregatorCollector) agg.get(), 2, 3.8d); + aggregate(selector, agg); + assertValues((AvgAggregatorCollector) agg.get(), 3, 7.3d); + aggregate(selector, agg); + assertValues((AvgAggregatorCollector) agg.get(), 4, 8.6d); + + agg.reset(); + assertValues((AvgAggregatorCollector) agg.get(), 0, 0d); + } + + private void assertValues(AvgAggregatorCollector holder, long count, double sum) + { + Assert.assertEquals(count, holder.count); + Assert.assertEquals(sum, holder.sum, 0.0001); + } + + @Test + public void testDoubleAvgBufferAggregator() + { + AvgBufferAggregator agg = (AvgBufferAggregator) aggFactory.factorizeBuffered( + colSelectorFactory + ); + + ByteBuffer buffer = ByteBuffer.wrap(new byte[aggFactory.getMaxIntermediateSize()]); + agg.init(buffer, 0); + + assertValues((AvgAggregatorCollector) agg.get(buffer, 0), 0, 0d); + aggregate(selector, agg, buffer, 0); + assertValues((AvgAggregatorCollector) agg.get(buffer, 0), 1, 1.1d); + aggregate(selector, agg, buffer, 0); + assertValues((AvgAggregatorCollector) agg.get(buffer, 0), 2, 3.8d); + aggregate(selector, agg, buffer, 0); + assertValues((AvgAggregatorCollector) agg.get(buffer, 0), 3, 7.3d); + aggregate(selector, agg, buffer, 0); + assertValues((AvgAggregatorCollector) agg.get(buffer, 0), 4, 8.6d); + } + + @Test + public void testCombine() + { + AvgAggregatorCollector holder1 = new AvgAggregatorCollector().add(1.1f).add(2.7f); + AvgAggregatorCollector holder2 = new AvgAggregatorCollector().add(3.5f).add(1.3f); + AvgAggregatorCollector expected = new AvgAggregatorCollector(4, 8.6d); + Assert.assertTrue(expected.equalsWithEpsilon( + (AvgAggregatorCollector) aggFactory.combine(holder1, holder2), + 0.00001 + )); + } + + @Test + public void testEqualsAndHashCode() throws Exception + { + AvgAggregatorFactory one = new AvgAggregatorFactory("name1", "fieldName1"); + AvgAggregatorFactory oneMore = new AvgAggregatorFactory("name1", "fieldName1"); + AvgAggregatorFactory two = new AvgAggregatorFactory("name2", "fieldName2"); + + Assert.assertEquals(one.hashCode(), oneMore.hashCode()); + + Assert.assertTrue(one.equals(oneMore)); + Assert.assertFalse(one.equals(two)); + } + + private void aggregate(TestFloatColumnSelector selector, AvgAggregator agg) + { + agg.aggregate(); + selector.increment(); + } + + private void aggregate( + TestFloatColumnSelector selector, + AvgBufferAggregator agg, + ByteBuffer buff, + int position + ) + { + agg.aggregate(buff, position); + selector.increment(); + } +} diff --git a/processing/src/test/java/io/druid/query/aggregation/avg/AvgGroupByQueryTest.java b/processing/src/test/java/io/druid/query/aggregation/avg/AvgGroupByQueryTest.java new file mode 100644 index 000000000000..96a7e2c5c65f --- /dev/null +++ b/processing/src/test/java/io/druid/query/aggregation/avg/AvgGroupByQueryTest.java @@ -0,0 +1,163 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.MoreExecutors; +import io.druid.data.input.Row; +import io.druid.query.QueryRunner; +import io.druid.query.QueryRunnerTestHelper; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.LongSumAggregatorFactory; +import io.druid.query.dimension.DefaultDimensionSpec; +import io.druid.query.dimension.DimensionSpec; +import io.druid.query.groupby.GroupByQuery; +import io.druid.query.groupby.GroupByQueryConfig; +import io.druid.query.groupby.GroupByQueryRunnerFactory; +import io.druid.query.groupby.GroupByQueryRunnerTest; +import io.druid.query.groupby.GroupByQueryRunnerTestHelper; +import io.druid.segment.TestHelper; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + */ +@RunWith(Parameterized.class) +public class AvgGroupByQueryTest +{ + private final GroupByQueryConfig config; + private final QueryRunner runner; + private final GroupByQueryRunnerFactory factory; + private final String testName; + + @Parameterized.Parameters(name = "{0}") + public static Collection constructorFeeder() throws IOException + { + return GroupByQueryRunnerTest.constructorFeeder(); + } + + public AvgGroupByQueryTest( + String testName, + GroupByQueryConfig config, + GroupByQueryRunnerFactory factory, + QueryRunner runner + ) + { + this.testName = testName; + this.config = config; + this.factory = factory; + this.runner = factory.mergeRunners(MoreExecutors.sameThreadExecutor(), ImmutableList.>of(runner)); + } + + @Test + public void testGroupByAvgOnly() + { + GroupByQuery query = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird) + .setDimensions(Lists.newArrayList(new DefaultDimensionSpec("quality", "alias"))) + .setAggregatorSpecs(Arrays.asList(AvgTestHelper.indexAvgAggr)) + .setGranularity(QueryRunnerTestHelper.dayGran) + .build(); + + AvgTestHelper.RowBuilder builder = + new AvgTestHelper.RowBuilder(new String[]{"alias", "index_var"}); + + List expectedResults = builder + .add("2011-04-01", "automotive", 135.885094d) + .add("2011-04-01", "business", 118.570340d) + .add("2011-04-01", "entertainment", 158.747224d) + .add("2011-04-01", "health", 120.134704d) + .add("2011-04-01", "mezzanine", 957.295563d) + .add("2011-04-01", "news", 121.583581d) + .add("2011-04-01", "premium", 966.932882d) + .add("2011-04-01", "technology", 78.622547d) + .add("2011-04-01", "travel", 119.922742d) + + .add("2011-04-02", "automotive", 147.425935d) + .add("2011-04-02", "business", 112.987027d) + .add("2011-04-02", "entertainment", 166.016049d) + .add("2011-04-02", "health", 113.446008d) + .add("2011-04-02", "mezzanine", 816.276871d) + .add("2011-04-02", "news", 114.290141d) + .add("2011-04-02", "premium", 835.471716d) + .add("2011-04-02", "technology", 97.387433d) + .add("2011-04-02", "travel", 126.411364d) + .build(); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, ""); + } + + @Test + public void testGroupBy() + { + GroupByQuery query = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird) + .setDimensions(Lists.newArrayList(new DefaultDimensionSpec("quality", "alias"))) + .setAggregatorSpecs( + Arrays.asList( + AvgTestHelper.rowsCount, + AvgTestHelper.indexAvgAggr, + new LongSumAggregatorFactory("idx", "index") + ) + ) + .setGranularity(QueryRunnerTestHelper.dayGran) + .build(); + + AvgTestHelper.RowBuilder builder = + new AvgTestHelper.RowBuilder(new String[]{"alias", "rows", "idx", "index_var"}); + + List expectedResults = builder + .add("2011-04-01", "automotive", 1L, 135L, 135.885094d) + .add("2011-04-01", "business", 1L, 118L, 118.570340d) + .add("2011-04-01", "entertainment", 1L, 158L, 158.747224d) + .add("2011-04-01", "health", 1L, 120L, 120.134704d) + .add("2011-04-01", "mezzanine", 3L, 2870L, 957.295563d) + .add("2011-04-01", "news", 1L, 121L, 121.583581d) + .add("2011-04-01", "premium", 3L, 2900L, 966.932882d) + .add("2011-04-01", "technology", 1L, 78L, 78.622547d) + .add("2011-04-01", "travel", 1L, 119L, 119.922742d) + + .add("2011-04-02", "automotive", 1L, 147L, 147.425935d) + .add("2011-04-02", "business", 1L, 112L, 112.987027d) + .add("2011-04-02", "entertainment", 1L, 166L, 166.016049d) + .add("2011-04-02", "health", 1L, 113L, 113.446008d) + .add("2011-04-02", "mezzanine", 3L, 2447L, 816.276871d) + .add("2011-04-02", "news", 1L, 114L, 114.290141d) + .add("2011-04-02", "premium", 3L, 2505L, 835.471716d) + .add("2011-04-02", "technology", 1L, 97L, 97.387433d) + .add("2011-04-02", "travel", 1L, 126L, 126.411364d) + .build(); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, ""); + } +} diff --git a/processing/src/test/java/io/druid/query/aggregation/avg/AvgSerdeTest.java b/processing/src/test/java/io/druid/query/aggregation/avg/AvgSerdeTest.java new file mode 100644 index 000000000000..29618dd9854a --- /dev/null +++ b/processing/src/test/java/io/druid/query/aggregation/avg/AvgSerdeTest.java @@ -0,0 +1,46 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import io.druid.segment.data.ObjectStrategy; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.Random; + +public class AvgSerdeTest +{ + @Test + public void testSerde() + { + Random r = new Random(); + AvgAggregatorCollector holder = new AvgAggregatorCollector(); + ObjectStrategy strategy = new AvgSerde().getObjectStrategy(); + Assert.assertEquals(AvgAggregatorCollector.class, strategy.getClazz()); + + for (int i = 0; i < 100; i++) { + byte[] array = strategy.toBytes(holder); + Assert.assertArrayEquals(array, holder.toByteArray()); + Assert.assertEquals(holder, strategy.fromByteBuffer(ByteBuffer.wrap(array), array.length)); + holder.add(r.nextFloat()); + } + } +} diff --git a/processing/src/test/java/io/druid/query/aggregation/avg/AvgTestHelper.java b/processing/src/test/java/io/druid/query/aggregation/avg/AvgTestHelper.java new file mode 100644 index 000000000000..6d9f6b80d727 --- /dev/null +++ b/processing/src/test/java/io/druid/query/aggregation/avg/AvgTestHelper.java @@ -0,0 +1,91 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import io.druid.data.input.MapBasedRow; +import io.druid.data.input.Row; +import io.druid.query.QueryRunnerTestHelper; +import io.druid.query.aggregation.AggregatorFactory; +import org.joda.time.DateTime; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** + */ +public class AvgTestHelper extends QueryRunnerTestHelper +{ + public static final String indexAvgMetric = "index_var"; + + public static final AvgAggregatorFactory indexAvgAggr = new AvgAggregatorFactory( + indexAvgMetric, + indexMetric + ); + + public static final List commonPlusVarAggregators = Arrays.asList( + rowsCount, + indexDoubleSum, + qualityUniques, + indexAvgAggr + ); + + public static class RowBuilder + { + private final String[] names; + private final List rows = Lists.newArrayList(); + + public RowBuilder(String[] names) + { + this.names = names; + } + + public RowBuilder add(final String timestamp, Object... values) + { + rows.add(build(timestamp, values)); + return this; + } + + public List build() + { + try { + return Lists.newArrayList(rows); + } + finally { + rows.clear(); + } + } + + public Row build(final String timestamp, Object... values) + { + Preconditions.checkArgument(names.length == values.length); + + Map theVals = Maps.newHashMap(); + for (int i = 0; i < values.length; i++) { + theVals.put(names[i], values[i]); + } + DateTime ts = new DateTime(timestamp); + return new MapBasedRow(ts, theVals); + } + } +} diff --git a/processing/src/test/java/io/druid/query/aggregation/avg/AvgTimeseriesQueryTest.java b/processing/src/test/java/io/druid/query/aggregation/avg/AvgTimeseriesQueryTest.java new file mode 100644 index 000000000000..65f748cad4d2 --- /dev/null +++ b/processing/src/test/java/io/druid/query/aggregation/avg/AvgTimeseriesQueryTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.google.common.collect.Lists; +import io.druid.java.util.common.guava.Sequences; +import io.druid.query.Druids; +import io.druid.query.QueryRunner; +import io.druid.query.Result; +import io.druid.query.aggregation.PostAggregator; +import io.druid.query.timeseries.TimeseriesQuery; +import io.druid.query.timeseries.TimeseriesQueryRunnerTest; +import io.druid.query.timeseries.TimeseriesResultValue; +import io.druid.segment.TestHelper; +import org.joda.time.DateTime; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +@RunWith(Parameterized.class) +public class AvgTimeseriesQueryTest +{ + @Parameterized.Parameters(name = "{0}:descending={1}") + public static Iterable constructorFeeder() throws IOException + { + return TimeseriesQueryRunnerTest.constructorFeeder(); + } + + private final QueryRunner runner; + private final boolean descending; + + public AvgTimeseriesQueryTest(QueryRunner runner, boolean descending) + { + this.runner = runner; + this.descending = descending; + } + + @Test + public void testTimeseriesWithNullFilterOnNonExistentDimension() + { + TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource(AvgTestHelper.dataSource) + .granularity(AvgTestHelper.dayGran) + .filters("bobby", null) + .intervals(AvgTestHelper.firstToThird) + .aggregators(AvgTestHelper.commonPlusVarAggregators) + .postAggregators( + Arrays.asList( + AvgTestHelper.addRowsIndexConstant + ) + ) + .descending(descending) + .build(); + + List> expectedResults = Arrays.asList( + new Result<>( + new DateTime("2011-04-01"), + new TimeseriesResultValue( + AvgTestHelper.of( + "rows", 13L, + "index", 6626.151596069336, + "addRowsIndexConstant", 6640.151596069336, + "uniques", AvgTestHelper.UNIQUES_9, + "index_var", 509.70396892841046 + ) + ) + ), + new Result<>( + new DateTime("2011-04-02"), + new TimeseriesResultValue( + AvgTestHelper.of( + "rows", 13L, + "index", 5833.2095947265625, + "addRowsIndexConstant", 5847.2095947265625, + "uniques", AvgTestHelper.UNIQUES_9, + "index_var", 448.7084303635817 + ) + ) + ) + ); + + Iterable> results = Sequences.toList( + runner.run(query, new HashMap()), + Lists.>newArrayList() + ); + assertExpectedResults(expectedResults, results); + } + + private void assertExpectedResults(Iterable> expectedResults, Iterable> results) + { + if (descending) { + expectedResults = TestHelper.revert(expectedResults); + } + TestHelper.assertExpectedResults(expectedResults, results); + } +} diff --git a/processing/src/test/java/io/druid/query/aggregation/avg/AvgTopNQueryTest.java b/processing/src/test/java/io/druid/query/aggregation/avg/AvgTopNQueryTest.java new file mode 100644 index 000000000000..fee6d4a899af --- /dev/null +++ b/processing/src/test/java/io/druid/query/aggregation/avg/AvgTopNQueryTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.aggregation.avg; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import io.druid.java.util.common.guava.Sequence; +import io.druid.query.QueryRunner; +import io.druid.query.QueryRunnerTestHelper; +import io.druid.query.Result; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.DoubleMaxAggregatorFactory; +import io.druid.query.aggregation.DoubleMinAggregatorFactory; +import io.druid.query.aggregation.PostAggregator; +import io.druid.query.topn.TopNQuery; +import io.druid.query.topn.TopNQueryBuilder; +import io.druid.query.topn.TopNQueryConfig; +import io.druid.query.topn.TopNQueryQueryToolChest; +import io.druid.query.topn.TopNQueryRunnerTest; +import io.druid.query.topn.TopNResultValue; +import io.druid.segment.TestHelper; +import org.joda.time.DateTime; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +@RunWith(Parameterized.class) +public class AvgTopNQueryTest +{ + @Parameterized.Parameters(name = "{0}") + public static Iterable constructorFeeder() throws IOException + { + return TopNQueryRunnerTest.constructorFeeder(); + } + + private final QueryRunner runner; + + public AvgTopNQueryTest( + QueryRunner runner + ) + { + this.runner = runner; + } + + @Test + public void testFullOnTopNOverUniques() + { + TopNQuery query = new TopNQueryBuilder() + .dataSource(QueryRunnerTestHelper.dataSource) + .granularity(QueryRunnerTestHelper.allGran) + .dimension(QueryRunnerTestHelper.marketDimension) + .metric(QueryRunnerTestHelper.uniqueMetric) + .threshold(3) + .intervals(QueryRunnerTestHelper.fullOnInterval) + .aggregators( + Lists.newArrayList( + Iterables.concat( + AvgTestHelper.commonPlusVarAggregators, + Lists.newArrayList( + new DoubleMaxAggregatorFactory("maxIndex", "index"), + new DoubleMinAggregatorFactory("minIndex", "index") + ) + ) + ) + ) + .postAggregators(Arrays.asList(QueryRunnerTestHelper.addRowsIndexConstant)) + .build(); + + List> expectedResults = Arrays.asList( + new Result( + new DateTime("2011-01-12T00:00:00.000Z"), + new TopNResultValue( + Arrays.>asList( + ImmutableMap.builder() + .put("market", "spot") + .put("rows", 837L) + .put("index", 95606.57232284546D) + .put("addRowsIndexConstant", 96444.57232284546D) + .put("uniques", QueryRunnerTestHelper.UNIQUES_9) + .put("maxIndex", 277.2735290527344D) + .put("minIndex", 59.02102279663086D) + .put("index_var", 114.22529548727056D) + .build(), + ImmutableMap.builder() + .put("market", "total_market") + .put("rows", 186L) + .put("index", 215679.82879638672D) + .put("addRowsIndexConstant", 215866.82879638672D) + .put("uniques", QueryRunnerTestHelper.UNIQUES_2) + .put("maxIndex", 1743.9217529296875D) + .put("minIndex", 792.3260498046875D) + .put("index_var", 1159.5689720235846D) + .build(), + ImmutableMap.builder() + .put("market", "upfront") + .put("rows", 186L) + .put("index", 192046.1060180664D) + .put("addRowsIndexConstant", 192233.1060180664D) + .put("uniques", QueryRunnerTestHelper.UNIQUES_2) + .put("maxIndex", 1870.06103515625D) + .put("minIndex", 545.9906005859375D) + .put("index_var", 1032.5059463336904D) + .build() + ) + ) + ) + ); + assertExpectedResults(expectedResults, query); + } + + private Sequence> assertExpectedResults( + Iterable> expectedResults, + TopNQuery query + ) + { + final TopNQueryQueryToolChest chest = new TopNQueryQueryToolChest( + new TopNQueryConfig(), + QueryRunnerTestHelper.NoopIntervalChunkingQueryRunnerDecorator() + ); + final QueryRunner> mergeRunner = chest.mergeResults(runner); + final Sequence> retval = mergeRunner.run(query, ImmutableMap.of()); + TestHelper.assertExpectedResults(expectedResults, retval); + return retval; + } + +}