Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions __tests__/unit/mst.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import { minimumSpanningTree } from "../../packages/graph/src";
import { Graph } from "@antv/graphlib";

const data = {
nodes: [
{
id: 'A',
data: {},
},
{
id: 'B',
data: {},
},
{
id: 'C',
data: {},
},
{
id: 'D',
data: {},
},
{
id: 'E',
data: {},
},
{
id: 'F',
data: {},
},
{
id: 'G',
data: {},
},
],
edges: [
{
id: 'edge1',
source: 'A',
target: 'B',
data: {
weight: 1,
}
},
{
id: 'edge2',
source: 'B',
target: 'C',
data: {
weight: 1,
}
},
{
id: 'edge3',
source: 'A',
target: 'C',
data: {
weight: 2,
}
},
{
id: 'edge4',
source: 'D',
target: 'A',
data: {
weight: 3,
}
},
{
id: 'edge5',
source: 'D',
target: 'E',
data: {
weight: 4,
}
},
{
id: 'edge6',
source: 'E',
target: 'F',
data: {
weight: 2,
}
},
{
id: 'edge7',
source: 'F',
target: 'D',
data: {
weight: 3,
}
},
],
};
const graph = new Graph(data);
describe('minimumSpanningTree', () => {
it('test kruskal algorithm', () => {
let result = minimumSpanningTree(graph, 'weight');
let totalWeight = 0;
for (let edge of result) {
totalWeight += edge.data.weight;
}
expect(totalWeight).toEqual(10);
});

it('test prim algorithm', () => {
let result = minimumSpanningTree(graph, 'weight', 'prim');
let totalWeight = 0;
for (let edge of result) {
totalWeight += edge.data.weight;
}
expect(totalWeight).toEqual(10);
});
});
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"build:ci": "pnpm -r run build:ci",
"prepare": "husky install",
"test": "jest",
"test_one": "jest ./__tests__/unit/nodes-cosine-similarity.spec.ts",
"test_one": "jest ./__tests__/unit/mst.spec.ts",
"coverage": "jest --coverage",
"build:site": "vite build",
"deploy": "gh-pages -d site/dist",
Expand Down
2 changes: 1 addition & 1 deletion packages/graph/src/cosine-similarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ export const cosineSimilarity = (
// Calculate the cosine similarity between the item vector and the target element vector
const cosineSimilarity = norm2Product ? dot / norm2Product : 0;
return cosineSimilarity;
}
};
1 change: 1 addition & 0 deletions packages/graph/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ export * from './dfs';
export * from './cosine-similarity';
export * from './nodes-cosine-similarity';
export * from './gaddi';
export * from './mst';
106 changes: 106 additions & 0 deletions packages/graph/src/mst.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import UnionFind from './structs/union-find';
import MinBinaryHeap from './structs/binary-heap';
import { Graph, IEdge, IMSTAlgorithm, IMSTAlgorithmOpt } from './types';
import { clone } from '@antv/util';

/**
Calculates the Minimum Spanning Tree (MST) of a graph using the Prim's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
@param graph - The graph for which the MST needs to be calculated.
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
*/
const primMST: IMSTAlgorithm = (graph, weightProps?) => {
const selectedEdges: IEdge[] = [];
const nodes = graph.getAllNodes();
const edges = graph.getAllEdges();
if (nodes.length === 0) {
return selectedEdges;
}
// From the first node
const currNode = nodes[0];
const visited = new Set();
visited.add(currNode);
// Using binary heap to maintain the weight of edges from other nodes that have joined the node
const compareWeight = (a: IEdge, b: IEdge) => {
if (weightProps) {
a.data;
return (a.data[weightProps] as number) - (b.data[weightProps] as number);
}
return 0;
};
const edgeQueue = new MinBinaryHeap<IEdge>(compareWeight);

graph.getRelatedEdges(currNode.id, 'both').forEach((edge) => {
edgeQueue.insert(edge);
});
while (!edgeQueue.isEmpty()) {
// Select the node with the least edge weight between the added node and the added node
const currEdge: IEdge = edgeQueue.delMin();
const source = currEdge.source;
const target = currEdge.target;
if (visited.has(source) && visited.has(target)) continue;
selectedEdges.push(currEdge);
if (!visited.has(source)) {
visited.add(source);
graph.getRelatedEdges(source, 'both').forEach((edge) => {
edgeQueue.insert(edge);
});
}
if (!visited.has(target)) {
visited.add(target);
graph.getRelatedEdges(target, 'both').forEach((edge) => {
edgeQueue.insert(edge);
});
}
}
return selectedEdges;
};

