Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions lib/Clustering/HDBSCAN.php
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ class HDBSCAN {

/**
* @param Labeled $dataset
* @param array $oldCoreDistances
* @param int $minClusterSize
* @param int $sampleSize
* @param float $maxEdgeLength
* @param array $oldCoreDistances
* @param Distance $kernel
* @param bool $useTrueMst // (Build true or approximate minimum spanning tree)
* @throws \Rubix\ML\Exceptions\InvalidArgumentException
*/
public function __construct(Labeled $dataset, int $minClusterSize = 5, int $sampleSize = 5, array $oldCoreDistances = [], ?Distance $kernel = null, bool $useTrueMst = true) {
Expand Down Expand Up @@ -128,9 +128,11 @@ public function params(): array {
/**
* Form clusters and make predictions from the dataset (hard clustering).
*
* @param float $minClusterSeparation
* @param float $maxEdgeLength
* @return list<MstClusterer>
*/
public function predict(): array {
public function predict(float $minClusterSeparation = 0.0, float $maxEdgeLength=0.5): array {
// Boruvka algorithm for MST generation
$edges = $this->mstSolver->getMst();

Expand All @@ -142,8 +144,7 @@ public function predict(): array {
}
unset($edge);

// TODO: Min cluster separation/edge length of MstClusterer to the caller of this class
$mstClusterer = new MstClusterer($edges, null, $this->minClusterSize, null, 0.0);
$mstClusterer = new MstClusterer($edges, null, $this->minClusterSize, null, $minClusterSeparation, $maxEdgeLength);
$flatClusters = $mstClusterer->processCluster();

return $flatClusters;
Expand Down
34 changes: 25 additions & 9 deletions lib/Clustering/MstClusterer.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

namespace OCA\Recognize\Clustering;

// TODO: core edges are not always stored properly (if two halves of the remaining clusters are both pruned at the same time)
// TODO: store vertex lambda length (relative to cluster lambda length) for all vertices for soft clustering.
// TODO: store vertex lambda length (relative to cluster lambda length) for all vertices for improved soft clustering (see https://hdbscan.readthedocs.io/en/latest/soft_clustering.html)
class MstClusterer {
private array $edges;

Expand All @@ -22,8 +21,9 @@ class MstClusterer {
private bool $isRoot;
private array $mapVerticesToEdges;
private float $minClusterSeparation;
private float $maxEdgeLength;

public function __construct(array $edges, ?array $mapVerticesToEdges, int $minimumClusterSize, ?float $startingLambda = null, float $minClusterSeparation = 0.1) {
public function __construct(array $edges, ?array $mapVerticesToEdges, int $minimumClusterSize, ?float $startingLambda = null, float $minClusterSeparation = 0.1, float $maxEdgeLength = 0.5) {
//Ascending sort of edges while perserving original keys.
$this->edges = $edges;

Expand Down Expand Up @@ -62,18 +62,25 @@ public function __construct(array $edges, ?array $mapVerticesToEdges, int $minim
$this->coreEdges = [];

$this->clusterWeight = 0.0;


$this->maxEdgeLength = $maxEdgeLength;

$this->minClusterSeparation = $minClusterSeparation;
}

public function processCluster(): array {
$currentLambda = $lastLambda = $this->startingLambda;

$edgeLength = INF;
while (true) {
$edgeCount = count($this->remainingEdges);

if ($edgeCount < ($this->minimumClusterSize - 1)) {
if ($edgeLength > $this->maxEdgeLength) {
// This cluster is too sparse and probably just noise
return [];
}


foreach ($this->coreEdges as &$edge) {
$edge['finalLambda'] = $currentLambda;
}
Expand Down Expand Up @@ -101,7 +108,11 @@ public function processCluster(): array {
unset($this->mapVerticesToEdges[$vertexConnectedFrom][$currentLongestEdgeKey]);
unset($this->mapVerticesToEdges[$vertexConnectedTo][$currentLongestEdgeKey]);

if ($edgeLength > 0.0) {
if ($edgeLength > $this->maxEdgeLength) {
// Prevent formation of clusters with edges longer than the maximum edge length
// This is done by forcing the weight of the current cluster to zero
$lastLambda = $currentLambda = 1 / $edgeLength;
} else if ($edgeLength > 0.0) {
$currentLambda = 1 / $edgeLength;
}

Expand Down Expand Up @@ -133,8 +144,8 @@ public function processCluster(): array {
// of clusters that weigh the most (i.e. have most (excess of) mass). Always discard the root cluster.


$childCluster1 = new MstClusterer($childClusterEdges1, $childClusterVerticesToEdges1, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation);
$childCluster2 = new MstClusterer($childClusterEdges2, $childClusterVerticesToEdges2, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation);
$childCluster1 = new MstClusterer($childClusterEdges1, $childClusterVerticesToEdges1, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation, $this->maxEdgeLength);
$childCluster2 = new MstClusterer($childClusterEdges2, $childClusterVerticesToEdges2, $this->minimumClusterSize, $currentLambda, $this->minClusterSeparation, $this->maxEdgeLength);

// Resolve all chosen child clusters recursively
$childClusters = array_merge($childCluster1->processCluster(), $childCluster2->processCluster());
Expand All @@ -145,7 +156,7 @@ public function processCluster(): array {
$this->coreEdges = array_merge($this->coreEdges, $childCluster->getCoreEdges());
}

if (($childrenWeight > $this->clusterWeight) || $this->isRoot) {
if (($childrenWeight >= $this->clusterWeight) || $this->isRoot) {
return $childClusters;
} else {
foreach (array_keys($this->remainingEdges) as $edgeKey) {
Expand All @@ -155,6 +166,11 @@ public function processCluster(): array {

return [$this];
}

if ($edgeLength > $this->maxEdgeLength) {
// Any pruned vertices were too far away to be part of the cluster
$this->edges = $this->remainingEdges;
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion lib/Service/FaceClusterAnalyzer.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class FaceClusterAnalyzer {
public const MIN_SAMPLE_SIZE = 4; // Conservative value: 10
public const MIN_CLUSTER_SIZE = 5; // Conservative value: 10
public const MIN_DETECTION_SIZE = 0.03;
public const MIN_CLUSTER_SEPARATION = 0.0;
public const MAX_CLUSTER_EDGE_LENGTH = 0.5;
public const DIMENSIONS = 128;
public const SAMPLE_SIZE_EXISTING_CLUSTERS = 42;

Expand Down Expand Up @@ -74,7 +76,7 @@ public function calculateClusters(string $userId, int $batchSize = 0): void {
$hdbscan = new HDBSCAN($dataset, self::MIN_CLUSTER_SIZE, self::MIN_SAMPLE_SIZE);

$numberOfClusteredDetections = 0;
$clusters = $hdbscan->predict();
$clusters = $hdbscan->predict(self::MIN_CLUSTER_SEPARATION, self::MAX_CLUSTER_EDGE_LENGTH);

foreach ($clusters as $flatCluster) {
$detectionKeys = array_keys($flatCluster->getClusterVertices());
Expand Down