Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
delete(n.neighbors, worst.Key)
// Delete backlink from the worst neighbor.
delete(worst.neighbors, n.Key)
worst.replenish(m)
worst.replenish(m, dist)
}

type searchCandidate[K cmp.Ordered] struct {
Expand Down Expand Up @@ -148,7 +148,7 @@ func (n *layerNode[K]) search(
return result.Slice()
}

func (n *layerNode[K]) replenish(m int) {
func (n *layerNode[K]) replenish(m int, dist DistanceFunc) {
if len(n.neighbors) >= m {
return
}
Expand All @@ -165,7 +165,7 @@ func (n *layerNode[K]) replenish(m int) {
if candidate == n {
continue
}
n.addNeighbor(candidate, m, CosineDistance)
n.addNeighbor(candidate, m, dist)
if len(n.neighbors) >= m {
return
}
Expand All @@ -175,13 +175,13 @@ func (n *layerNode[K]) replenish(m int) {

// isolates remove the node from the graph by removing all connections
// to neighbors.
func (n *layerNode[K]) isolate(m int) {
func (n *layerNode[K]) isolate(m int, dist DistanceFunc) {
for _, neighbor := range n.neighbors {
delete(neighbor.neighbors, n.Key)
}

for _, neighbor := range n.neighbors {
neighbor.replenish(m)
neighbor.replenish(m, dist)
}
}

Expand Down Expand Up @@ -501,7 +501,7 @@ func (h *Graph[K]) Delete(key K) bool {
if len(layer.nodes) == 0 {
deleteLayer[i] = struct{}{}
}
node.isolate(h.M)
node.isolate(h.M, h.Distance)
deleted = true
}

Expand Down
23 changes: 22 additions & 1 deletion graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,15 @@ func TestGraph_AddSearch(t *testing.T) {
)

require.Len(t, nearest, 4)
// The two closest are 64 and 65 (distance 0.5 each).
// The next two are 63 and 66 (distance 1.5 each).
require.EqualValues(
t,
[]Node[int]{
{64, Vector{64}},
{65, Vector{65}},
{62, Vector{62}},
{63, Vector{63}},
{66, Vector{66}},
},
nearest,
)
Expand Down Expand Up @@ -259,3 +261,22 @@ func TestGraph_RemoveAllNodes(t *testing.T) {
g.Add(MakeNode(1, vec))
}
}

func TestGraph_DeleteReplenishUsesGraphDistance(t *testing.T) {
// replenish() previously hardcoded CosineDistance. After deleting a
// node from a EuclideanDistance graph, replenish must use the correct
// distance function or the topology becomes corrupted.
g := newTestGraph[int]() // uses EuclideanDistance
for i := 0; i < 20; i++ {
g.Add(Node[int]{Key: i, Value: Vector{float32(i)}})
}

// Delete a node in the middle to trigger replenish.
g.Delete(10)

// Search should still find the correct nearest neighbor.
results := g.Search(Vector{9.5}, 1)
require.Len(t, results, 1)
// Must be 9 or 11 (both distance 0.5 from 9.5).
require.Contains(t, []int{9, 11}, results[0].Key)
}
18 changes: 16 additions & 2 deletions heap/heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ func (h *Heap[T]) Pop() T {
return heap.Pop(&h.inner).(T)
}

// PopLast removes and returns the maximum element from the heap.
func (h *Heap[T]) PopLast() T {
return h.Remove(h.Len() - 1)
return h.Remove(h.maxIndex())
}

// Remove removes and returns the element at index i from the heap.
Expand All @@ -85,9 +86,22 @@ func (h *Heap[T]) Min() T {
return h.inner.data[0]
}

// maxIndex returns the index of the maximum element by scanning leaf nodes.
// In a min-heap the max is always a leaf (indices n/2 .. n-1).
func (h *Heap[T]) maxIndex() int {
n := h.inner.Len()
best := n / 2
for i := best + 1; i < n; i++ {
if h.inner.data[best].Less(h.inner.data[i]) {
best = i
}
}
return best
}

// Max returns the maximum element in the heap.
func (h *Heap[T]) Max() T {
return h.inner.data[h.inner.Len()-1]
return h.inner.data[h.maxIndex()]
}

func (h *Heap[T]) Slice() []T {
Expand Down
17 changes: 17 additions & 0 deletions heap/heap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,20 @@ func TestHeap(t *testing.T) {
t.Errorf("Heap did not return sorted elements: %+v", inOrder)
}
}

func TestHeap_MaxAndPopLast(t *testing.T) {
h := Heap[Int]{}
values := []Int{5, 1, 9, 3, 7, 2, 8, 4, 6}
for _, v := range values {
h.Push(v)
}

require.Equal(t, Int(9), h.Max(), "Max should return the largest element")
require.Equal(t, Int(1), h.Min(), "Min should return the smallest element")

// PopLast should remove and return the maximum.
popped := h.PopLast()
require.Equal(t, Int(9), popped)
require.Equal(t, Int(8), h.Max(), "Max should be 8 after removing 9")
require.Equal(t, 8, h.Len())
}
Loading