From 5bb6181f46c4443e23c8892a5f2efb8e0b326263 Mon Sep 17 00:00:00 2001 From: geonove Date: Wed, 21 May 2025 00:06:13 +0200 Subject: [PATCH 1/5] CountMinSketch initial commit --- .../java/org/apache/datasketches/count/CountMinSketch.java | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/main/java/org/apache/datasketches/count/CountMinSketch.java diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java new file mode 100644 index 000000000..31f0b9fd1 --- /dev/null +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -0,0 +1,5 @@ +package org.apache.datasketches.count; + +public class CountMinSketch { + +} From 799436bcbbb4163e6176ffed2e7446e77083e24c Mon Sep 17 00:00:00 2001 From: Andrea Novellini Date: Thu, 22 May 2025 01:13:36 +0200 Subject: [PATCH 2/5] Implement count min sketch1 --- .../datasketches/count/CountMinSketch.java | 295 +++++++++++++++++- 1 file changed, 294 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index 31f0b9fd1..fc3dfde78 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -1,5 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.datasketches.count; +import org.apache.datasketches.common.Family; +import org.apache.datasketches.common.SketchesException; +import org.apache.datasketches.hash.MurmurHash3; +import org.apache.datasketches.tuple.Util; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Random; + + public class CountMinSketch { - + private final byte numHashes_; + private final int numBuckets_; + private final long seed_; + private final long[] hashSeeds_; + private final long[] sketchArray_; + private long totalWeight_; + + private static final int IS_EMPTY = 0; + + /** + * Creates a CountMin sketch with given number of hash functions and buckets, + * and a user-specified seed. + * + * @param numHashes The number of hash functions to apply to items + * @param numBuckets Array size for each of the hashing function + * @param seed The base hash seed + */ + CountMinSketch(final byte numHashes, final int numBuckets, final long seed) { + numHashes_ = numHashes; + numBuckets_ = numBuckets; + seed_ = seed; + hashSeeds_ = new long[numHashes]; + sketchArray_ = new long[numHashes * numBuckets]; + totalWeight_ = 0; + + if (numBuckets < 3) { + throw new SketchesException("Using fewer than 3 buckets incurs relative error greater than 1."); + } + + // This check is to ensure later compatibility with a Java implementation whose maximum size can only + // be 2^31-1. We check only against 2^30 for simplicity. + if (numBuckets * numHashes >= 1 << 30) { + throw new SketchesException("These parameters generate a sketch that exceeds 2^30 elements. \n" + + "Try reducing either the number of buckets or the number of hash functions."); + } + + Random rand = new Random(); + for (int i = 0; i < numHashes; i++) { + hashSeeds_[i] = rand.nextLong(); + } + } + + private long[] getHashes(byte[] item) { + long[] updateLocations = new long[numHashes_]; + + for (int i = 0; i < numHashes_; i++) { + long[] index = MurmurHash3.hash(item, hashSeeds_[i]); + updateLocations[i] = i * (long)numBuckets_ + index[0] % numBuckets_; + } + + return updateLocations; + } + + public boolean isEmpty() { + return totalWeight_ == 0; + } + + public byte getNumHashes_() { + return numHashes_; + } + + public int getNumBuckets_() { + return numBuckets_; + } + + public long getSeed_() { + return seed_; + } + + public long getTotalWeight_() { + return totalWeight_; + } + + public double getRelativeError() { + return Math.exp(1.0) / (double)numBuckets_; + } + + public byte suggestNumHashes(double confidence) { + if (confidence < 0 || confidence > 1) { + throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive)."); + } + int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence))); + return (byte) Math.min(value, 127); + } + + public int suggestNumBuckets(double relativeError) { + return (int) Math.ceil(Math.exp(1.0) / relativeError); + } + + public void update(final long item, final long weight) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + update(longByte, weight); + } + + public void update(final String item, final long weight) { + if (item == null || item.isEmpty()) { + return; + } + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + update(strByte, weight); + } + + public void update(final byte[] item, final long weight) { + if (item.length == 0) { + return; + } + + totalWeight_ += weight > 0 ? weight : -weight; + long[] hashLocations = getHashes(item); + for (long h : hashLocations) { + sketchArray_[(int) h] += weight; + } + } + + public long getEstimate(final long item) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + return getEstimate(longByte); + } + + public long getEstimate(final String item) { + if (item == null || item.isEmpty()) { + return 0; + } + + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getEstimate(strByte); + } + + public long getEstimate(final byte[] item) { + if (item.length == 0) { + return 0; + } + + long[] hashLocations = getHashes(item); + long res = sketchArray_[(int) hashLocations[0]]; + for (long h : hashLocations) { + res = Math.min(res, sketchArray_[(int) h]); + } + + return res; + } + + public long getUpperBound(final long item) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + return getUpperBound(longByte); + } + + public long getUpperBound(final String item) { + if (item == null || item.isEmpty()) { + return 0; + } + + byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getUpperBound(strByte); + } + + public long getUpperBound(final byte[] item) { + if (item.length == 0) { + return 0; + } + + return getEstimate(item) + (long)(getRelativeError() * getTotalWeight_()); + } + + public long getLowerBound(final long item) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + return getLowerBound(longByte); + } + + public long getLowerBound(final String item) { + if (item == null || item.isEmpty()) { + return 0; + } + + byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getLowerBound(strByte); + } + + public long getLowerBound(final byte[] item) { + return getEstimate(item); + } + + public void merge(final CountMinSketch other) { + if (this == other) { + throw new SketchesException("Cannot merge a sketch with itself"); + } + + boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() && + getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_(); + + if (!acceptableConfig) { + throw new SketchesException("Incompatible sketch configuration."); + } + + for (int i = 0; i < sketchArray_.length; i++) { + sketchArray_[i] += other.sketchArray_[i]; + } + + totalWeight_ += other.getTotalWeight_(); + } + + public void serialize(ByteBuffer buf) { + // Long 0 + final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); + buf.put((byte) preambleLongs); + final int serialVersion = 1; + buf.put((byte) serialVersion); + final int familyId = Family.COUNTMIN.getID(); + buf.put((byte) familyId); + final int flagsByte = isEmpty() ? 1 << IS_EMPTY : 0; + buf.put((byte)flagsByte); + final int NULL_32 = 0; + buf.putInt(NULL_32); + + // Long 1 + buf.putInt(numBuckets_); + buf.putShort(numHashes_); + buf.putShort(Util.computeSeedHash(seed_)); + final byte NULL_8 = 0; + buf.put(NULL_8); + if (isEmpty()) { + return; + } + + buf.putLong(totalWeight_); + + for (long estimate: sketchArray_) { + buf.putLong(estimate); + } + } + + public static CountMinSketch deserialize(final byte[] b, final long seed) { + ByteBuffer buf = ByteBuffer.allocate(b.length); + buf.put(b); + + final byte preambleLongs = buf.get(); + final byte serialVersion = buf.get(); + final byte familyId = buf.get(); + final byte flagsByte = buf.get(); + final int NULL_32 = buf.getInt(); + + final int numBuckets = buf.getInt(); + final byte numHashes = buf.get(); + final short seedHash = buf.getShort(); + final byte NULL_8 = buf.get(); + + if (seedHash != Util.computeSeedHash(seed)) { + throw new SketchesException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " + + String.valueOf(Util.computeSeedHash(seed))); + } + + CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed); + final boolean empty = (flagsByte & (1 << IS_EMPTY)) > 0; + if (empty) { + return cms; + } + + int i = 0; + while (buf.hasRemaining()) { + cms.sketchArray_[i] = buf.getLong(); + } + + return cms; + } } From 5b5f7ace065657f1a1a9cc0153634aa514aeea75 Mon Sep 17 00:00:00 2001 From: Andrea Novellini Date: Sat, 24 May 2025 23:36:11 +0200 Subject: [PATCH 3/5] Add documentation --- .../datasketches/count/CountMinSketch.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index fc3dfde78..be2669ee9 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -84,30 +84,59 @@ private long[] getHashes(byte[] item) { return updateLocations; } + /** + * Checks if the CountMinSketch has processed any items. + * @return True if the sketch is empty, otherwise false. + */ public boolean isEmpty() { return totalWeight_ == 0; } + /** + * Returns the number of hash functions used in this sketch. + * @return The number of hash functions. + */ public byte getNumHashes_() { return numHashes_; } + /** + * Returns the number of buckets per hash function. + * @return The number of buckets. + */ public int getNumBuckets_() { return numBuckets_; } + /** + * Returns the hash seed used by this sketch. + * @return The seed value. + */ public long getSeed_() { return seed_; } + /** + * Returns the total weight of all items inserted into the sketch. + * @return The total weight. + */ public long getTotalWeight_() { return totalWeight_; } + /** + * Returns the relative error of the sketch. + * @return The relative error. + */ public double getRelativeError() { return Math.exp(1.0) / (double)numBuckets_; } + /** + * Suggests an appropriate number of hash functions to use for a given confidence level. + * @param confidence The desired confidence level between 0 and 1. + * @return Suggested number of hash functions. + */ public byte suggestNumHashes(double confidence) { if (confidence < 0 || confidence > 1) { throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive)."); @@ -116,15 +145,30 @@ public byte suggestNumHashes(double confidence) { return (byte) Math.min(value, 127); } + /** + * Suggests an appropriate number of buckets per hash function for a given relative error. + * @param relativeError The desired relative error. + * @return Suggested number of buckets. + */ public int suggestNumBuckets(double relativeError) { return (int) Math.ceil(Math.exp(1.0) / relativeError); } + /** + * Updates the sketch with the provided item and weight. + * @param item The item to update. + * @param weight The weight of the item. + */ public void update(final long item, final long weight) { byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); update(longByte, weight); } + /** + * Updates the sketch with the provided item and weight. + * @param item The item to update. + * @param weight The weight of the item. + */ public void update(final String item, final long weight) { if (item == null || item.isEmpty()) { return; @@ -133,6 +177,11 @@ public void update(final String item, final long weight) { update(strByte, weight); } + /** + * Updates the sketch with the provided item and weight. + * @param item The item to update. + * @param weight The weight of the item. + */ public void update(final byte[] item, final long weight) { if (item.length == 0) { return; @@ -145,11 +194,21 @@ public void update(final byte[] item, final long weight) { } } + /** + * Returns the estimated frequency for the given item. + * @param item The item to estimate. + * @return Estimated frequency. + */ public long getEstimate(final long item) { byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); return getEstimate(longByte); } + /** + * Returns the estimated frequency for the given item. + * @param item The item to estimate. + * @return Estimated frequency. + */ public long getEstimate(final String item) { if (item == null || item.isEmpty()) { return 0; @@ -159,6 +218,11 @@ public long getEstimate(final String item) { return getEstimate(strByte); } + /** + * Returns the estimated frequency for the given item. + * @param item The item to estimate. + * @return Estimated frequency. + */ public long getEstimate(final byte[] item) { if (item.length == 0) { return 0; @@ -173,11 +237,21 @@ public long getEstimate(final byte[] item) { return res; } + /** + * Returns the upper bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Upper bound of estimated frequency. + */ public long getUpperBound(final long item) { byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); return getUpperBound(longByte); } + /** + * Returns the upper bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Upper bound of estimated frequency. + */ public long getUpperBound(final String item) { if (item == null || item.isEmpty()) { return 0; @@ -187,6 +261,11 @@ public long getUpperBound(final String item) { return getUpperBound(strByte); } + /** + * Returns the upper bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Upper bound of estimated frequency. + */ public long getUpperBound(final byte[] item) { if (item.length == 0) { return 0; @@ -195,11 +274,21 @@ public long getUpperBound(final byte[] item) { return getEstimate(item) + (long)(getRelativeError() * getTotalWeight_()); } + /** + * Returns the lower bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Lower bound of estimated frequency. + */ public long getLowerBound(final long item) { byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); return getLowerBound(longByte); } + /** + * Returns the lower bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Lower bound of estimated frequency. + */ public long getLowerBound(final String item) { if (item == null || item.isEmpty()) { return 0; @@ -209,10 +298,19 @@ public long getLowerBound(final String item) { return getLowerBound(strByte); } + /** + * Returns the lower bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Lower bound of estimated frequency. + */ public long getLowerBound(final byte[] item) { return getEstimate(item); } + /** + * Merges another CountMinSketch into this one. The sketches must have the same configuration. + * @param other The other sketch to merge. + */ public void merge(final CountMinSketch other) { if (this == other) { throw new SketchesException("Cannot merge a sketch with itself"); @@ -232,6 +330,10 @@ public void merge(final CountMinSketch other) { totalWeight_ += other.getTotalWeight_(); } + /** + * Serializes the sketch into the provided ByteBuffer. + * @param buf The ByteBuffer to write into. + */ public void serialize(ByteBuffer buf) { // Long 0 final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); @@ -262,6 +364,12 @@ public void serialize(ByteBuffer buf) { } } + /** + * Deserializes a CountMinSketch from the provided byte array. + * @param b The byte array containing the serialized sketch. + * @param seed The seed used during serialization. + * @return The deserialized CountMinSketch. + */ public static CountMinSketch deserialize(final byte[] b, final long seed) { ByteBuffer buf = ByteBuffer.allocate(b.length); buf.put(b); From b3d9f3fbb32756c87db7737ac7a61b9e8e28cf03 Mon Sep 17 00:00:00 2001 From: Andrea Novellini Date: Sat, 24 May 2025 23:44:32 +0200 Subject: [PATCH 4/5] Flag isEmpty to be enum --- .../apache/datasketches/count/CountMinSketch.java | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index be2669ee9..959486757 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -24,7 +24,6 @@ import org.apache.datasketches.hash.MurmurHash3; import org.apache.datasketches.tuple.Util; -import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Random; @@ -38,7 +37,14 @@ public class CountMinSketch { private final long[] sketchArray_; private long totalWeight_; - private static final int IS_EMPTY = 0; + + private enum Flag { + IS_EMPTY; + + int mask() { + return 1 << ordinal(); + } + } /** * Creates a CountMin sketch with given number of hash functions and buckets, @@ -342,7 +348,7 @@ public void serialize(ByteBuffer buf) { buf.put((byte) serialVersion); final int familyId = Family.COUNTMIN.getID(); buf.put((byte) familyId); - final int flagsByte = isEmpty() ? 1 << IS_EMPTY : 0; + final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; buf.put((byte)flagsByte); final int NULL_32 = 0; buf.putInt(NULL_32); @@ -391,7 +397,7 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) { } CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed); - final boolean empty = (flagsByte & (1 << IS_EMPTY)) > 0; + final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0; if (empty) { return cms; } From aec9a7ef467a72ea4a95f9967647887b8d8abb87 Mon Sep 17 00:00:00 2001 From: Andrea Novellini Date: Sun, 25 May 2025 02:59:09 +0200 Subject: [PATCH 5/5] Add tests --- .../datasketches/count/CountMinSketch.java | 53 ++-- .../count/CountMinSketchTest.java | 246 ++++++++++++++++++ 2 files changed, 277 insertions(+), 22 deletions(-) create mode 100644 src/test/java/org/apache/datasketches/count/CountMinSketchTest.java diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java index 959486757..36bea38cf 100644 --- a/src/main/java/org/apache/datasketches/count/CountMinSketch.java +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -20,10 +20,12 @@ package org.apache.datasketches.count; import org.apache.datasketches.common.Family; +import org.apache.datasketches.common.SketchesArgumentException; import org.apache.datasketches.common.SketchesException; import org.apache.datasketches.hash.MurmurHash3; import org.apache.datasketches.tuple.Util; +import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Random; @@ -63,17 +65,17 @@ int mask() { totalWeight_ = 0; if (numBuckets < 3) { - throw new SketchesException("Using fewer than 3 buckets incurs relative error greater than 1."); + throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1."); } // This check is to ensure later compatibility with a Java implementation whose maximum size can only // be 2^31-1. We check only against 2^30 for simplicity. if (numBuckets * numHashes >= 1 << 30) { - throw new SketchesException("These parameters generate a sketch that exceeds 2^30 elements. \n" + + throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" + "Try reducing either the number of buckets or the number of hash functions."); } - Random rand = new Random(); + Random rand = new Random(seed); for (int i = 0; i < numHashes; i++) { hashSeeds_[i] = rand.nextLong(); } @@ -84,7 +86,7 @@ private long[] getHashes(byte[] item) { for (int i = 0; i < numHashes_; i++) { long[] index = MurmurHash3.hash(item, hashSeeds_[i]); - updateLocations[i] = i * (long)numBuckets_ + index[0] % numBuckets_; + updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_); } return updateLocations; @@ -143,7 +145,7 @@ public double getRelativeError() { * @param confidence The desired confidence level between 0 and 1. * @return Suggested number of hash functions. */ - public byte suggestNumHashes(double confidence) { + public static byte suggestNumHashes(double confidence) { if (confidence < 0 || confidence > 1) { throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive)."); } @@ -156,7 +158,10 @@ public byte suggestNumHashes(double confidence) { * @param relativeError The desired relative error. * @return Suggested number of buckets. */ - public int suggestNumBuckets(double relativeError) { + public static int suggestNumBuckets(double relativeError) { + if (relativeError < 0.) { + throw new SketchesException("Relative error must be at least 0."); + } return (int) Math.ceil(Math.exp(1.0) / relativeError); } @@ -340,33 +345,35 @@ public void merge(final CountMinSketch other) { * Serializes the sketch into the provided ByteBuffer. * @param buf The ByteBuffer to write into. */ - public void serialize(ByteBuffer buf) { + public void serialize(ByteArrayOutputStream buf) { // Long 0 final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); - buf.put((byte) preambleLongs); + buf.write((byte) preambleLongs); final int serialVersion = 1; - buf.put((byte) serialVersion); + buf.write((byte) serialVersion); final int familyId = Family.COUNTMIN.getID(); - buf.put((byte) familyId); + buf.write((byte) familyId); final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; - buf.put((byte)flagsByte); + buf.write((byte)flagsByte); final int NULL_32 = 0; - buf.putInt(NULL_32); + buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array()); // Long 1 - buf.putInt(numBuckets_); - buf.putShort(numHashes_); - buf.putShort(Util.computeSeedHash(seed_)); + buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array()); + buf.write(numHashes_); + short hashSeed = Util.computeSeedHash(seed_); + buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array()); final byte NULL_8 = 0; - buf.put(NULL_8); + buf.write(NULL_8); if (isEmpty()) { return; } - buf.putLong(totalWeight_); + final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array(); + buf.writeBytes(totWeightByte); - for (long estimate: sketchArray_) { - buf.putLong(estimate); + for (long w: sketchArray_) { + buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array()); } } @@ -379,6 +386,7 @@ public void serialize(ByteBuffer buf) { public static CountMinSketch deserialize(final byte[] b, final long seed) { ByteBuffer buf = ByteBuffer.allocate(b.length); buf.put(b); + buf.flip(); final byte preambleLongs = buf.get(); final byte serialVersion = buf.get(); @@ -392,7 +400,7 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) { final byte NULL_8 = buf.get(); if (seedHash != Util.computeSeedHash(seed)) { - throw new SketchesException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " + throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " + String.valueOf(Util.computeSeedHash(seed))); } @@ -401,9 +409,10 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) { if (empty) { return cms; } + long w = buf.getLong(); + cms.totalWeight_ = w; - int i = 0; - while (buf.hasRemaining()) { + for (int i = 0; i < cms.sketchArray_.length; i++) { cms.sketchArray_[i] = buf.getLong(); } diff --git a/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java new file mode 100644 index 000000000..cbd2fde79 --- /dev/null +++ b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datasketches.count; + +import org.apache.datasketches.common.SketchesArgumentException; +import org.apache.datasketches.common.SketchesException; +import org.testng.annotations.Test; + +import java.io.ByteArrayOutputStream; +import java.lang.annotation.Repeatable; +import java.nio.ByteBuffer; + +import static org.testng.Assert.*; + +public class CountMinSketchTest { + @Test + public void createNewCountMinSketchTest() throws Exception { + assertThrows(SketchesArgumentException.class, () -> new CountMinSketch((byte) 5, 1, 123)); + assertThrows(SketchesArgumentException.class, () -> new CountMinSketch((byte) 4, 268435456, 123)); + + final byte numHashes = 3; + final int numBuckets = 5; + final long seed = 1234567; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + + assertEquals(c.getNumHashes_(), numHashes); + assertEquals(c.getNumBuckets_(), numBuckets); + assertEquals(c.getSeed_(), seed); + assertTrue(c.isEmpty()); + } + + @Test + public void parameterSuggestionsTest() { + // Bucket suggestions + assertThrows("Relative error must be at least 0.", SketchesException.class, () -> CountMinSketch.suggestNumBuckets(-1.0)); + assertEquals(CountMinSketch.suggestNumBuckets(0.2), 14); + assertEquals(CountMinSketch.suggestNumBuckets(0.1), 28); + assertEquals(CountMinSketch.suggestNumBuckets(0.05), 55); + assertEquals(CountMinSketch.suggestNumBuckets(0.01), 272); + + // Check that the sketch get_epsilon acts inversely to suggest_num_buckets + final byte numHashes = 3; + final long seed = 1234567; + assertTrue(new CountMinSketch(numHashes, 14, seed).getRelativeError() <= 0.2); + assertTrue(new CountMinSketch(numHashes, 28, seed).getRelativeError() <= 0.1); + assertTrue(new CountMinSketch(numHashes, 55, seed).getRelativeError() <= 0.05); + assertTrue(new CountMinSketch(numHashes, 272, seed).getRelativeError() <= 0.01); + + // Hash suggestions + assertThrows("Confidence must be between 0 and 1.0 (inclusive).", SketchesException.class, () -> CountMinSketch.suggestNumHashes(10)); + assertThrows("Confidence must be between 0 and 1.0 (inclusive).", SketchesException.class, () -> CountMinSketch.suggestNumHashes(-1.0)); + assertEquals(CountMinSketch.suggestNumHashes(0.682689492), 2); + assertEquals(CountMinSketch.suggestNumHashes(0.954499736), 4); + assertEquals(CountMinSketch.suggestNumHashes(0.997300204), 6); + } + + @Test + public void countMinSketchOneUpdateTest() { + final byte numHashes = 3; + final int numBuckets = 5; + final long seed = 1234567; + long insertedWeights = 0; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + final String x = "x"; + + assertTrue(c.isEmpty()); + assertEquals(c.getEstimate(x), 0); + c.update(x, 1); + assertFalse(c.isEmpty()); + assertEquals(c.getEstimate(x), 1); + insertedWeights++; + + final long w = 9; + insertedWeights += w; + c.update(x, w); + assertEquals(c.getEstimate(x), insertedWeights); + + final double w1 = 10.0; + insertedWeights += (long) w1; + c.update(x, (long) w1); + assertEquals(c.getEstimate(x), insertedWeights); + assertEquals(c.getTotalWeight_(), insertedWeights); + assertTrue(c.getEstimate(x) <= c.getUpperBound(x)); + assertTrue(c.getEstimate(x) >= c.getLowerBound(x)); + } + + @Test + public void frequencyCancellationTest() { + CountMinSketch c = new CountMinSketch((byte) 1, 5, 123456); + c.update("x", 1); + c.update("y", -1); + assertEquals(c.getTotalWeight_(), 2); + assertEquals(c.getEstimate("x"), 1); + assertEquals(c.getEstimate("y"), -1); + } + + @Test + public void frequencyEstimates() { + final int numItems = 10; + long[] data = new long[numItems]; + long[] frequencies = new long[numItems]; + + for (int i = 0; i < numItems; i++) { + data[i] = i; + frequencies[i] = (long) 1 << (numItems - i); + } + + final double relativeError = 0.1; + final double confidence = 0.99; + final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError); + final byte numHashes = CountMinSketch.suggestNumHashes(confidence); + + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, 1234567); + for (int i = 0; i < numItems; i++) { + final long value = data[i]; + final long freq = frequencies[i]; + c.update(value, freq); + } + + for (final long i : data) { + final long est = c.getEstimate(i); + final long upp = c.getUpperBound(i); + final long low = c.getLowerBound(i); + assertTrue(est <= upp); + assertTrue(est >= low); + } + } + + @Test + public void mergeFailTest() { + final double relativeError = 0.25; + final double confidence = 0.9; + final long seed = 1234567; + final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError); + final byte numHashes = CountMinSketch.suggestNumHashes(confidence); + CountMinSketch s = new CountMinSketch(numHashes, numBuckets, seed); + + assertThrows("Cannot merge a sketch with itself.", SketchesException.class, () -> s.merge(s)); + + CountMinSketch s1 = new CountMinSketch((byte) (numHashes + 1), numBuckets, seed); + CountMinSketch s2 = new CountMinSketch(numHashes, numBuckets + 1, seed); + CountMinSketch s3 = new CountMinSketch(numHashes, numBuckets, seed + 1); + + CountMinSketch[] sketches = {s1, s2, s3}; + for (final CountMinSketch sk : sketches) { + assertThrows("Incompatible sketch configuration.", SketchesException.class, () -> s.merge(sk)); + } + } + + @Test + public void mergeTest() { + final double relativeError = 0.25; + final double confidence = 0.9; + final long seed = 123456; + final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError); + final byte numHashes = CountMinSketch.suggestNumHashes(confidence); + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + + final byte sHashes = c.getNumHashes_(); + final int sBuckets = c.getNumBuckets_(); + final long sSeed = c.getSeed_(); + CountMinSketch s = new CountMinSketch(sHashes, sBuckets, sSeed); + + c.merge(s); + assertEquals(c.getTotalWeight_(), 0); + + final long[] data = {2, 3, 5, 7}; + for (final long d : data) { + c.update(d, 1); + s.update(d, 1); + } + c.merge(s); + + assertEquals(c.getTotalWeight_(), 2 * s.getTotalWeight_()); + + for (final long d : data) { + assertTrue(c.getEstimate(d) <= c.getUpperBound(d)); + assertTrue(s.getEstimate(d) <= 2); + } + } + + @Test + public void serializeDeserializeEmptyTest() { + final byte numHashes = 3; + final int numBuckets = 32; + final long seed = 123456; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + c.serialize(buf); + + byte[] b = buf.toByteArray(); + assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1)); + + CountMinSketch d = CountMinSketch.deserialize(b, seed); + assertEquals(d.getNumHashes_(), c.getNumHashes_()); + assertEquals(d.getNumBuckets_(), c.getNumBuckets_()); + assertEquals(d.getSeed_(), c.getSeed_()); + final long zero = 0; + assertEquals(d.getEstimate(zero), c.getEstimate(zero)); + assertEquals(d.getTotalWeight_(), c.getTotalWeight_()); + } + + @Test + public void serializeDeserializeTest() { + final byte numHashes = 5; + final int numBuckets = 64; + final long seed = 1234456; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + for (long i = 0; i < 10; i++) { + c.update(i, 10*i*i); + } + + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + c.serialize(buf); + + assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1)); + CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed); + + assertEquals(d.getNumHashes_(), c.getNumHashes_()); + assertEquals(d.getNumBuckets_(), c.getNumBuckets_()); + assertEquals(d.getSeed_(), c.getSeed_()); + assertEquals(d.getTotalWeight_(), c.getTotalWeight_()); + + for (long i = 0; i < 10; i++) { + assertEquals(d.getEstimate(i), c.getEstimate(i)); + } + } +}