diff --git a/src/main/java/tech/sourced/gemini/WeightedMinHash.java b/src/main/java/tech/sourced/gemini/WeightedMinHash.java index d6e494f6..7137e84b 100644 --- a/src/main/java/tech/sourced/gemini/WeightedMinHash.java +++ b/src/main/java/tech/sourced/gemini/WeightedMinHash.java @@ -7,6 +7,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.Serializable; + import static java.lang.Math.floor; import static java.lang.Math.log; @@ -15,7 +17,7 @@ * https://github.com/ekzhu/datasketch/blob/master/datasketch/weighted_minhash.py * https://github.com/src-d/go-license-detector/blob/master/licensedb/internal/wmh/wmh.go */ -public class WeightedMinHash { +public class WeightedMinHash implements Serializable { private static final Logger log = LoggerFactory.getLogger(WeightedMinHash.class); protected int dim; diff --git a/src/main/scala/tech/sourced/gemini/Hash.scala b/src/main/scala/tech/sourced/gemini/Hash.scala index 6f2bf070..ea63aa20 100644 --- a/src/main/scala/tech/sourced/gemini/Hash.scala +++ b/src/main/scala/tech/sourced/gemini/Hash.scala @@ -196,6 +196,8 @@ class Hash(session: SparkSession, ): Dataset[RDDHash] = { log.warn("hashing features") + val wmh = makeBroadcastedWmh(docFreq.tokens.size, sampleSize) + val tf = features.rdd .map { case Row(feature: String, doc: String, weight: Long) => (RDDFeatureKey(feature, doc), weight) } .reduceByKey(_ + _) @@ -204,9 +206,8 @@ class Hash(session: SparkSession, .map(row => (row._1.doc, Feature(row._1.token, row._2))) .groupByKey(session.sparkContext.defaultParallelism) .mapPartitions { partIter => - val wmh = FeaturesHash.initWmh(docFreq.tokens.size, sampleSize) // ~1.6 Gb (for 1 PGA bucket) partIter.map { case (doc, features) => - RDDHash(doc, wmh.hash(FeaturesHash.toBagOfFeatures(features.iterator, docFreq))) + RDDHash(doc, wmh.value.hash(FeaturesHash.toBagOfFeatures(features.iterator, docFreq))) } } tfIdf.toDS() @@ -222,17 +223,33 @@ class Hash(session: SparkSession, ): Dataset[RDDHash] = { log.warn("hashing features") + val wmh = makeBroadcastedWmh(docFreq.tokens.size, sampleSize) val tf = features.groupBy("feature", "doc").sum("weight").alias("weight") val tfIdf = tf .map { case Row(token: String, doc: String, weight: Long) => (doc, Feature(token, weight)) } .groupByKey { case (doc, _) => doc } .mapGroups { (doc, features) => - val wmh = FeaturesHash.initWmh(docFreq.tokens.size, sampleSize) // ~1.6 Gb RAM (for 1 PGA bucket) - RDDHash(doc, wmh.hash(FeaturesHash.toBagOfFeatures(features.map(_._2), docFreq))) + RDDHash(doc, wmh.value.hash(FeaturesHash.toBagOfFeatures(features.map(_._2), docFreq))) } tfIdf } + /** + * Create WeightedMinHash instance and broadcasts it + * + * create it only once and keep on node + * because the instance is relatively huge (2 * N of features * sampleSize(160 or 256 depends on mode) * 8) + * According to tests ~1.6 Gb per 1 PGA bucket (but really depends on bucket) + * + * @param tokens number of features + * @param sampleSize depends on hashing mode and threshold + * @return + */ + def makeBroadcastedWmh(tokens: Int, sampleSize: Int): Broadcast[WeightedMinHash] = { + val wmh = FeaturesHash.initWmh(tokens, sampleSize) + session.sparkContext.broadcast(wmh) + } + protected def saveDocFreqToDB(docFreq: OrderedDocFreq, keyspace: String, tables: Tables): Unit = { log.warn(s"save document frequencies to DB")