/**
Calculates the Minimum Spanning Tree (MST) of a graph using the Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
@param graph - The graph for which the MST needs to be calculated.
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
*/
const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => {
const selectedEdges: IEdge[] = [];
const nodes = graph.getAllNodes();
const edges = graph.getAllEdges();
if (nodes.length === 0) {
return selectedEdges;
}
// If you specify weight, all edges are sorted by weight from smallest to largest
const weightEdges = clone(edges);
if (weightProps) {
weightEdges.sort((a: IEdge, b: IEdge) => {
return (a.data[weightProps] as number) - (b.data[weightProps] as number);
});
}
const disjointSet = new UnionFind(nodes.map((n) => n.id));
// Starting with the edge with the least weight, if the two nodes connected by this edge are not in the same connected component in graph G, the edge is added.
while (weightEdges.length > 0) {
const curEdge = weightEdges.shift();
const source = curEdge.source;
const target = curEdge.target;
if (!disjointSet.connected(source, target)) {
selectedEdges.push(curEdge);
disjointSet.union(source, target);
}
}
return selectedEdges;
};

/**
Calculates the Minimum Spanning Tree (MST) of a graph using either Prim's or Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight.
@param graph - The graph for which the MST needs to be calculated.
@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0.
@param algo - Optional. The algorithm to use for calculating the MST. Can be either 'prim' for Prim's algorithm, 'kruskal' for Kruskal's algorithm, or undefined to use the default algorithm (Kruskal's algorithm).
@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph.
*/
export const minimumSpanningTree = (graph: Graph, weightProps?: string, algo?: 'prim' | 'kruskal' | undefined): IEdge[] => {
const algos: IMSTAlgorithmOpt = {
'prim': primMST,
'kruskal': kruskalMST,
};
return (algo && algos[algo](graph, weightProps)) || kruskalMST(graph, weightProps);
};
6 changes: 3 additions & 3 deletions packages/graph/src/nodes-cosine-similarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ export const nodesCosineSimilarity = (
allCosineSimilarity: number[],
similarNodes: NodeSimilarity[],
} => {
const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id));
const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id);
const similarNodes = clone(nodes.filter((node) => node.id !== seedNode.id));
const seedNodeIndex = nodes.findIndex((node) => node.id === seedNode.id);
// Collection of all node properties
const properties = getAllProperties(nodes);
// One-hot feature vectors for all node properties
Expand All @@ -40,4 +40,4 @@ export const nodesCosineSimilarity = (
// Sort the returned nodes according to cosine similarity
similarNodes.sort((a: NodeSimilarity, b: NodeSimilarity) => b.data.cosineSimilarity - a.data.cosineSimilarity);
return { allCosineSimilarity, similarNodes };
}
};
87 changes: 87 additions & 0 deletions packages/graph/src/structs/binary-heap.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@

export default class MinBinaryHeap<T> {
list: T[];

compareFn: (a?: T, b?: T) => number;

constructor(compareFn: (a: T, b: T) => number) {
this.compareFn = compareFn || (() => 0);
this.list = [];
}

getLeft(index: number) {
return 2 * index + 1;
}

getRight(index: number) {
return 2 * index + 2;
}

getParent(index: number) {
if (index === 0) {
return null;
}
return Math.floor((index - 1) / 2);
}

isEmpty() {
return this.list.length <= 0;
}

top() {
return this.isEmpty() ? undefined : this.list[0];
}

delMin() {
const top = this.top();
const bottom = this.list.pop();
if (this.list.length > 0) {
this.list[0] = bottom;
this.moveDown(0);
}
return top;
}

insert(value: T) {
if (value !== null) {
this.list.push(value);
const index = this.list.length - 1;
this.moveUp(index);
return true;
}
return false;
}

moveUp(index: number) {
let i = index;
let parent = this.getParent(i);
while (i && i > 0 && this.compareFn(this.list[parent], this.list[i]) > 0) {
// swap
const tmp = this.list[parent];
this.list[parent] = this.list[i];
this.list[i] = tmp;
i = parent;
parent = this.getParent(i);
}
}

moveDown(index: number) {
let element = index;
const left = this.getLeft(index);
const right = this.getRight(index);
const size = this.list.length;
if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) {
element = left;
} else if (
right !== null &&
right < size &&
this.compareFn(this.list[element], this.list[right]) > 0
) {
element = right;
}
if (index !== element) {
[this.list[index], this.list[element]] = [this.list[element], this.list[index]];
this.moveDown(element);
}
}
}
44 changes: 44 additions & 0 deletions packages/graph/src/structs/union-find.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* Disjoint set to support quick union
*/
export default class UnionFind {
count: number;

parent: { [key: number | string]: number | string };

constructor(items: (number | string)[]) {
this.count = items.length;
this.parent = {};
for (const i of items) {
this.parent[i] = i;
}
}

// find the root of the item
find(item: (number | string)) {
let resItem = item;
while (this.parent[resItem] !== resItem) {
resItem = this.parent[resItem];
}
return resItem;
}

union(a: (number | string), b: (number | string)) {
const rootA = this.find(a);
const rootB = this.find(b);
if (rootA === rootB) return;
// make the element with smaller root the parent
if (rootA < rootB) {
if (this.parent[b] !== b) this.union(this.parent[b], a);
this.parent[b] = this.parent[a];
} else {
if (this.parent[a] !== a) this.union(this.parent[a], b);
this.parent[a] = this.parent[b];
}
}

// Determine that A and B are connected
connected(a: (number | string), b: (number | string)) {
return this.find(a) === this.find(b);
}
}
Loading