diff --git a/src/main/java/org/apache/datasketches/count/CountMinSketch.java b/src/main/java/org/apache/datasketches/count/CountMinSketch.java new file mode 100644 index 000000000..36bea38cf --- /dev/null +++ b/src/main/java/org/apache/datasketches/count/CountMinSketch.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datasketches.count; + +import org.apache.datasketches.common.Family; +import org.apache.datasketches.common.SketchesArgumentException; +import org.apache.datasketches.common.SketchesException; +import org.apache.datasketches.hash.MurmurHash3; +import org.apache.datasketches.tuple.Util; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Random; + + +public class CountMinSketch { + private final byte numHashes_; + private final int numBuckets_; + private final long seed_; + private final long[] hashSeeds_; + private final long[] sketchArray_; + private long totalWeight_; + + + private enum Flag { + IS_EMPTY; + + int mask() { + return 1 << ordinal(); + } + } + + /** + * Creates a CountMin sketch with given number of hash functions and buckets, + * and a user-specified seed. + * + * @param numHashes The number of hash functions to apply to items + * @param numBuckets Array size for each of the hashing function + * @param seed The base hash seed + */ + CountMinSketch(final byte numHashes, final int numBuckets, final long seed) { + numHashes_ = numHashes; + numBuckets_ = numBuckets; + seed_ = seed; + hashSeeds_ = new long[numHashes]; + sketchArray_ = new long[numHashes * numBuckets]; + totalWeight_ = 0; + + if (numBuckets < 3) { + throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1."); + } + + // This check is to ensure later compatibility with a Java implementation whose maximum size can only + // be 2^31-1. We check only against 2^30 for simplicity. + if (numBuckets * numHashes >= 1 << 30) { + throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" + + "Try reducing either the number of buckets or the number of hash functions."); + } + + Random rand = new Random(seed); + for (int i = 0; i < numHashes; i++) { + hashSeeds_[i] = rand.nextLong(); + } + } + + private long[] getHashes(byte[] item) { + long[] updateLocations = new long[numHashes_]; + + for (int i = 0; i < numHashes_; i++) { + long[] index = MurmurHash3.hash(item, hashSeeds_[i]); + updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_); + } + + return updateLocations; + } + + /** + * Checks if the CountMinSketch has processed any items. + * @return True if the sketch is empty, otherwise false. + */ + public boolean isEmpty() { + return totalWeight_ == 0; + } + + /** + * Returns the number of hash functions used in this sketch. + * @return The number of hash functions. + */ + public byte getNumHashes_() { + return numHashes_; + } + + /** + * Returns the number of buckets per hash function. + * @return The number of buckets. + */ + public int getNumBuckets_() { + return numBuckets_; + } + + /** + * Returns the hash seed used by this sketch. + * @return The seed value. + */ + public long getSeed_() { + return seed_; + } + + /** + * Returns the total weight of all items inserted into the sketch. + * @return The total weight. + */ + public long getTotalWeight_() { + return totalWeight_; + } + + /** + * Returns the relative error of the sketch. + * @return The relative error. + */ + public double getRelativeError() { + return Math.exp(1.0) / (double)numBuckets_; + } + + /** + * Suggests an appropriate number of hash functions to use for a given confidence level. + * @param confidence The desired confidence level between 0 and 1. + * @return Suggested number of hash functions. + */ + public static byte suggestNumHashes(double confidence) { + if (confidence < 0 || confidence > 1) { + throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive)."); + } + int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence))); + return (byte) Math.min(value, 127); + } + + /** + * Suggests an appropriate number of buckets per hash function for a given relative error. + * @param relativeError The desired relative error. + * @return Suggested number of buckets. + */ + public static int suggestNumBuckets(double relativeError) { + if (relativeError < 0.) { + throw new SketchesException("Relative error must be at least 0."); + } + return (int) Math.ceil(Math.exp(1.0) / relativeError); + } + + /** + * Updates the sketch with the provided item and weight. + * @param item The item to update. + * @param weight The weight of the item. + */ + public void update(final long item, final long weight) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + update(longByte, weight); + } + + /** + * Updates the sketch with the provided item and weight. + * @param item The item to update. + * @param weight The weight of the item. + */ + public void update(final String item, final long weight) { + if (item == null || item.isEmpty()) { + return; + } + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + update(strByte, weight); + } + + /** + * Updates the sketch with the provided item and weight. + * @param item The item to update. + * @param weight The weight of the item. + */ + public void update(final byte[] item, final long weight) { + if (item.length == 0) { + return; + } + + totalWeight_ += weight > 0 ? weight : -weight; + long[] hashLocations = getHashes(item); + for (long h : hashLocations) { + sketchArray_[(int) h] += weight; + } + } + + /** + * Returns the estimated frequency for the given item. + * @param item The item to estimate. + * @return Estimated frequency. + */ + public long getEstimate(final long item) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + return getEstimate(longByte); + } + + /** + * Returns the estimated frequency for the given item. + * @param item The item to estimate. + * @return Estimated frequency. + */ + public long getEstimate(final String item) { + if (item == null || item.isEmpty()) { + return 0; + } + + final byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getEstimate(strByte); + } + + /** + * Returns the estimated frequency for the given item. + * @param item The item to estimate. + * @return Estimated frequency. + */ + public long getEstimate(final byte[] item) { + if (item.length == 0) { + return 0; + } + + long[] hashLocations = getHashes(item); + long res = sketchArray_[(int) hashLocations[0]]; + for (long h : hashLocations) { + res = Math.min(res, sketchArray_[(int) h]); + } + + return res; + } + + /** + * Returns the upper bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Upper bound of estimated frequency. + */ + public long getUpperBound(final long item) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + return getUpperBound(longByte); + } + + /** + * Returns the upper bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Upper bound of estimated frequency. + */ + public long getUpperBound(final String item) { + if (item == null || item.isEmpty()) { + return 0; + } + + byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getUpperBound(strByte); + } + + /** + * Returns the upper bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Upper bound of estimated frequency. + */ + public long getUpperBound(final byte[] item) { + if (item.length == 0) { + return 0; + } + + return getEstimate(item) + (long)(getRelativeError() * getTotalWeight_()); + } + + /** + * Returns the lower bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Lower bound of estimated frequency. + */ + public long getLowerBound(final long item) { + byte[] longByte = ByteBuffer.allocate(8).putLong(item).array(); + return getLowerBound(longByte); + } + + /** + * Returns the lower bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Lower bound of estimated frequency. + */ + public long getLowerBound(final String item) { + if (item == null || item.isEmpty()) { + return 0; + } + + byte[] strByte = item.getBytes(StandardCharsets.UTF_8); + return getLowerBound(strByte); + } + + /** + * Returns the lower bound of the estimated frequency for the given item. + * @param item The item to estimate. + * @return Lower bound of estimated frequency. + */ + public long getLowerBound(final byte[] item) { + return getEstimate(item); + } + + /** + * Merges another CountMinSketch into this one. The sketches must have the same configuration. + * @param other The other sketch to merge. + */ + public void merge(final CountMinSketch other) { + if (this == other) { + throw new SketchesException("Cannot merge a sketch with itself"); + } + + boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() && + getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_(); + + if (!acceptableConfig) { + throw new SketchesException("Incompatible sketch configuration."); + } + + for (int i = 0; i < sketchArray_.length; i++) { + sketchArray_[i] += other.sketchArray_[i]; + } + + totalWeight_ += other.getTotalWeight_(); + } + + /** + * Serializes the sketch into the provided ByteBuffer. + * @param buf The ByteBuffer to write into. + */ + public void serialize(ByteArrayOutputStream buf) { + // Long 0 + final int preambleLongs = Family.COUNTMIN.getMinPreLongs(); + buf.write((byte) preambleLongs); + final int serialVersion = 1; + buf.write((byte) serialVersion); + final int familyId = Family.COUNTMIN.getID(); + buf.write((byte) familyId); + final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0; + buf.write((byte)flagsByte); + final int NULL_32 = 0; + buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array()); + + // Long 1 + buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array()); + buf.write(numHashes_); + short hashSeed = Util.computeSeedHash(seed_); + buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array()); + final byte NULL_8 = 0; + buf.write(NULL_8); + if (isEmpty()) { + return; + } + + final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array(); + buf.writeBytes(totWeightByte); + + for (long w: sketchArray_) { + buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array()); + } + } + + /** + * Deserializes a CountMinSketch from the provided byte array. + * @param b The byte array containing the serialized sketch. + * @param seed The seed used during serialization. + * @return The deserialized CountMinSketch. + */ + public static CountMinSketch deserialize(final byte[] b, final long seed) { + ByteBuffer buf = ByteBuffer.allocate(b.length); + buf.put(b); + buf.flip(); + + final byte preambleLongs = buf.get(); + final byte serialVersion = buf.get(); + final byte familyId = buf.get(); + final byte flagsByte = buf.get(); + final int NULL_32 = buf.getInt(); + + final int numBuckets = buf.getInt(); + final byte numHashes = buf.get(); + final short seedHash = buf.getShort(); + final byte NULL_8 = buf.get(); + + if (seedHash != Util.computeSeedHash(seed)) { + throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", " + + String.valueOf(Util.computeSeedHash(seed))); + } + + CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed); + final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0; + if (empty) { + return cms; + } + long w = buf.getLong(); + cms.totalWeight_ = w; + + for (int i = 0; i < cms.sketchArray_.length; i++) { + cms.sketchArray_[i] = buf.getLong(); + } + + return cms; + } +} diff --git a/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java new file mode 100644 index 000000000..cbd2fde79 --- /dev/null +++ b/src/test/java/org/apache/datasketches/count/CountMinSketchTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datasketches.count; + +import org.apache.datasketches.common.SketchesArgumentException; +import org.apache.datasketches.common.SketchesException; +import org.testng.annotations.Test; + +import java.io.ByteArrayOutputStream; +import java.lang.annotation.Repeatable; +import java.nio.ByteBuffer; + +import static org.testng.Assert.*; + +public class CountMinSketchTest { + @Test + public void createNewCountMinSketchTest() throws Exception { + assertThrows(SketchesArgumentException.class, () -> new CountMinSketch((byte) 5, 1, 123)); + assertThrows(SketchesArgumentException.class, () -> new CountMinSketch((byte) 4, 268435456, 123)); + + final byte numHashes = 3; + final int numBuckets = 5; + final long seed = 1234567; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + + assertEquals(c.getNumHashes_(), numHashes); + assertEquals(c.getNumBuckets_(), numBuckets); + assertEquals(c.getSeed_(), seed); + assertTrue(c.isEmpty()); + } + + @Test + public void parameterSuggestionsTest() { + // Bucket suggestions + assertThrows("Relative error must be at least 0.", SketchesException.class, () -> CountMinSketch.suggestNumBuckets(-1.0)); + assertEquals(CountMinSketch.suggestNumBuckets(0.2), 14); + assertEquals(CountMinSketch.suggestNumBuckets(0.1), 28); + assertEquals(CountMinSketch.suggestNumBuckets(0.05), 55); + assertEquals(CountMinSketch.suggestNumBuckets(0.01), 272); + + // Check that the sketch get_epsilon acts inversely to suggest_num_buckets + final byte numHashes = 3; + final long seed = 1234567; + assertTrue(new CountMinSketch(numHashes, 14, seed).getRelativeError() <= 0.2); + assertTrue(new CountMinSketch(numHashes, 28, seed).getRelativeError() <= 0.1); + assertTrue(new CountMinSketch(numHashes, 55, seed).getRelativeError() <= 0.05); + assertTrue(new CountMinSketch(numHashes, 272, seed).getRelativeError() <= 0.01); + + // Hash suggestions + assertThrows("Confidence must be between 0 and 1.0 (inclusive).", SketchesException.class, () -> CountMinSketch.suggestNumHashes(10)); + assertThrows("Confidence must be between 0 and 1.0 (inclusive).", SketchesException.class, () -> CountMinSketch.suggestNumHashes(-1.0)); + assertEquals(CountMinSketch.suggestNumHashes(0.682689492), 2); + assertEquals(CountMinSketch.suggestNumHashes(0.954499736), 4); + assertEquals(CountMinSketch.suggestNumHashes(0.997300204), 6); + } + + @Test + public void countMinSketchOneUpdateTest() { + final byte numHashes = 3; + final int numBuckets = 5; + final long seed = 1234567; + long insertedWeights = 0; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + final String x = "x"; + + assertTrue(c.isEmpty()); + assertEquals(c.getEstimate(x), 0); + c.update(x, 1); + assertFalse(c.isEmpty()); + assertEquals(c.getEstimate(x), 1); + insertedWeights++; + + final long w = 9; + insertedWeights += w; + c.update(x, w); + assertEquals(c.getEstimate(x), insertedWeights); + + final double w1 = 10.0; + insertedWeights += (long) w1; + c.update(x, (long) w1); + assertEquals(c.getEstimate(x), insertedWeights); + assertEquals(c.getTotalWeight_(), insertedWeights); + assertTrue(c.getEstimate(x) <= c.getUpperBound(x)); + assertTrue(c.getEstimate(x) >= c.getLowerBound(x)); + } + + @Test + public void frequencyCancellationTest() { + CountMinSketch c = new CountMinSketch((byte) 1, 5, 123456); + c.update("x", 1); + c.update("y", -1); + assertEquals(c.getTotalWeight_(), 2); + assertEquals(c.getEstimate("x"), 1); + assertEquals(c.getEstimate("y"), -1); + } + + @Test + public void frequencyEstimates() { + final int numItems = 10; + long[] data = new long[numItems]; + long[] frequencies = new long[numItems]; + + for (int i = 0; i < numItems; i++) { + data[i] = i; + frequencies[i] = (long) 1 << (numItems - i); + } + + final double relativeError = 0.1; + final double confidence = 0.99; + final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError); + final byte numHashes = CountMinSketch.suggestNumHashes(confidence); + + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, 1234567); + for (int i = 0; i < numItems; i++) { + final long value = data[i]; + final long freq = frequencies[i]; + c.update(value, freq); + } + + for (final long i : data) { + final long est = c.getEstimate(i); + final long upp = c.getUpperBound(i); + final long low = c.getLowerBound(i); + assertTrue(est <= upp); + assertTrue(est >= low); + } + } + + @Test + public void mergeFailTest() { + final double relativeError = 0.25; + final double confidence = 0.9; + final long seed = 1234567; + final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError); + final byte numHashes = CountMinSketch.suggestNumHashes(confidence); + CountMinSketch s = new CountMinSketch(numHashes, numBuckets, seed); + + assertThrows("Cannot merge a sketch with itself.", SketchesException.class, () -> s.merge(s)); + + CountMinSketch s1 = new CountMinSketch((byte) (numHashes + 1), numBuckets, seed); + CountMinSketch s2 = new CountMinSketch(numHashes, numBuckets + 1, seed); + CountMinSketch s3 = new CountMinSketch(numHashes, numBuckets, seed + 1); + + CountMinSketch[] sketches = {s1, s2, s3}; + for (final CountMinSketch sk : sketches) { + assertThrows("Incompatible sketch configuration.", SketchesException.class, () -> s.merge(sk)); + } + } + + @Test + public void mergeTest() { + final double relativeError = 0.25; + final double confidence = 0.9; + final long seed = 123456; + final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError); + final byte numHashes = CountMinSketch.suggestNumHashes(confidence); + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + + final byte sHashes = c.getNumHashes_(); + final int sBuckets = c.getNumBuckets_(); + final long sSeed = c.getSeed_(); + CountMinSketch s = new CountMinSketch(sHashes, sBuckets, sSeed); + + c.merge(s); + assertEquals(c.getTotalWeight_(), 0); + + final long[] data = {2, 3, 5, 7}; + for (final long d : data) { + c.update(d, 1); + s.update(d, 1); + } + c.merge(s); + + assertEquals(c.getTotalWeight_(), 2 * s.getTotalWeight_()); + + for (final long d : data) { + assertTrue(c.getEstimate(d) <= c.getUpperBound(d)); + assertTrue(s.getEstimate(d) <= 2); + } + } + + @Test + public void serializeDeserializeEmptyTest() { + final byte numHashes = 3; + final int numBuckets = 32; + final long seed = 123456; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + c.serialize(buf); + + byte[] b = buf.toByteArray(); + assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1)); + + CountMinSketch d = CountMinSketch.deserialize(b, seed); + assertEquals(d.getNumHashes_(), c.getNumHashes_()); + assertEquals(d.getNumBuckets_(), c.getNumBuckets_()); + assertEquals(d.getSeed_(), c.getSeed_()); + final long zero = 0; + assertEquals(d.getEstimate(zero), c.getEstimate(zero)); + assertEquals(d.getTotalWeight_(), c.getTotalWeight_()); + } + + @Test + public void serializeDeserializeTest() { + final byte numHashes = 5; + final int numBuckets = 64; + final long seed = 1234456; + CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed); + for (long i = 0; i < 10; i++) { + c.update(i, 10*i*i); + } + + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + c.serialize(buf); + + assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1)); + CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed); + + assertEquals(d.getNumHashes_(), c.getNumHashes_()); + assertEquals(d.getNumBuckets_(), c.getNumBuckets_()); + assertEquals(d.getSeed_(), c.getSeed_()); + assertEquals(d.getTotalWeight_(), c.getTotalWeight_()); + + for (long i = 0; i < 10; i++) { + assertEquals(d.getEstimate(i), c.getEstimate(i)); + } + } +}