diff --git a/lib/Clustering/HDBSCAN.php b/lib/Clustering/HDBSCAN.php index 37bd7030f..f6748033e 100644 --- a/lib/Clustering/HDBSCAN.php +++ b/lib/Clustering/HDBSCAN.php @@ -66,11 +66,11 @@ class HDBSCAN { /** * @param Labeled $dataset - * @param array $oldCoreDistances * @param int $minClusterSize * @param int $sampleSize - * @param float $maxEdgeLength + * @param array $oldCoreDistances * @param Distance $kernel + * @param bool $useTrueMst // (Build true or approximate minimum spanning tree) * @throws \Rubix\ML\Exceptions\InvalidArgumentException */ public function __construct(Labeled $dataset, int $minClusterSize = 5, int $sampleSize = 5, array $oldCoreDistances = [], ?Distance $kernel = null, bool $useTrueMst = true) { @@ -128,9 +128,11 @@ public function params(): array { /** * Form clusters and make predictions from the dataset (hard clustering). * + * @param float $minClusterSeparation + * @param float $maxEdgeLength * @return list */ - public function predict(): array { + public function predict(float $minClusterSeparation = 0.0, float $maxEdgeLength=0.5): array { // Boruvka algorithm for MST generation $edges = $this->mstSolver->getMst(); @@ -142,8 +144,7 @@ public function predict(): array { } unset($edge); - // TODO: Min cluster separation/edge length of MstClusterer to the caller of this class - $mstClusterer = new MstClusterer($edges, null, $this->minClusterSize, null, 0.0); + $mstClusterer = new MstClusterer($edges, null, $this->minClusterSize, null, $minClusterSeparation, $maxEdgeLength); $flatClusters = $mstClusterer->processCluster(); return $flatClusters; diff --git a/lib/Clustering/MstClusterer.php b/lib/Clustering/MstClusterer.php index 2f77b5ab6..74010dd2a 100644 --- a/lib/Clustering/MstClusterer.php +++ b/lib/Clustering/MstClusterer.php @@ -6,8 +6,7 @@ namespace OCA\Recognize\Clustering; -// TODO: core edges are not always stored properly (if two halves of the remaining clusters are both pruned at the same time) -// TODO: store vertex lambda length (relative to cluster lambda length) for all vertices for soft clustering. +// TODO: store vertex lambda length (relative to cluster lambda length) for all vertices for improved soft clustering (see https://hdbscan.readthedocs.io/en/latest/soft_clustering.html) class MstClusterer { private array $edges; @@ -22,8 +21,9 @@ class MstClusterer { private bool $isRoot; private array $mapVerticesToEdges; private float $minClusterSeparation; + private float $maxEdgeLength; - public function __construct(array $edges, ?array $mapVerticesToEdges, int $minimumClusterSize, ?float $startingLambda = null, float $minClusterSeparation = 0.1) { + public function __construct(array $edges, ?array $mapVerticesToEdges, int $minimumClusterSize, ?float $startingLambda = null, float $minClusterSeparation = 0.1, float $maxEdgeLength = 0.5) { //Ascending sort of edges while perserving original keys. $this->edges = $edges; @@ -62,18 +62,25 @@ public function __construct(array $edges, ?array $mapVerticesToEdges, int $minim $this->coreEdges = []; $this->clusterWeight = 0.0; - + + $this->maxEdgeLength = $maxEdgeLength; $this->minClusterSeparation = $minClusterSeparation; } public function processCluster(): array { $currentLambda = $lastLambda = $this->startingLambda; - + $edgeLength = INF; while (true) { $edgeCount = count($this->remainingEdges); if ($edgeCount < ($this->minimumClusterSize - 1)) { + if ($edgeLength > $this->maxEdgeLength) { + // This cluster is too sparse and probably just noise + return []; + } + + foreach ($this->coreEdges as &$edge) { $edge['finalLambda'] = $currentLambda; } @@ -101,7 +108,11 @@ public function processCluster(): array { unset($this->mapVerticesToEdges[$vertexConnectedFrom][$currentLongestEdgeKey]); unset($this->mapVerticesToEdges[$vertexConnectedTo][$currentLongestEdgeKey]); - if ($edgeLength > 0.0) { + if ($edgeLength > $this->maxEdgeLength) { + // Prevent formation of clusters with edges longer than the maximum edge length + // This is done by forcing the weight of the current cluster to zero + $lastLambda = $currentLambda = 1 / $edgeLength; + } else if ($edgeLength > 0.0) { $currentLambda = 1 / $edgeLength; } @@ -133,8 +144,8 @@ public function processCluster(): array { // of clusters that weigh the most (i.e. have most (excess of) mass). Always discard the root cluster. - $childCluster1 = new MstClusterer($childClusterEdges1, $childClusterVerticesToEdges1, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation); - $childCluster2 = new MstClusterer($childClusterEdges2, $childClusterVerticesToEdges2, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation); + $childCluster1 = new MstClusterer($childClusterEdges1, $childClusterVerticesToEdges1, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation, $this->maxEdgeLength); + $childCluster2 = new MstClusterer($childClusterEdges2, $childClusterVerticesToEdges2, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation, $this->maxEdgeLength); // Resolve all chosen child clusters recursively $childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster()); @@ -145,7 +156,7 @@ public function processCluster(): array { $this->coreEdges = array_merge($this->coreEdges, $childCluster->getCoreEdges()); } - if (($childrenWeight > $this->clusterWeight) || $this->isRoot) { + if (($childrenWeight >= $this->clusterWeight) || $this->isRoot) { return $childClusters; } else { foreach (array_keys($this->remainingEdges) as $edgeKey) { @@ -155,6 +166,11 @@ public function processCluster(): array { return [$this]; } + + if ($edgeLength > $this->maxEdgeLength) { + // Any pruned vertices were too far away to be part of the cluster + $this->edges = $this->remainingEdges; + } } } diff --git a/lib/Service/FaceClusterAnalyzer.php b/lib/Service/FaceClusterAnalyzer.php index f93d7435e..607c5c8f7 100644 --- a/lib/Service/FaceClusterAnalyzer.php +++ b/lib/Service/FaceClusterAnalyzer.php @@ -19,6 +19,8 @@ class FaceClusterAnalyzer { public const MIN_SAMPLE_SIZE = 4; // Conservative value: 10 public const MIN_CLUSTER_SIZE = 5; // Conservative value: 10 public const MIN_DETECTION_SIZE = 0.03; + public const MIN_CLUSTER_SEPARATION = 0.0; + public const MAX_CLUSTER_EDGE_LENGTH = 0.5; public const DIMENSIONS = 128; public const SAMPLE_SIZE_EXISTING_CLUSTERS = 42; @@ -74,7 +76,7 @@ public function calculateClusters(string $userId, int $batchSize = 0): void { $hdbscan = new HDBSCAN($dataset, self::MIN_CLUSTER_SIZE, self::MIN_SAMPLE_SIZE); $numberOfClusteredDetections = 0; - $clusters = $hdbscan->predict(); + $clusters = $hdbscan->predict(self::MIN_CLUSTER_SEPARATION, self::MAX_CLUSTER_EDGE_LENGTH); foreach ($clusters as $flatCluster) { $detectionKeys = array_keys($flatCluster->getClusterVertices());