From af12b55b62d9bfa12563c033c2f911e1024ffcd3 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 9 Mar 2023 22:18:34 +0800 Subject: [PATCH 01/19] support --- .../spark/sql/DataFrameStatFunctions.scala | 98 ++++++++++++++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 27 +++++ .../CheckConnectJvmClientCompatibility.scala | 1 - .../connect/planner/SparkConnectPlanner.scala | 6 ++ 4 files changed, 130 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 0d4372b8738ee..32504570fcdd7 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.{lang => jl, util => ju} +import java.io.ByteArrayInputStream import scala.collection.JavaConverters._ @@ -25,7 +26,7 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} import org.apache.spark.sql.functions.lit -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. @@ -584,6 +585,101 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo } CountMinSketch.readFrom(ds.head()) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.4.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.4.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.4.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.4.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) + } + + private def buildBloomFilter( + col: Column, + expectedNumItems: Long, + numBits: Long, + fpp: Double): BloomFilter = { + + def optimalNumOfBits(n: Long, p: Double): Long = + (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong + + val nBits = if (fpp.isNaN) { + numBits + } else { + if (fpp <= 0d || fpp >= 1d) { + throw new IllegalArgumentException( + "False positive probability must be within range (0.0, 1.0)") + } + optimalNumOfBits(expectedNumItems, fpp) + } + + if (expectedNumItems <= 0) { + throw new IllegalArgumentException("Expected insertions must be positive") + } + if (nBits <= 0) { + throw new IllegalArgumentException("Number of bits must be positive") + } + + val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(nBits)) + val ds = sparkSession.newDataset(BinaryEncoder) { builder => + builder.getProjectBuilder + .setInput(root) + .addExpressions(agg.expr) + } + BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) + } } private object DataFrameStatFunctions { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index aea31005f3bd6..ab2fda1b74f60 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -176,4 +176,31 @@ class DataFrameStatSuite extends RemoteSparkSession { assert(sketch.relativeError() === 0.001) assert(sketch.confidence() === 0.99 +- 5e-3) } + + // This test only verifies some basic requirements, more correctness tests can be found in + // `BloomFilterSuite` in project spark-sketch. + test("Bloom filter") { + val df = spark.range(1000) + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(0.until(1000).forall(filter1.mightContain)) + val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter2.bitSize() == 64 * 5) + assert(0.until(1000).forall(filter2.mightContain)) + + val message1 = intercept[IllegalArgumentException] { + df.stat.bloomFilter("id", -1000, 100) + }.getMessage + assert(message1.contains("Expected insertions must be positive")) + + val message2 = intercept[IllegalArgumentException] { + df.stat.bloomFilter("id", 1000, -100) + }.getMessage + assert(message2.contains("Number of bits must be positive")) + + val message3 = intercept[IllegalArgumentException] { + df.stat.bloomFilter("id", 1000, -1.0) + }.getMessage + assert(message3.contains("False positive probability must be within range (0.0, 1.0)")) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index ae6c6c86fec26..7178def513555 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -138,7 +138,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.this"), // DataFrameStatFunctions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.this"), // Dataset diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cd4da39d62fe9..bcc0cd743117f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIden import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical @@ -1073,6 +1074,11 @@ class SparkConnectPlanner(val session: SparkSession) { } Some(Lead(children.head, children(1), children(2), ignoreNulls)) + case "bloom_filter_agg" if fun.getArgumentsCount == 3 => + val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) + Some(Alias(new BloomFilterAggregate(children.head, children(1), children(2)) + .toAggregateExpression(), "bloomFilter")()) + case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) val timeCol = children.head From 6f2e7363b703c87eeb79fe8d944f4de728add025 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 9 Mar 2023 22:21:31 +0800 Subject: [PATCH 02/19] format --- .../spark/sql/connect/planner/SparkConnectPlanner.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index bcc0cd743117f..e6cdca3231ca9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1076,8 +1076,11 @@ class SparkConnectPlanner(val session: SparkSession) { case "bloom_filter_agg" if fun.getArgumentsCount == 3 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - Some(Alias(new BloomFilterAggregate(children.head, children(1), children(2)) - .toAggregateExpression(), "bloomFilter")()) + Some( + Alias( + new BloomFilterAggregate(children.head, children(1), children(2)) + .toAggregateExpression(), + "bloomFilter")()) case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) From 92da5893967d535d326340c5885129a60cae9cf0 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 00:05:22 +0800 Subject: [PATCH 03/19] remove Alias --- .../spark/sql/connect/planner/SparkConnectPlanner.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e6cdca3231ca9..01f4959aa0e58 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1077,10 +1077,8 @@ class SparkConnectPlanner(val session: SparkSession) { case "bloom_filter_agg" if fun.getArgumentsCount == 3 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) Some( - Alias( - new BloomFilterAggregate(children.head, children(1), children(2)) - .toAggregateExpression(), - "bloomFilter")()) + new BloomFilterAggregate(children.head, children(1), children(2)) + .toAggregateExpression()) case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) From 71fba34016a1960b71aa32ed8ce8d9c53f7d4bf4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 14:18:40 +0800 Subject: [PATCH 04/19] support Int/Byte/Short --- .../spark/sql/DataFrameStatFunctions.scala | 17 ++++++- .../apache/spark/sql/DataFrameStatSuite.scala | 44 ++++++++++++++++--- .../client/util/IntegrationTestUtils.scala | 2 +- .../connect/planner/SparkConnectPlanner.scala | 20 ++++++++- 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 32504570fcdd7..cf469114dd146 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -672,7 +672,22 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo throw new IllegalArgumentException("Number of bits must be positive") } - val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(nBits)) + val dataType = sparkSession + .newDataFrame { builder => + builder.getProjectBuilder + .setInput(root) + .addExpressions(col.expr) + } + .schema + .head + .dataType + + val agg = Column.fn( + "bloom_filter_agg", + col, + lit(expectedNumItems), + lit(nBits), + lit(dataType.catalogString)) val ds = sparkSession.newDataset(BinaryEncoder) { builder => builder.getProjectBuilder .setInput(root) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index ab2fda1b74f60..08e5f27205d26 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -177,17 +177,49 @@ class DataFrameStatSuite extends RemoteSparkSession { assert(sketch.confidence() === 0.99 +- 5e-3) } - // This test only verifies some basic requirements, more correctness tests can be found in - // `BloomFilterSuite` in project spark-sketch. - test("Bloom filter") { - val df = spark.range(1000) + test("Bloom filter -- Long Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toLong) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + test("Bloom filter -- Int Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + test("Bloom filter -- Short Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toShort) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + test("Bloom filter -- Byte Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toByte) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + + private def checkBloomFilter(data: Seq[Any], df: DataFrame) = { val filter1 = df.stat.bloomFilter("id", 1000, 0.03) assert(filter1.expectedFpp() - 0.03 < 1e-3) - assert(0.until(1000).forall(filter1.mightContain)) + assert(data.forall(filter1.mightContain)) val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5) assert(filter2.bitSize() == 64 * 5) - assert(0.until(1000).forall(filter2.mightContain)) + assert(data.forall(filter2.mightContain)) + } + test("Bloom filter test invalid inputs") { + val df = spark.emptyDataFrame val message1 = intercept[IllegalArgumentException] { df.stat.bloomFilter("id", -1000, 100) }.getMessage diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index f27ea614a7eb8..6e042c8c288ac 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -41,7 +41,7 @@ object IntegrationTestUtils { } sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) } - private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean + private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "true").toBoolean // Log server start stop debug info into console // scalastyle:off println diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 01f4959aa0e58..1ca836be88dde 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1074,10 +1074,26 @@ class SparkConnectPlanner(val session: SparkSession) { } Some(Lead(children.head, children(1), children(2), ignoreNulls)) - case "bloom_filter_agg" if fun.getArgumentsCount == 3 => + case "bloom_filter_agg" if fun.getArgumentsCount == 4 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) + val dt = { + val ddl = children.last match { + case StringLiteral(s) => s + case other => + throw InvalidPlanInput(s"col dataType should be a literal string, but got $other") + } + DataType.fromDDL(ddl) + } + val first = dt match { + case IntegerType | ShortType | ByteType => Cast(children.head, LongType) + case LongType => children.head + case other => + throw InvalidPlanInput( + s"Bloom filter only supports integral types, " + + s"and does not support type $other.") + } Some( - new BloomFilterAggregate(children.head, children(1), children(2)) + new BloomFilterAggregate(first, children(1), children(2)) .toAggregateExpression()) case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => From ab78bae5e5c406713fa395273181122373c819cb Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 15:53:28 +0800 Subject: [PATCH 05/19] support string --- .../apache/spark/sql/DataFrameStatSuite.scala | 8 +++++++ .../connect/planner/SparkConnectPlanner.scala | 2 +- .../aggregate/BloomFilterAggregate.scala | 24 +++++++++++++++++-- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 08e5f27205d26..4aef1107ab775 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -209,6 +209,14 @@ class DataFrameStatSuite extends RemoteSparkSession { checkBloomFilter(data, df) } + test("Bloom filter -- String Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toString) + val df = data.toDF("id") + checkBloomFilter(data, df) + } + private def checkBloomFilter(data: Seq[Any], df: DataFrame) = { val filter1 = df.stat.bloomFilter("id", 1000, 0.03) assert(filter1.expectedFpp() - 0.03 < 1e-3) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 1ca836be88dde..73fb836d80deb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1086,7 +1086,7 @@ class SparkConnectPlanner(val session: SparkSession) { } val first = dt match { case IntegerType | ShortType | ByteType => Cast(children.head, LongType) - case LongType => children.head + case LongType | StringType => children.head case other => throw InvalidPlanInput( s"Bloom filter only supports integral types, " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 980785e764cdb..f61e696dd199f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.trees.TernaryLike import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.BloomFilter /** @@ -78,7 +79,7 @@ case class BloomFilterAggregate( "exprName" -> "estimatedNumItems or numBits" ) ) - case (LongType, LongType, LongType) => + case (LongType, LongType, LongType) | (StringType, LongType, LongType) => if (!estimatedNumItemsExpression.foldable) { DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -150,6 +151,11 @@ case class BloomFilterAggregate( Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) + private lazy val updater: BloomFilterUpdater = first.dataType match { + case LongType => LongUpdater + case StringType => BinaryUpdater + } + override def first: Expression = child override def second: Expression = estimatedNumItemsExpression @@ -174,7 +180,7 @@ case class BloomFilterAggregate( if (value == null) { return buffer } - buffer.putLong(value.asInstanceOf[Long]) + updater.update(buffer, value) buffer } @@ -224,3 +230,17 @@ object BloomFilterAggregate { bloomFilter } } + +private trait BloomFilterUpdater { + def update(bf: BloomFilter, v: Any) +} + +private object LongUpdater extends BloomFilterUpdater { + override def update(bf: BloomFilter, v: Any): Unit = + bf.putLong(v.asInstanceOf[Long]) +} + +private object BinaryUpdater extends BloomFilterUpdater { + override def update(bf: BloomFilter, v: Any): Unit = + bf.putBinary(v.asInstanceOf[UTF8String].getBytes) +} From ebcec0bd41940ba9f5380adcdbad5beb63edd08d Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 15:54:36 +0800 Subject: [PATCH 06/19] support string --- .../expressions/aggregate/BloomFilterAggregate.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index f61e696dd199f..03d8136f11682 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -232,15 +232,15 @@ object BloomFilterAggregate { } private trait BloomFilterUpdater { - def update(bf: BloomFilter, v: Any) + def update(bf: BloomFilter, v: Any): Boolean } private object LongUpdater extends BloomFilterUpdater { - override def update(bf: BloomFilter, v: Any): Unit = + override def update(bf: BloomFilter, v: Any): Boolean = bf.putLong(v.asInstanceOf[Long]) } private object BinaryUpdater extends BloomFilterUpdater { - override def update(bf: BloomFilter, v: Any): Unit = + override def update(bf: BloomFilter, v: Any): Boolean = bf.putBinary(v.asInstanceOf[UTF8String].getBytes) } From 6ee5528cc7422a42b8fb23f115d66e16ebe02915 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 16:01:55 +0800 Subject: [PATCH 07/19] change to check numBits --- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 2 +- .../spark/sql/connect/client/util/IntegrationTestUtils.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index cf469114dd146..9af7cd01de95b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -655,7 +655,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo def optimalNumOfBits(n: Long, p: Double): Long = (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong - val nBits = if (fpp.isNaN) { + val nBits = if (numBits > 0L) { numBits } else { if (fpp <= 0d || fpp >= 1d) { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index 6e042c8c288ac..f27ea614a7eb8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -41,7 +41,7 @@ object IntegrationTestUtils { } sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) } - private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "true").toBoolean + private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean // Log server start stop debug info into console // scalastyle:off println From dfffd53524eaff690f4782cee072a75927521db8 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 17:51:16 +0800 Subject: [PATCH 08/19] tmp fix --- .../apache/spark/util/sketch/BloomFilter.java | 2 +- .../spark/sql/DataFrameStatFunctions.scala | 25 ++------- .../apache/spark/sql/DataFrameStatSuite.scala | 8 +-- .../connect/planner/SparkConnectPlanner.scala | 53 ++++++++++++++++--- 4 files changed, 55 insertions(+), 33 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 5c01841e5015a..ab6bf8d96ac10 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -201,7 +201,7 @@ private static int optimalNumOfHashFunctions(long n, long m) { * @param n expected insertions (must be positive) * @param p false positive rate (must be 0 < p < 1) */ - private static long optimalNumOfBits(long n, double p) { + public static long optimalNumOfBits(long n, double p) { return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 9af7cd01de95b..02f7dfba95086 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -652,26 +652,6 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo numBits: Long, fpp: Double): BloomFilter = { - def optimalNumOfBits(n: Long, p: Double): Long = - (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong - - val nBits = if (numBits > 0L) { - numBits - } else { - if (fpp <= 0d || fpp >= 1d) { - throw new IllegalArgumentException( - "False positive probability must be within range (0.0, 1.0)") - } - optimalNumOfBits(expectedNumItems, fpp) - } - - if (expectedNumItems <= 0) { - throw new IllegalArgumentException("Expected insertions must be positive") - } - if (nBits <= 0) { - throw new IllegalArgumentException("Number of bits must be positive") - } - val dataType = sparkSession .newDataFrame { builder => builder.getProjectBuilder @@ -685,9 +665,10 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo val agg = Column.fn( "bloom_filter_agg", col, + lit(dataType.catalogString), lit(expectedNumItems), - lit(nBits), - lit(dataType.catalogString)) + lit(numBits), + lit(fpp)) val ds = sparkSession.newDataset(BinaryEncoder) { builder => builder.getProjectBuilder .setInput(root) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 4aef1107ab775..91e849e2866e4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -227,18 +227,18 @@ class DataFrameStatSuite extends RemoteSparkSession { } test("Bloom filter test invalid inputs") { - val df = spark.emptyDataFrame - val message1 = intercept[IllegalArgumentException] { + val df = spark.range(1000).toDF("id") + val message1 = intercept[StatusRuntimeException] { df.stat.bloomFilter("id", -1000, 100) }.getMessage assert(message1.contains("Expected insertions must be positive")) - val message2 = intercept[IllegalArgumentException] { + val message2 = intercept[StatusRuntimeException] { df.stat.bloomFilter("id", 1000, -100) }.getMessage assert(message2.contains("Number of bits must be positive")) - val message3 = intercept[IllegalArgumentException] { + val message3 = intercept[StatusRuntimeException] { df.stat.bloomFilter("id", 1000, -1.0) }.getMessage assert(message3.contains("False positive probability must be within range (0.0, 1.0)")) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 73fb836d80deb..5e894638dbe3c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1074,17 +1074,21 @@ class SparkConnectPlanner(val session: SparkSession) { } Some(Lead(children.head, children(1), children(2), ignoreNulls)) - case "bloom_filter_agg" if fun.getArgumentsCount == 4 => + case "bloom_filter_agg" if fun.getArgumentsCount == 5 => + // [col, catalogString: String, expectedNumItems: Long, numBits: Long, fpp: Double] + def optimalNumOfBits(n: Long, p: Double): Long = + (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong + val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) val dt = { - val ddl = children.last match { + val ddl = children(1) match { case StringLiteral(s) => s case other => throw InvalidPlanInput(s"col dataType should be a literal string, but got $other") } DataType.fromDDL(ddl) } - val first = dt match { + val col = dt match { case IntegerType | ShortType | ByteType => Cast(children.head, LongType) case LongType | StringType => children.head case other => @@ -1092,9 +1096,46 @@ class SparkConnectPlanner(val session: SparkSession) { s"Bloom filter only supports integral types, " + s"and does not support type $other.") } - Some( - new BloomFilterAggregate(first, children(1), children(2)) - .toAggregateExpression()) + + val expectedNumItems = children(2) + val n = expectedNumItems match { + case Literal(l: Long, LongType) if l > 0L => l + case _ => + throw InvalidPlanInput("Expected insertions must be positive long literal.") + } + + val numBits = children(3) + numBits match { + case Literal(l: Long, LongType) => + if (l > 0L) { + Some( + new BloomFilterAggregate(col, expectedNumItems, numBits) + .toAggregateExpression()) + } else { + val fpp = children(4) + val p = fpp match { + case DoubleLiteral(d) => d + case _ => + throw InvalidPlanInput("False positive must be double literal.") + } + // if `p.isNaN` and numBits less than zero, throw `numBits` related exception first. + if (p.isNaN) { + throw InvalidPlanInput("Number of bits must be positive") + } + if (p <= 0d || p >= 1d) { + throw InvalidPlanInput( + "False positive probability must be within range (0.0, 1.0)") + } + val opNumOfBits = optimalNumOfBits(n, p) + if (opNumOfBits < 0L) { + throw InvalidPlanInput("Number of bits must be positive") + } + Some( + new BloomFilterAggregate(col, expectedNumItems, Literal(opNumOfBits, LongType)) + .toAggregateExpression()) + } + case _ => throw InvalidPlanInput("Number of bits must be long literal.") + } case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) From 8d0e51754e158e218adb77a979a53e8662552f74 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 18:02:36 +0800 Subject: [PATCH 09/19] refactor --- .../connect/planner/SparkConnectPlanner.scala | 91 +++++++++++-------- 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5e894638dbe3c..b8926d13ba1a1 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1076,9 +1076,6 @@ class SparkConnectPlanner(val session: SparkSession) { case "bloom_filter_agg" if fun.getArgumentsCount == 5 => // [col, catalogString: String, expectedNumItems: Long, numBits: Long, fpp: Double] - def optimalNumOfBits(n: Long, p: Double): Long = - (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong - val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) val dt = { val ddl = children(1) match { @@ -1097,44 +1094,62 @@ class SparkConnectPlanner(val session: SparkSession) { s"and does not support type $other.") } - val expectedNumItems = children(2) - val n = expectedNumItems match { - case Literal(l: Long, LongType) if l > 0L => l + val fpp = children(4) match { + case DoubleLiteral(d) => d case _ => - throw InvalidPlanInput("Expected insertions must be positive long literal.") + throw InvalidPlanInput("False positive must be double literal.") } - val numBits = children(3) - numBits match { - case Literal(l: Long, LongType) => - if (l > 0L) { - Some( - new BloomFilterAggregate(col, expectedNumItems, numBits) - .toAggregateExpression()) - } else { - val fpp = children(4) - val p = fpp match { - case DoubleLiteral(d) => d - case _ => - throw InvalidPlanInput("False positive must be double literal.") - } - // if `p.isNaN` and numBits less than zero, throw `numBits` related exception first. - if (p.isNaN) { - throw InvalidPlanInput("Number of bits must be positive") - } - if (p <= 0d || p >= 1d) { - throw InvalidPlanInput( - "False positive probability must be within range (0.0, 1.0)") - } - val opNumOfBits = optimalNumOfBits(n, p) - if (opNumOfBits < 0L) { - throw InvalidPlanInput("Number of bits must be positive") - } - Some( - new BloomFilterAggregate(col, expectedNumItems, Literal(opNumOfBits, LongType)) - .toAggregateExpression()) - } - case _ => throw InvalidPlanInput("Number of bits must be long literal.") + if (fpp.isNaN) { // use expectedNumItems and numBits + val expectedNumItemsExpr = children(2) + val expectedNumItems = expectedNumItemsExpr match { + case Literal(l: Long, LongType) => l + case _ => + throw InvalidPlanInput("Expected insertions must be long literal.") + } + if (expectedNumItems <= 0L) { + throw InvalidPlanInput("Expected insertions must be positive.") + } + + val numBitsExpr = children(3) + val numBits = numBitsExpr match { + case Literal(l: Long, LongType) => l + case _ => + throw InvalidPlanInput("Number of bits must be long literal.") + } + if (numBits <= 0L) { + throw InvalidPlanInput("Number of bits must be positive.") + } + + Some( + new BloomFilterAggregate(col, expectedNumItemsExpr, numBitsExpr) + .toAggregateExpression()) + + } else { // use expectedNumItems and fpp + def optimalNumOfBits(n: Long, p: Double): Long = + (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong + + val expectedNumItemsExpr = children(2) + val expectedNumItems = expectedNumItemsExpr match { + case Literal(l: Long, LongType) => l + case _ => + throw InvalidPlanInput("Expected insertions must be long literal.") + } + if (expectedNumItems <= 0L) { + throw InvalidPlanInput("Expected insertions must be positive.") + } + + if (fpp <= 0d || fpp >= 1d) { + throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)") + } + + val numBits = optimalNumOfBits(expectedNumItems, fpp) + if (numBits <= 0L) { + throw InvalidPlanInput("Number of bits must be positive") + } + Some( + new BloomFilterAggregate(col, expectedNumItemsExpr, Literal(numBits, LongType)) + .toAggregateExpression()) } case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => From d193de02f352f3890e61da0c9c3f8077f776e791 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 10 Mar 2023 18:10:57 +0800 Subject: [PATCH 10/19] move check to server side --- .../apache/spark/util/sketch/BloomFilter.java | 2 +- .../connect/planner/SparkConnectPlanner.scala | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index ab6bf8d96ac10..5c01841e5015a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -201,7 +201,7 @@ private static int optimalNumOfHashFunctions(long n, long m) { * @param n expected insertions (must be positive) * @param p false positive rate (must be 0 < p < 1) */ - public static long optimalNumOfBits(long n, double p) { + private static long optimalNumOfBits(long n, double p) { return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b8926d13ba1a1..01eccb96b3b97 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1100,7 +1100,9 @@ class SparkConnectPlanner(val session: SparkSession) { throw InvalidPlanInput("False positive must be double literal.") } - if (fpp.isNaN) { // use expectedNumItems and numBits + if (fpp.isNaN) { + // Use expectedNumItems and numBits when `fpp.isNaN` if true. + // Check expectedNumItems > 0L val expectedNumItemsExpr = children(2) val expectedNumItems = expectedNumItemsExpr match { case Literal(l: Long, LongType) => l @@ -1110,7 +1112,7 @@ class SparkConnectPlanner(val session: SparkSession) { if (expectedNumItems <= 0L) { throw InvalidPlanInput("Expected insertions must be positive.") } - + // Check numBits > 0L val numBitsExpr = children(3) val numBits = numBitsExpr match { case Literal(l: Long, LongType) => l @@ -1120,15 +1122,17 @@ class SparkConnectPlanner(val session: SparkSession) { if (numBits <= 0L) { throw InvalidPlanInput("Number of bits must be positive.") } - + // Create BloomFilterAggregate with expectedNumItemsExpr and numBitsExpr. Some( new BloomFilterAggregate(col, expectedNumItemsExpr, numBitsExpr) .toAggregateExpression()) - } else { // use expectedNumItems and fpp + } else { def optimalNumOfBits(n: Long, p: Double): Long = (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong + // Use expectedNumItems and fpp when `fpp.isNaN` if false. + // Check expectedNumItems > 0L val expectedNumItemsExpr = children(2) val expectedNumItems = expectedNumItemsExpr match { case Literal(l: Long, LongType) => l @@ -1139,14 +1143,17 @@ class SparkConnectPlanner(val session: SparkSession) { throw InvalidPlanInput("Expected insertions must be positive.") } + // Check fpp in (0.0, 1.0). if (fpp <= 0d || fpp >= 1d) { throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)") } - + // Calculate numBits through expectedNumItems and fpp, + // refer to `BloomFilter.optimalNumOfBits(long, double)`. val numBits = optimalNumOfBits(expectedNumItems, fpp) if (numBits <= 0L) { throw InvalidPlanInput("Number of bits must be positive") } + // Create BloomFilterAggregate with expectedNumItemsExpr and new numBits. Some( new BloomFilterAggregate(col, expectedNumItemsExpr, Literal(numBits, LongType)) .toAggregateExpression()) From dc629962aeb816f0d12aea924d8b5375d4911201 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 14 Mar 2023 13:40:08 +0800 Subject: [PATCH 11/19] try remove lazy --- .../catalyst/expressions/aggregate/BloomFilterAggregate.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 03d8136f11682..b70b6a64bf033 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -151,7 +151,7 @@ case class BloomFilterAggregate( Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) - private lazy val updater: BloomFilterUpdater = first.dataType match { + private val updater: BloomFilterUpdater = first.dataType match { case LongType => LongUpdater case StringType => BinaryUpdater } From df56a053ca72712722ff3592ad96b23f9b867f65 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 14 Mar 2023 16:53:05 +0800 Subject: [PATCH 12/19] make lazy --- .../catalyst/expressions/aggregate/BloomFilterAggregate.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index b70b6a64bf033..6e197ba0d2501 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -151,7 +151,8 @@ case class BloomFilterAggregate( Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, SQLConf.get.getConf(RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) - private val updater: BloomFilterUpdater = first.dataType match { + // Mark as lazy so that `updater` is not evaluated during tree transformation. + private lazy val updater: BloomFilterUpdater = first.dataType match { case LongType => LongUpdater case StringType => BinaryUpdater } From f6474c88178b0b5eb5b43bb8c43baa9d34345a39 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 21 Mar 2023 17:45:17 +0800 Subject: [PATCH 13/19] add Serializable --- .../catalyst/expressions/aggregate/BloomFilterAggregate.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 6e197ba0d2501..500a61ddc4371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -236,12 +236,12 @@ private trait BloomFilterUpdater { def update(bf: BloomFilter, v: Any): Boolean } -private object LongUpdater extends BloomFilterUpdater { +private object LongUpdater extends BloomFilterUpdater with Serializable { override def update(bf: BloomFilter, v: Any): Boolean = bf.putLong(v.asInstanceOf[Long]) } -private object BinaryUpdater extends BloomFilterUpdater { +private object BinaryUpdater extends BloomFilterUpdater with Serializable { override def update(bf: BloomFilter, v: Any): Boolean = bf.putBinary(v.asInstanceOf[UTF8String].getBytes) } From ad1d2057a4fa6a225f6d290ae1065efec8fe404e Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 23 Mar 2023 13:04:52 +0800 Subject: [PATCH 14/19] change since from 3.4.0 to 3.5.0 --- .../org/apache/spark/sql/DataFrameStatFunctions.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 02f7dfba95086..b7b772af82375 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -595,7 +595,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo * expected number of items which will be put into the filter. * @param fpp * expected false positive probability of the filter. - * @since 3.4.0 + * @since 3.5.0 */ def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp) @@ -610,7 +610,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo * expected number of items which will be put into the filter. * @param fpp * expected false positive probability of the filter. - * @since 3.4.0 + * @since 3.5.0 */ def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { buildBloomFilter(col, expectedNumItems, -1L, fpp) @@ -625,7 +625,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo * expected number of items which will be put into the filter. * @param numBits * expected number of bits of the filter. - * @since 3.4.0 + * @since 3.5.0 */ def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN) @@ -640,7 +640,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo * expected number of items which will be put into the filter. * @param numBits * expected number of bits of the filter. - * @since 3.4.0 + * @since 3.5.0 */ def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) From 176ce09b16e57ebc48c565eccb1d85331c30e93c Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 7 Apr 2023 17:11:13 +0800 Subject: [PATCH 15/19] merge case --- .../catalyst/expressions/aggregate/BloomFilterAggregate.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 500a61ddc4371..d239659d00961 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -79,7 +79,7 @@ case class BloomFilterAggregate( "exprName" -> "estimatedNumItems or numBits" ) ) - case (LongType, LongType, LongType) | (StringType, LongType, LongType) => + case (LongType | StringType, LongType, LongType) => if (!estimatedNumItemsExpression.foldable) { DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", From ea2234cfbae735997a8dc4dc80fba546b600b4e1 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 7 Apr 2023 17:57:36 +0800 Subject: [PATCH 16/19] refactor to remove pass dataType --- .../spark/sql/DataFrameStatFunctions.scala | 18 +--------- .../apache/spark/sql/DataFrameStatSuite.scala | 10 ++++++ .../connect/planner/SparkConnectPlanner.scala | 35 ++++++------------- .../aggregate/BloomFilterAggregate.scala | 20 ++++++++++- 4 files changed, 41 insertions(+), 42 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index b7b772af82375..1b8e3ddf5a5e3 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -652,23 +652,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo numBits: Long, fpp: Double): BloomFilter = { - val dataType = sparkSession - .newDataFrame { builder => - builder.getProjectBuilder - .setInput(root) - .addExpressions(col.expr) - } - .schema - .head - .dataType - - val agg = Column.fn( - "bloom_filter_agg", - col, - lit(dataType.catalogString), - lit(expectedNumItems), - lit(numBits), - lit(fpp)) + val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits), lit(fpp)) val ds = sparkSession.newDataset(BinaryEncoder) { builder => builder.getProjectBuilder .setInput(root) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 91e849e2866e4..77a44747851b8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -226,6 +226,16 @@ class DataFrameStatSuite extends RemoteSparkSession { assert(data.forall(filter2.mightContain)) } + test("Bloom filter -- Wrong dataType Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toDouble) + val message = intercept[StatusRuntimeException] { + data.toDF("id").stat.bloomFilter("id", 1000, 0.03) + }.getMessage + assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE")) + } + test("Bloom filter test invalid inputs") { val df = spark.range(1000).toDF("id") val message1 = intercept[StatusRuntimeException] { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 536cf182d0b78..a7c264729f093 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1155,27 +1155,11 @@ class SparkConnectPlanner(val session: SparkSession) { } Some(Lead(children.head, children(1), children(2), ignoreNulls)) - case "bloom_filter_agg" if fun.getArgumentsCount == 5 => - // [col, catalogString: String, expectedNumItems: Long, numBits: Long, fpp: Double] + case "bloom_filter_agg" if fun.getArgumentsCount == 4 => + // [col, expectedNumItems: Long, numBits: Long, fpp: Double] val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - val dt = { - val ddl = children(1) match { - case StringLiteral(s) => s - case other => - throw InvalidPlanInput(s"col dataType should be a literal string, but got $other") - } - DataType.fromDDL(ddl) - } - val col = dt match { - case IntegerType | ShortType | ByteType => Cast(children.head, LongType) - case LongType | StringType => children.head - case other => - throw InvalidPlanInput( - s"Bloom filter only supports integral types, " + - s"and does not support type $other.") - } - val fpp = children(4) match { + val fpp = children(3) match { case DoubleLiteral(d) => d case _ => throw InvalidPlanInput("False positive must be double literal.") @@ -1184,7 +1168,7 @@ class SparkConnectPlanner(val session: SparkSession) { if (fpp.isNaN) { // Use expectedNumItems and numBits when `fpp.isNaN` if true. // Check expectedNumItems > 0L - val expectedNumItemsExpr = children(2) + val expectedNumItemsExpr = children(1) val expectedNumItems = expectedNumItemsExpr match { case Literal(l: Long, LongType) => l case _ => @@ -1194,7 +1178,7 @@ class SparkConnectPlanner(val session: SparkSession) { throw InvalidPlanInput("Expected insertions must be positive.") } // Check numBits > 0L - val numBitsExpr = children(3) + val numBitsExpr = children(2) val numBits = numBitsExpr match { case Literal(l: Long, LongType) => l case _ => @@ -1205,7 +1189,7 @@ class SparkConnectPlanner(val session: SparkSession) { } // Create BloomFilterAggregate with expectedNumItemsExpr and numBitsExpr. Some( - new BloomFilterAggregate(col, expectedNumItemsExpr, numBitsExpr) + new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr) .toAggregateExpression()) } else { @@ -1214,7 +1198,7 @@ class SparkConnectPlanner(val session: SparkSession) { // Use expectedNumItems and fpp when `fpp.isNaN` if false. // Check expectedNumItems > 0L - val expectedNumItemsExpr = children(2) + val expectedNumItemsExpr = children(1) val expectedNumItems = expectedNumItemsExpr match { case Literal(l: Long, LongType) => l case _ => @@ -1236,7 +1220,10 @@ class SparkConnectPlanner(val session: SparkSession) { } // Create BloomFilterAggregate with expectedNumItemsExpr and new numBits. Some( - new BloomFilterAggregate(col, expectedNumItemsExpr, Literal(numBits, LongType)) + new BloomFilterAggregate( + children.head, + expectedNumItemsExpr, + Literal(numBits, LongType)) .toAggregateExpression()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index d239659d00961..2325b2b055f96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -79,7 +79,7 @@ case class BloomFilterAggregate( "exprName" -> "estimatedNumItems or numBits" ) ) - case (LongType | StringType, LongType, LongType) => + case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) => if (!estimatedNumItemsExpression.foldable) { DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -154,6 +154,9 @@ case class BloomFilterAggregate( // Mark as lazy so that `updater` is not evaluated during tree transformation. private lazy val updater: BloomFilterUpdater = first.dataType match { case LongType => LongUpdater + case IntegerType => IntUpdater + case ShortType => ShortUpdater + case ByteType => ByteUpdater case StringType => BinaryUpdater } @@ -241,6 +244,21 @@ private object LongUpdater extends BloomFilterUpdater with Serializable { bf.putLong(v.asInstanceOf[Long]) } +private object IntUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Int]) +} + +private object ShortUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Short]) +} + +private object ByteUpdater extends BloomFilterUpdater with Serializable { + override def update(bf: BloomFilter, v: Any): Boolean = + bf.putLong(v.asInstanceOf[Byte]) +} + private object BinaryUpdater extends BloomFilterUpdater with Serializable { override def update(bf: BloomFilter, v: Any): Boolean = bf.putBinary(v.asInstanceOf[UTF8String].getBytes) From 5bd616954aaef21b88b161e20e0f49bc067cae0c Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 7 Apr 2023 18:42:14 +0800 Subject: [PATCH 17/19] chanage to always pass 3 parameters --- .../spark/sql/DataFrameStatFunctions.scala | 7 +- .../connect/planner/SparkConnectPlanner.scala | 104 +++++++----------- 2 files changed, 44 insertions(+), 67 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 1b8e3ddf5a5e3..f12e7d8fe9425 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -652,7 +652,12 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo numBits: Long, fpp: Double): BloomFilter = { - val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits), lit(fpp)) + val agg = if (!fpp.isNaN) { + Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(fpp)) + } else { + Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) + } + val ds = sparkSession.newDataset(BinaryEncoder) { builder => builder.getProjectBuilder .setInput(root) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a7c264729f093..1f4ceaf2e5e17 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1155,77 +1155,49 @@ class SparkConnectPlanner(val session: SparkSession) { } Some(Lead(children.head, children(1), children(2), ignoreNulls)) - case "bloom_filter_agg" if fun.getArgumentsCount == 4 => - // [col, expectedNumItems: Long, numBits: Long, fpp: Double] - val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - - val fpp = children(3) match { - case DoubleLiteral(d) => d + case "bloom_filter_agg" if fun.getArgumentsCount == 3 => + // [col, expectedNumItems: Long, numBits: Long] or + // [col, expectedNumItems: Long, fpp: Double] + val children = fun.getArgumentsList.asScala.map(transformExpression) + + // Check expectedNumItems > 0L + val expectedNumItemsExpr = children(1) + val expectedNumItems = expectedNumItemsExpr match { + case Literal(l: Long, LongType) => l case _ => - throw InvalidPlanInput("False positive must be double literal.") + throw InvalidPlanInput("Expected insertions must be long literal.") + } + if (expectedNumItems <= 0L) { + throw InvalidPlanInput("Expected insertions must be positive.") } - if (fpp.isNaN) { - // Use expectedNumItems and numBits when `fpp.isNaN` if true. - // Check expectedNumItems > 0L - val expectedNumItemsExpr = children(1) - val expectedNumItems = expectedNumItemsExpr match { - case Literal(l: Long, LongType) => l - case _ => - throw InvalidPlanInput("Expected insertions must be long literal.") - } - if (expectedNumItems <= 0L) { - throw InvalidPlanInput("Expected insertions must be positive.") - } - // Check numBits > 0L - val numBitsExpr = children(2) - val numBits = numBitsExpr match { - case Literal(l: Long, LongType) => l - case _ => - throw InvalidPlanInput("Number of bits must be long literal.") - } - if (numBits <= 0L) { - throw InvalidPlanInput("Number of bits must be positive.") - } - // Create BloomFilterAggregate with expectedNumItemsExpr and numBitsExpr. - Some( - new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr) - .toAggregateExpression()) + val numberBitsOrFpp = children(2) - } else { - def optimalNumOfBits(n: Long, p: Double): Long = - (-n * Math.log(p) / (Math.log(2) * Math.log(2))).toLong - - // Use expectedNumItems and fpp when `fpp.isNaN` if false. - // Check expectedNumItems > 0L - val expectedNumItemsExpr = children(1) - val expectedNumItems = expectedNumItemsExpr match { - case Literal(l: Long, LongType) => l - case _ => - throw InvalidPlanInput("Expected insertions must be long literal.") - } - if (expectedNumItems <= 0L) { - throw InvalidPlanInput("Expected insertions must be positive.") - } - - // Check fpp in (0.0, 1.0). - if (fpp <= 0d || fpp >= 1d) { - throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)") - } - // Calculate numBits through expectedNumItems and fpp, - // refer to `BloomFilter.optimalNumOfBits(long, double)`. - val numBits = optimalNumOfBits(expectedNumItems, fpp) - if (numBits <= 0L) { - throw InvalidPlanInput("Number of bits must be positive") - } - // Create BloomFilterAggregate with expectedNumItemsExpr and new numBits. - Some( - new BloomFilterAggregate( - children.head, - expectedNumItemsExpr, - Literal(numBits, LongType)) - .toAggregateExpression()) + val numBitsExpr = numberBitsOrFpp match { + case Literal(numBits: Long, LongType) => + // Check numBits > 0L + if (numBits <= 0L) { + throw InvalidPlanInput("Number of bits must be positive.") + } + numberBitsOrFpp + case DoubleLiteral(fpp) => + // Check fpp in (0.0, 1.0). + if (fpp <= 0d || fpp >= 1d) { + throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)") + } + // Calculate numBits through expectedNumItems and fpp, + // refer to `BloomFilter.optimalNumOfBits(long, double)`. + val numBits = (-expectedNumItems * Math.log(fpp) / (Math.log(2) * Math.log(2))).toLong + if (numBits <= 0L) { + throw InvalidPlanInput("Number of bits must be positive") + } + Literal(numBits, LongType) + case _ => + throw InvalidPlanInput("The 3rd parameter must be double or long literal.") } + Some( + new BloomFilterAggregate(children.head, expectedNumItemsExpr, numBitsExpr) + .toAggregateExpression()) case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 4 => val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) From 43d6305fad43b3f6c9573f67085872da10a96c27 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 7 Apr 2023 18:52:30 +0800 Subject: [PATCH 18/19] add NaN check in Planner --- .../spark/sql/connect/planner/SparkConnectPlanner.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 1f4ceaf2e5e17..2a083a994b952 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1181,8 +1181,8 @@ class SparkConnectPlanner(val session: SparkSession) { } numberBitsOrFpp case DoubleLiteral(fpp) => - // Check fpp in (0.0, 1.0). - if (fpp <= 0d || fpp >= 1d) { + // Check fpp not NaN and in (0.0, 1.0). + if (fpp.isNaN || fpp <= 0d || fpp >= 1d) { throw InvalidPlanInput("False positive probability must be within range (0.0, 1.0)") } // Calculate numBits through expectedNumItems and fpp, From dbbb61a1ee106e154b50848deaa30b4de5c8114b Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 7 Apr 2023 19:06:56 +0800 Subject: [PATCH 19/19] bridge optimalNumOfBits --- .../apache/spark/util/sketch/BloomFilter.java | 2 +- .../connect/planner/SparkConnectPlanner.scala | 3 ++- .../spark/util/sketch/BloomFilterHelper.scala | 26 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/util/sketch/BloomFilterHelper.scala diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index 5c01841e5015a..e05f1ac3a50ae 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -201,7 +201,7 @@ private static int optimalNumOfHashFunctions(long n, long m) { * @param n expected insertions (must be positive) * @param p false positive rate (must be 0 < p < 1) */ - private static long optimalNumOfBits(long n, double p) { + static long optimalNumOfBits(long n, double p) { return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 2a083a994b952..146f45c3d9463 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -59,6 +59,7 @@ import org.apache.spark.sql.internal.CatalogImpl import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils +import org.apache.spark.util.sketch.BloomFilterHelper final case class InvalidCommandInput( private val message: String = "", @@ -1187,7 +1188,7 @@ class SparkConnectPlanner(val session: SparkSession) { } // Calculate numBits through expectedNumItems and fpp, // refer to `BloomFilter.optimalNumOfBits(long, double)`. - val numBits = (-expectedNumItems * Math.log(fpp) / (Math.log(2) * Math.log(2))).toLong + val numBits = BloomFilterHelper.optimalNumOfBits(expectedNumItems, fpp) if (numBits <= 0L) { throw InvalidPlanInput("Number of bits must be positive") } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/util/sketch/BloomFilterHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/util/sketch/BloomFilterHelper.scala new file mode 100644 index 0000000000000..bbb0ee3c2f1a4 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/util/sketch/BloomFilterHelper.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch + +/** + * `BloomFilterHelper` is used to bridge helper methods in BloomFilter` + */ +private[spark] object BloomFilterHelper { + def optimalNumOfBits(expectedNumItems: Long, fpp: Double): Long = + BloomFilter.optimalNumOfBits(expectedNumItems, fpp) +}