diff --git a/__tests__/unit/mst.spec.ts b/__tests__/unit/mst.spec.ts new file mode 100644 index 0000000..492cd1f --- /dev/null +++ b/__tests__/unit/mst.spec.ts @@ -0,0 +1,113 @@ +import { minimumSpanningTree } from "../../packages/graph/src"; +import { Graph } from "@antv/graphlib"; + +const data = { + nodes: [ + { + id: 'A', + data: {}, + }, + { + id: 'B', + data: {}, + }, + { + id: 'C', + data: {}, + }, + { + id: 'D', + data: {}, + }, + { + id: 'E', + data: {}, + }, + { + id: 'F', + data: {}, + }, + { + id: 'G', + data: {}, + }, + ], + edges: [ + { + id: 'edge1', + source: 'A', + target: 'B', + data: { + weight: 1, + } + }, + { + id: 'edge2', + source: 'B', + target: 'C', + data: { + weight: 1, + } + }, + { + id: 'edge3', + source: 'A', + target: 'C', + data: { + weight: 2, + } + }, + { + id: 'edge4', + source: 'D', + target: 'A', + data: { + weight: 3, + } + }, + { + id: 'edge5', + source: 'D', + target: 'E', + data: { + weight: 4, + } + }, + { + id: 'edge6', + source: 'E', + target: 'F', + data: { + weight: 2, + } + }, + { + id: 'edge7', + source: 'F', + target: 'D', + data: { + weight: 3, + } + }, + ], +}; +const graph = new Graph(data); +describe('minimumSpanningTree', () => { + it('test kruskal algorithm', () => { + let result = minimumSpanningTree(graph, 'weight'); + let totalWeight = 0; + for (let edge of result) { + totalWeight += edge.data.weight; + } + expect(totalWeight).toEqual(10); + }); + + it('test prim algorithm', () => { + let result = minimumSpanningTree(graph, 'weight', 'prim'); + let totalWeight = 0; + for (let edge of result) { + totalWeight += edge.data.weight; + } + expect(totalWeight).toEqual(10); + }); +}); diff --git a/package.json b/package.json index 8a6e34a..f3e823a 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ "build:ci": "pnpm -r run build:ci", "prepare": "husky install", "test": "jest", - "test_one": "jest ./__tests__/unit/nodes-cosine-similarity.spec.ts", + "test_one": "jest ./__tests__/unit/mst.spec.ts", "coverage": "jest --coverage", "build:site": "vite build", "deploy": "gh-pages -d site/dist", diff --git a/packages/graph/src/cosine-similarity.ts b/packages/graph/src/cosine-similarity.ts index 554623b..9f5b5f7 100644 --- a/packages/graph/src/cosine-similarity.ts +++ b/packages/graph/src/cosine-similarity.ts @@ -24,4 +24,4 @@ export const cosineSimilarity = ( // Calculate the cosine similarity between the item vector and the target element vector const cosineSimilarity = norm2Product ? dot / norm2Product : 0; return cosineSimilarity; -} +}; diff --git a/packages/graph/src/index.ts b/packages/graph/src/index.ts index ba47401..6413f1b 100644 --- a/packages/graph/src/index.ts +++ b/packages/graph/src/index.ts @@ -9,3 +9,4 @@ export * from './dfs'; export * from './cosine-similarity'; export * from './nodes-cosine-similarity'; export * from './gaddi'; +export * from './mst'; \ No newline at end of file diff --git a/packages/graph/src/mst.ts b/packages/graph/src/mst.ts new file mode 100644 index 0000000..17eae0b --- /dev/null +++ b/packages/graph/src/mst.ts @@ -0,0 +1,106 @@ +import UnionFind from './structs/union-find'; +import MinBinaryHeap from './structs/binary-heap'; +import { Graph, IEdge, IMSTAlgorithm, IMSTAlgorithmOpt } from './types'; +import { clone } from '@antv/util'; + +/** +Calculates the Minimum Spanning Tree (MST) of a graph using the Prim's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight. +@param graph - The graph for which the MST needs to be calculated. +@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0. +@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph. +*/ +const primMST: IMSTAlgorithm = (graph, weightProps?) => { + const selectedEdges: IEdge[] = []; + const nodes = graph.getAllNodes(); + const edges = graph.getAllEdges(); + if (nodes.length === 0) { + return selectedEdges; + } + // From the first node + const currNode = nodes[0]; + const visited = new Set(); + visited.add(currNode); + // Using binary heap to maintain the weight of edges from other nodes that have joined the node + const compareWeight = (a: IEdge, b: IEdge) => { + if (weightProps) { + a.data; + return (a.data[weightProps] as number) - (b.data[weightProps] as number); + } + return 0; + }; + const edgeQueue = new MinBinaryHeap(compareWeight); + + graph.getRelatedEdges(currNode.id, 'both').forEach((edge) => { + edgeQueue.insert(edge); + }); + while (!edgeQueue.isEmpty()) { + // Select the node with the least edge weight between the added node and the added node + const currEdge: IEdge = edgeQueue.delMin(); + const source = currEdge.source; + const target = currEdge.target; + if (visited.has(source) && visited.has(target)) continue; + selectedEdges.push(currEdge); + if (!visited.has(source)) { + visited.add(source); + graph.getRelatedEdges(source, 'both').forEach((edge) => { + edgeQueue.insert(edge); + }); + } + if (!visited.has(target)) { + visited.add(target); + graph.getRelatedEdges(target, 'both').forEach((edge) => { + edgeQueue.insert(edge); + }); + } + } + return selectedEdges; +}; + +/** +Calculates the Minimum Spanning Tree (MST) of a graph using the Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight. +@param graph - The graph for which the MST needs to be calculated. +@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0. +@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph. +*/ +const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => { + const selectedEdges: IEdge[] = []; + const nodes = graph.getAllNodes(); + const edges = graph.getAllEdges(); + if (nodes.length === 0) { + return selectedEdges; + } + // If you specify weight, all edges are sorted by weight from smallest to largest + const weightEdges = clone(edges); + if (weightProps) { + weightEdges.sort((a: IEdge, b: IEdge) => { + return (a.data[weightProps] as number) - (b.data[weightProps] as number); + }); + } + const disjointSet = new UnionFind(nodes.map((n) => n.id)); + // Starting with the edge with the least weight, if the two nodes connected by this edge are not in the same connected component in graph G, the edge is added. + while (weightEdges.length > 0) { + const curEdge = weightEdges.shift(); + const source = curEdge.source; + const target = curEdge.target; + if (!disjointSet.connected(source, target)) { + selectedEdges.push(curEdge); + disjointSet.union(source, target); + } + } + return selectedEdges; +}; + +/** +Calculates the Minimum Spanning Tree (MST) of a graph using either Prim's or Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight. +@param graph - The graph for which the MST needs to be calculated. +@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0. +@param algo - Optional. The algorithm to use for calculating the MST. Can be either 'prim' for Prim's algorithm, 'kruskal' for Kruskal's algorithm, or undefined to use the default algorithm (Kruskal's algorithm). +@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph. +*/ +export const minimumSpanningTree = (graph: Graph, weightProps?: string, algo?: 'prim' | 'kruskal' | undefined): IEdge[] => { + const algos: IMSTAlgorithmOpt = { + 'prim': primMST, + 'kruskal': kruskalMST, + }; + return (algo && algos[algo](graph, weightProps)) || kruskalMST(graph, weightProps); +}; diff --git a/packages/graph/src/nodes-cosine-similarity.ts b/packages/graph/src/nodes-cosine-similarity.ts index 9ef5201..9687326 100644 --- a/packages/graph/src/nodes-cosine-similarity.ts +++ b/packages/graph/src/nodes-cosine-similarity.ts @@ -21,8 +21,8 @@ export const nodesCosineSimilarity = ( allCosineSimilarity: number[], similarNodes: NodeSimilarity[], } => { - const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id)); - const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id); + const similarNodes = clone(nodes.filter((node) => node.id !== seedNode.id)); + const seedNodeIndex = nodes.findIndex((node) => node.id === seedNode.id); // Collection of all node properties const properties = getAllProperties(nodes); // One-hot feature vectors for all node properties @@ -40,4 +40,4 @@ export const nodesCosineSimilarity = ( // Sort the returned nodes according to cosine similarity similarNodes.sort((a: NodeSimilarity, b: NodeSimilarity) => b.data.cosineSimilarity - a.data.cosineSimilarity); return { allCosineSimilarity, similarNodes }; -} +}; diff --git a/packages/graph/src/structs/binary-heap.ts b/packages/graph/src/structs/binary-heap.ts new file mode 100644 index 0000000..4637552 --- /dev/null +++ b/packages/graph/src/structs/binary-heap.ts @@ -0,0 +1,87 @@ + +export default class MinBinaryHeap { + list: T[]; + + compareFn: (a?: T, b?: T) => number; + + constructor(compareFn: (a: T, b: T) => number) { + this.compareFn = compareFn || (() => 0); + this.list = []; + } + + getLeft(index: number) { + return 2 * index + 1; + } + + getRight(index: number) { + return 2 * index + 2; + } + + getParent(index: number) { + if (index === 0) { + return null; + } + return Math.floor((index - 1) / 2); + } + + isEmpty() { + return this.list.length <= 0; + } + + top() { + return this.isEmpty() ? undefined : this.list[0]; + } + + delMin() { + const top = this.top(); + const bottom = this.list.pop(); + if (this.list.length > 0) { + this.list[0] = bottom; + this.moveDown(0); + } + return top; + } + + insert(value: T) { + if (value !== null) { + this.list.push(value); + const index = this.list.length - 1; + this.moveUp(index); + return true; + } + return false; + } + + moveUp(index: number) { + let i = index; + let parent = this.getParent(i); + while (i && i > 0 && this.compareFn(this.list[parent], this.list[i]) > 0) { + // swap + const tmp = this.list[parent]; + this.list[parent] = this.list[i]; + this.list[i] = tmp; + i = parent; + parent = this.getParent(i); + } + } + + moveDown(index: number) { + let element = index; + const left = this.getLeft(index); + const right = this.getRight(index); + const size = this.list.length; + if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) { + element = left; + } else if ( + right !== null && + right < size && + this.compareFn(this.list[element], this.list[right]) > 0 + ) { + element = right; + } + if (index !== element) { + [this.list[index], this.list[element]] = [this.list[element], this.list[index]]; + this.moveDown(element); + } + } +} diff --git a/packages/graph/src/structs/union-find.ts b/packages/graph/src/structs/union-find.ts new file mode 100644 index 0000000..2758b5e --- /dev/null +++ b/packages/graph/src/structs/union-find.ts @@ -0,0 +1,44 @@ +/** + * Disjoint set to support quick union + */ +export default class UnionFind { + count: number; + + parent: { [key: number | string]: number | string }; + + constructor(items: (number | string)[]) { + this.count = items.length; + this.parent = {}; + for (const i of items) { + this.parent[i] = i; + } + } + + // find the root of the item + find(item: (number | string)) { + let resItem = item; + while (this.parent[resItem] !== resItem) { + resItem = this.parent[resItem]; + } + return resItem; + } + + union(a: (number | string), b: (number | string)) { + const rootA = this.find(a); + const rootB = this.find(b); + if (rootA === rootB) return; + // make the element with smaller root the parent + if (rootA < rootB) { + if (this.parent[b] !== b) this.union(this.parent[b], a); + this.parent[b] = this.parent[a]; + } else { + if (this.parent[a] !== a) this.union(this.parent[a], b); + this.parent[a] = this.parent[b]; + } + } + + // Determine that A and B are connected + connected(a: (number | string), b: (number | string)) { + return this.find(a) === this.find(b); + } +} diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index 5315abd..d4aeeee 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -1,6 +1,6 @@ import { Edge, Graph as IGraph, Node, PlainObject } from '@antv/graphlib'; -// 数据集中属性/特征值分布的map +// Map of attribute / eigenvalue distribution in dataset export interface KeyValueMap { [key: string]: any[]; } @@ -56,3 +56,9 @@ export type GraphData = { export type INode = Node; export type IEdge = Edge; + +export type IMSTAlgorithm = (graph: Graph, weightProps?: string) => IEdge[]; +export interface IMSTAlgorithmOpt { + 'prim': IMSTAlgorithm; + 'kruskal': IMSTAlgorithm; +} \ No newline at end of file diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index 78d5103..4a884fb 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -12,17 +12,17 @@ export const getAllProperties = (nodes: Node[]) => { export const getAllKeyValueMap = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => { let keys: string[] = []; - // 指定了参与计算的keys时,使用指定的keys + // Use the specified keys when the keys participating in the calculation is specified if (involvedKeys?.length) { keys = involvedKeys; } else { - // 未指定抽取的keys时,提取数据中所有的key + // When the extracted keys is not specified, all key in the data is extracted dataList.forEach((data) => { keys = keys.concat(Object.keys(data)); }); keys = uniq(keys); } - // 获取所有值非空的key的value数组 + // Get the value array of all key with non-null values const allKeyValueMap: KeyValueMap = {}; keys.forEach((key) => { const value: unknown[] = []; @@ -40,19 +40,17 @@ export const getAllKeyValueMap = (dataList: PlainObject[], involvedKeys?: string }; export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => { - // 获取数据中所有的属性/特征及其对应的值 + // Get all attributes / features in the data and their corresponding values const allKeyValueMap = getAllKeyValueMap(dataList, involvedKeys, uninvolvedKeys); const oneHotCode: unknown[][] = []; if (!Object.keys(allKeyValueMap).length) { return oneHotCode; } - - // 获取所有的属性/特征值 + // Get all attribute / feature values const allValue = Object.values(allKeyValueMap); - // 是否所有属性/特征的值都是数值型 + // Whether the values of all attributes / features are numerical const isAllNumber = allValue.every((value) => value.every((item) => (typeof (item) === 'number'))); - - // 对数据进行one-hot编码 + // One-hot encode the data dataList.forEach((data, index) => { let code: unknown[] = []; Object.keys(allKeyValueMap).forEach((key) => { @@ -60,11 +58,11 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol const allKeyValue = allKeyValueMap[key]; const valueIndex = allKeyValue.findIndex((value) => keyValue === value); const subCode = []; - // 如果属性/特征所有的值都能转成数值型,不满足分箱,则直接用值(todo: 为了收敛更快,需做归一化处理) + // If all the values of the attribute / feature can be converted to numerical type and do not satisfy the box division, then use the value directly (todo: normalization is needed for faster convergence) if (isAllNumber) { subCode.push(keyValue); } else { - // 进行one-hot编码 + // Encode one-hot for (let i = 0; i < allKeyValue.length; i++) { if (i === valueIndex) { subCode.push(1); @@ -78,4 +76,5 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol oneHotCode[index] = code; }); return oneHotCode; -}; \ No newline at end of file +}; +