From b3966a295400861f6f7b771c7e9326b2094c3474 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Tue, 12 Sep 2023 18:20:31 +0800 Subject: [PATCH 1/7] feat: v5 algorithm mst --- packages/graph/src/index.ts | 1 + packages/graph/src/mst.ts | 109 ++++++++++++++++++++++ packages/graph/src/structs/binary-heap.ts | 90 ++++++++++++++++++ packages/graph/src/structs/union-find.ts | 45 +++++++++ packages/graph/src/types.ts | 6 ++ packages/graph/src/utils.ts | 15 ++- 6 files changed, 263 insertions(+), 3 deletions(-) create mode 100644 packages/graph/src/mst.ts create mode 100644 packages/graph/src/structs/binary-heap.ts create mode 100644 packages/graph/src/structs/union-find.ts diff --git a/packages/graph/src/index.ts b/packages/graph/src/index.ts index ba47401..81898ea 100644 --- a/packages/graph/src/index.ts +++ b/packages/graph/src/index.ts @@ -9,3 +9,4 @@ export * from './dfs'; export * from './cosine-similarity'; export * from './nodes-cosine-similarity'; export * from './gaddi'; +export * from './mst' \ No newline at end of file diff --git a/packages/graph/src/mst.ts b/packages/graph/src/mst.ts new file mode 100644 index 0000000..109e3ef --- /dev/null +++ b/packages/graph/src/mst.ts @@ -0,0 +1,109 @@ +import UnionFind from './structs/union-find'; +import MinBinaryHeap from './structs/binary-heap'; +import { Graph, IEdge, IMSTAlgorithm, IMSTAlgorithmOpt } from './types'; +import { clone } from '@antv/util'; +import { getEdgesByNodeId } from './utils'; + +/** +Calculates the Minimum Spanning Tree (MST) of a graph using the Prim's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight. +@param graph - The graph for which the MST needs to be calculated. +@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0. +@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph. +*/ +const primMST: IMSTAlgorithm = (graph, weightProps?) => { + const selectedEdges: IEdge[] = []; + const nodes = graph.getAllNodes() + const edges = graph.getAllEdges(); + if (nodes.length === 0) { + return selectedEdges; + } + // From the first node + const currNode = nodes[0]; + const visited = new Set(); + visited.add(currNode); + + // Using binary heap to maintain the weight of edges from other nodes that have joined the node + const compareWeight = (a: IEdge, b: IEdge) => { + if (weightProps) { + a.data + return (a.data[weightProps] as number) - (b.data[weightProps] as number); + } + return 0; + }; + const edgeQueue = new MinBinaryHeap(compareWeight); + getEdgesByNodeId(currNode.id, edges).forEach((edge) => { + edgeQueue.insert(edge); + }); + while (!edgeQueue.isEmpty()) { + // Select the node with the least edge weight between the added node and the added node + const currEdge: IEdge = edgeQueue.delMin(); + const source = currEdge.source; + const target = currEdge.target; + if (visited.has(source) && visited.has(target)) continue; + selectedEdges.push(currEdge); + if (!visited.has(source)) { + visited.add(source); + getEdgesByNodeId(source, edges).forEach((edge) => { + edgeQueue.insert(edge); + }); + } + if (!visited.has(target)) { + visited.add(target); + getEdgesByNodeId(target, edges).forEach((edge) => { + edgeQueue.insert(edge); + }); + } + } + return selectedEdges; +}; + +/** +Calculates the Minimum Spanning Tree (MST) of a graph using the Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight. +@param graph - The graph for which the MST needs to be calculated. +@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0. +@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph. +*/ +const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => { + const selectedEdges: IEdge[] = []; + const nodes = graph.getAllNodes() + const edges = graph.getAllEdges(); + if (nodes.length === 0) { + return selectedEdges; + } + // 若指定weight,则将所有的边按权值从小到大排序 + const weightEdges = clone(edges); + if (weightProps) { + weightEdges.sort((a: IEdge, b: IEdge) => { + return (a.data[weightProps] as number) - (b.data[weightProps] as number); + }); + } + const disjointSet = new UnionFind(nodes.map((n) => n.id)); + + // 从权值最小的边开始,如果这条边连接的两个节点于图G中不在同一个连通分量中,则添加这条边 + // 直到遍历完所有点或边 + while (weightEdges.length > 0) { + const curEdge = weightEdges.shift(); + const source = curEdge.source; + const target = curEdge.target; + if (!disjointSet.connected(source, target)) { + selectedEdges.push(curEdge); + disjointSet.union(source, target); + } + } + return selectedEdges; +}; + +/** +Calculates the Minimum Spanning Tree (MST) of a graph using either Prim's or Kruskal's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight. +@param graph - The graph for which the MST needs to be calculated. +@param weightProps - Optional. The property name in the edge data object that represents the weight of the edge.If provided, the algorithm will consider the weight of edges based on this property.If not provided, the algorithm will assume all edges have a weight of 0. +@param algo - Optional. The algorithm to use for calculating the MST. Can be either 'prim' for Prim's algorithm, 'kruskal' for Kruskal's algorithm, or undefined to use the default algorithm (Kruskal's algorithm). +@returns An array of selected edges that form the Minimum Spanning Tree (MST) of the graph. +*/ +export const minimumSpanningTree = (graph: Graph, weightProps?: string, algo?: 'prim' | 'kruskal' | undefined): IEdge[] => { + const algos: IMSTAlgorithmOpt = { + 'prim': primMST, + 'kruskal': kruskalMST, + }; + return (algo && algos[algo](graph, weightProps)) || kruskalMST(graph, weightProps); +} diff --git a/packages/graph/src/structs/binary-heap.ts b/packages/graph/src/structs/binary-heap.ts new file mode 100644 index 0000000..bf3eb70 --- /dev/null +++ b/packages/graph/src/structs/binary-heap.ts @@ -0,0 +1,90 @@ +const defaultCompare = (a, b) => { + return a - b; +}; + +export default class MinBinaryHeap { + list: any[]; + + compareFn: (a: any, b: any) => number; + + constructor(compareFn = defaultCompare) { + this.compareFn = compareFn; + this.list = []; + } + + getLeft(index) { + return 2 * index + 1; + } + + getRight(index) { + return 2 * index + 2; + } + + getParent(index) { + if (index === 0) { + return null; + } + return Math.floor((index - 1) / 2); + } + + isEmpty() { + return this.list.length <= 0; + } + + top() { + return this.isEmpty() ? undefined : this.list[0]; + } + + delMin() { + const top = this.top(); + const bottom = this.list.pop(); + if (this.list.length > 0) { + this.list[0] = bottom; + this.moveDown(0); + } + return top; + } + + insert(value) { + if (value !== null) { + this.list.push(value); + const index = this.list.length - 1; + this.moveUp(index); + return true; + } + return false; + } + + moveUp(index) { + let parent = this.getParent(index); + while (index && index > 0 && this.compareFn(this.list[parent], this.list[index]) > 0) { + // swap + const tmp = this.list[parent]; + this.list[parent] = this.list[index]; + this.list[index] = tmp; + // [this.list[index], this.list[parent]] = [this.list[parent], this.list[index]] + index = parent; + parent = this.getParent(index); + } + } + + moveDown(index) { + let element = index; + const left = this.getLeft(index); + const right = this.getRight(index); + const size = this.list.length; + if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) { + element = left; + } else if ( + right !== null && + right < size && + this.compareFn(this.list[element], this.list[right]) > 0 + ) { + element = right; + } + if (index !== element) { + [this.list[index], this.list[element]] = [this.list[element], this.list[index]]; + this.moveDown(element); + } + } +} diff --git a/packages/graph/src/structs/union-find.ts b/packages/graph/src/structs/union-find.ts new file mode 100644 index 0000000..940ed6d --- /dev/null +++ b/packages/graph/src/structs/union-find.ts @@ -0,0 +1,45 @@ +/** + * 并查集 Disjoint set to support quick union + */ +export default class UnionFind { + count: number; + + parent: {}; + + constructor(items: (number | string)[]) { + this.count = items.length; + this.parent = {}; + for (const i of items) { + this.parent[i] = i; + } + } + + // find the root of the item + find(item) { + while (this.parent[item] !== item) { + item = this.parent[item]; + } + return item; + } + + union(a, b) { + const rootA = this.find(a); + const rootB = this.find(b); + + if (rootA === rootB) return; + + // make the element with smaller root the parent + if (rootA < rootB) { + if (this.parent[b] !== b) this.union(this.parent[b], a); + this.parent[b] = this.parent[a]; + } else { + if (this.parent[a] !== a) this.union(this.parent[a], b); + this.parent[a] = this.parent[b]; + } + } + + // whether a and b are connected, i.e. a and b have the same root + connected(a, b) { + return this.find(a) === this.find(b); + } +} diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index 5315abd..3e2457b 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -56,3 +56,9 @@ export type GraphData = { export type INode = Node; export type IEdge = Edge; + +export type IMSTAlgorithm = (graph: Graph, weightProps?: string) => IEdge[] +export interface IMSTAlgorithmOpt { + 'prim': IMSTAlgorithm, + 'kruskal': IMSTAlgorithm, +} \ No newline at end of file diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index 78d5103..18a0b58 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -1,5 +1,5 @@ -import { Node, PlainObject } from "@antv/graphlib"; -import { KeyValueMap, NodeData } from "./types"; +import { Edge, Node, PlainObject } from "@antv/graphlib"; +import { KeyValueMap, NodeData, NodeID } from "./types"; import { uniq } from "@antv/util"; export const getAllProperties = (nodes: Node[]) => { @@ -78,4 +78,13 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol oneHotCode[index] = code; }); return oneHotCode; -}; \ No newline at end of file +}; + +/** + * 获取指定节点的边,包括出边和入边 + * @param nodeId 节点 ID + * @param edges 图中的所有边数据 + */ +export const getEdgesByNodeId = (nodeId: NodeID, edges: Edge<{ [key: string]: any }>[]): Edge<{ [key: string]: any }>[] => { + return edges.filter(edge => edge.source === nodeId || edge.target === nodeId) +} From 2a6db0cb326e1209677308fb5ad2fd7cafbb44b1 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Tue, 12 Sep 2023 18:20:50 +0800 Subject: [PATCH 2/7] test: mst unit test --- __tests__/unit/mst.spec.ts | 113 +++++++++++++++++++++++++++++++++++++ package.json | 2 +- 2 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 __tests__/unit/mst.spec.ts diff --git a/__tests__/unit/mst.spec.ts b/__tests__/unit/mst.spec.ts new file mode 100644 index 0000000..492cd1f --- /dev/null +++ b/__tests__/unit/mst.spec.ts @@ -0,0 +1,113 @@ +import { minimumSpanningTree } from "../../packages/graph/src"; +import { Graph } from "@antv/graphlib"; + +const data = { + nodes: [ + { + id: 'A', + data: {}, + }, + { + id: 'B', + data: {}, + }, + { + id: 'C', + data: {}, + }, + { + id: 'D', + data: {}, + }, + { + id: 'E', + data: {}, + }, + { + id: 'F', + data: {}, + }, + { + id: 'G', + data: {}, + }, + ], + edges: [ + { + id: 'edge1', + source: 'A', + target: 'B', + data: { + weight: 1, + } + }, + { + id: 'edge2', + source: 'B', + target: 'C', + data: { + weight: 1, + } + }, + { + id: 'edge3', + source: 'A', + target: 'C', + data: { + weight: 2, + } + }, + { + id: 'edge4', + source: 'D', + target: 'A', + data: { + weight: 3, + } + }, + { + id: 'edge5', + source: 'D', + target: 'E', + data: { + weight: 4, + } + }, + { + id: 'edge6', + source: 'E', + target: 'F', + data: { + weight: 2, + } + }, + { + id: 'edge7', + source: 'F', + target: 'D', + data: { + weight: 3, + } + }, + ], +}; +const graph = new Graph(data); +describe('minimumSpanningTree', () => { + it('test kruskal algorithm', () => { + let result = minimumSpanningTree(graph, 'weight'); + let totalWeight = 0; + for (let edge of result) { + totalWeight += edge.data.weight; + } + expect(totalWeight).toEqual(10); + }); + + it('test prim algorithm', () => { + let result = minimumSpanningTree(graph, 'weight', 'prim'); + let totalWeight = 0; + for (let edge of result) { + totalWeight += edge.data.weight; + } + expect(totalWeight).toEqual(10); + }); +}); diff --git a/package.json b/package.json index 8a6e34a..f3e823a 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ "build:ci": "pnpm -r run build:ci", "prepare": "husky install", "test": "jest", - "test_one": "jest ./__tests__/unit/nodes-cosine-similarity.spec.ts", + "test_one": "jest ./__tests__/unit/mst.spec.ts", "coverage": "jest --coverage", "build:site": "vite build", "deploy": "gh-pages -d site/dist", From b50a385c40530f989bdf7bf01b9b86edfbd46cbb Mon Sep 17 00:00:00 2001 From: zqqcee Date: Wed, 13 Sep 2023 10:41:46 +0800 Subject: [PATCH 3/7] fix: fix lint --- packages/graph/src/cosine-similarity.ts | 2 +- packages/graph/src/index.ts | 2 +- packages/graph/src/mst.ts | 8 ++-- packages/graph/src/nodes-cosine-similarity.ts | 6 +-- packages/graph/src/structs/binary-heap.ts | 38 +++++++++---------- packages/graph/src/structs/union-find.ts | 13 ++++--- packages/graph/src/types.ts | 6 +-- packages/graph/src/utils.ts | 4 +- 8 files changed, 40 insertions(+), 39 deletions(-) diff --git a/packages/graph/src/cosine-similarity.ts b/packages/graph/src/cosine-similarity.ts index 554623b..9f5b5f7 100644 --- a/packages/graph/src/cosine-similarity.ts +++ b/packages/graph/src/cosine-similarity.ts @@ -24,4 +24,4 @@ export const cosineSimilarity = ( // Calculate the cosine similarity between the item vector and the target element vector const cosineSimilarity = norm2Product ? dot / norm2Product : 0; return cosineSimilarity; -} +}; diff --git a/packages/graph/src/index.ts b/packages/graph/src/index.ts index 81898ea..6413f1b 100644 --- a/packages/graph/src/index.ts +++ b/packages/graph/src/index.ts @@ -9,4 +9,4 @@ export * from './dfs'; export * from './cosine-similarity'; export * from './nodes-cosine-similarity'; export * from './gaddi'; -export * from './mst' \ No newline at end of file +export * from './mst'; \ No newline at end of file diff --git a/packages/graph/src/mst.ts b/packages/graph/src/mst.ts index 109e3ef..4424be0 100644 --- a/packages/graph/src/mst.ts +++ b/packages/graph/src/mst.ts @@ -12,7 +12,7 @@ Calculates the Minimum Spanning Tree (MST) of a graph using the Prim's algorithm */ const primMST: IMSTAlgorithm = (graph, weightProps?) => { const selectedEdges: IEdge[] = []; - const nodes = graph.getAllNodes() + const nodes = graph.getAllNodes(); const edges = graph.getAllEdges(); if (nodes.length === 0) { return selectedEdges; @@ -25,7 +25,7 @@ const primMST: IMSTAlgorithm = (graph, weightProps?) => { // Using binary heap to maintain the weight of edges from other nodes that have joined the node const compareWeight = (a: IEdge, b: IEdge) => { if (weightProps) { - a.data + a.data; return (a.data[weightProps] as number) - (b.data[weightProps] as number); } return 0; @@ -65,7 +65,7 @@ Calculates the Minimum Spanning Tree (MST) of a graph using the Kruskal's algori */ const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => { const selectedEdges: IEdge[] = []; - const nodes = graph.getAllNodes() + const nodes = graph.getAllNodes(); const edges = graph.getAllEdges(); if (nodes.length === 0) { return selectedEdges; @@ -106,4 +106,4 @@ export const minimumSpanningTree = (graph: Graph, weightProps?: string, algo?: ' 'kruskal': kruskalMST, }; return (algo && algos[algo](graph, weightProps)) || kruskalMST(graph, weightProps); -} +}; diff --git a/packages/graph/src/nodes-cosine-similarity.ts b/packages/graph/src/nodes-cosine-similarity.ts index 9ef5201..9687326 100644 --- a/packages/graph/src/nodes-cosine-similarity.ts +++ b/packages/graph/src/nodes-cosine-similarity.ts @@ -21,8 +21,8 @@ export const nodesCosineSimilarity = ( allCosineSimilarity: number[], similarNodes: NodeSimilarity[], } => { - const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id)); - const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id); + const similarNodes = clone(nodes.filter((node) => node.id !== seedNode.id)); + const seedNodeIndex = nodes.findIndex((node) => node.id === seedNode.id); // Collection of all node properties const properties = getAllProperties(nodes); // One-hot feature vectors for all node properties @@ -40,4 +40,4 @@ export const nodesCosineSimilarity = ( // Sort the returned nodes according to cosine similarity similarNodes.sort((a: NodeSimilarity, b: NodeSimilarity) => b.data.cosineSimilarity - a.data.cosineSimilarity); return { allCosineSimilarity, similarNodes }; -} +}; diff --git a/packages/graph/src/structs/binary-heap.ts b/packages/graph/src/structs/binary-heap.ts index bf3eb70..efb2730 100644 --- a/packages/graph/src/structs/binary-heap.ts +++ b/packages/graph/src/structs/binary-heap.ts @@ -1,26 +1,25 @@ -const defaultCompare = (a, b) => { - return a - b; -}; export default class MinBinaryHeap { - list: any[]; + list: number[]; compareFn: (a: any, b: any) => number; - constructor(compareFn = defaultCompare) { - this.compareFn = compareFn; + constructor(compareFn: (a: any, b: any) => number) { + this.compareFn = compareFn || ((a: number, b: number) => { + return a - b; + }); this.list = []; } - getLeft(index) { + getLeft(index: number) { return 2 * index + 1; } - getRight(index) { + getRight(index: number) { return 2 * index + 2; } - getParent(index) { + getParent(index: number) { if (index === 0) { return null; } @@ -45,7 +44,7 @@ export default class MinBinaryHeap { return top; } - insert(value) { + insert(value: number) { if (value !== null) { this.list.push(value); const index = this.list.length - 1; @@ -55,20 +54,21 @@ export default class MinBinaryHeap { return false; } - moveUp(index) { - let parent = this.getParent(index); - while (index && index > 0 && this.compareFn(this.list[parent], this.list[index]) > 0) { + moveUp(index: number) { + let i = index; + let parent = this.getParent(i); + while (i && i > 0 && this.compareFn(this.list[i], this.list[i]) > 0) { // swap const tmp = this.list[parent]; - this.list[parent] = this.list[index]; - this.list[index] = tmp; - // [this.list[index], this.list[parent]] = [this.list[parent], this.list[index]] - index = parent; - parent = this.getParent(index); + this.list[parent] = this.list[i]; + this.list[i] = tmp; + // [this.list[i], this.list[parent]] = [this.list[parent], this.list[i]] + i = parent; + parent = this.getParent(i); } } - moveDown(index) { + moveDown(index: number) { let element = index; const left = this.getLeft(index); const right = this.getRight(index); diff --git a/packages/graph/src/structs/union-find.ts b/packages/graph/src/structs/union-find.ts index 940ed6d..6314bb7 100644 --- a/packages/graph/src/structs/union-find.ts +++ b/packages/graph/src/structs/union-find.ts @@ -4,7 +4,7 @@ export default class UnionFind { count: number; - parent: {}; + parent: { [key: number | string]: number | string }; constructor(items: (number | string)[]) { this.count = items.length; @@ -15,14 +15,15 @@ export default class UnionFind { } // find the root of the item - find(item) { + find(item: (number | string)) { + let resItem = item; while (this.parent[item] !== item) { - item = this.parent[item]; + resItem = this.parent[item]; } - return item; + return resItem; } - union(a, b) { + union(a: (number | string), b: (number | string)) { const rootA = this.find(a); const rootB = this.find(b); @@ -39,7 +40,7 @@ export default class UnionFind { } // whether a and b are connected, i.e. a and b have the same root - connected(a, b) { + connected(a: (number | string), b: (number | string)) { return this.find(a) === this.find(b); } } diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index 3e2457b..9c842a3 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -57,8 +57,8 @@ export type GraphData = { export type INode = Node; export type IEdge = Edge; -export type IMSTAlgorithm = (graph: Graph, weightProps?: string) => IEdge[] +export type IMSTAlgorithm = (graph: Graph, weightProps?: string) => IEdge[]; export interface IMSTAlgorithmOpt { - 'prim': IMSTAlgorithm, - 'kruskal': IMSTAlgorithm, + 'prim': IMSTAlgorithm; + 'kruskal': IMSTAlgorithm; } \ No newline at end of file diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index 18a0b58..3aba87e 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -86,5 +86,5 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol * @param edges 图中的所有边数据 */ export const getEdgesByNodeId = (nodeId: NodeID, edges: Edge<{ [key: string]: any }>[]): Edge<{ [key: string]: any }>[] => { - return edges.filter(edge => edge.source === nodeId || edge.target === nodeId) -} + return edges.filter((edge) => edge.source === nodeId || edge.target === nodeId); +}; From 2afef3c99d57f457436d69a41491a89c78ff3533 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Wed, 13 Sep 2023 11:23:13 +0800 Subject: [PATCH 4/7] fix: fix lint --- packages/graph/src/structs/binary-heap.ts | 3 +-- packages/graph/src/structs/union-find.ts | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/graph/src/structs/binary-heap.ts b/packages/graph/src/structs/binary-heap.ts index efb2730..22f682f 100644 --- a/packages/graph/src/structs/binary-heap.ts +++ b/packages/graph/src/structs/binary-heap.ts @@ -57,12 +57,11 @@ export default class MinBinaryHeap { moveUp(index: number) { let i = index; let parent = this.getParent(i); - while (i && i > 0 && this.compareFn(this.list[i], this.list[i]) > 0) { + while (i && i > 0 && this.compareFn(this.list[parent], this.list[i]) > 0) { // swap const tmp = this.list[parent]; this.list[parent] = this.list[i]; this.list[i] = tmp; - // [this.list[i], this.list[parent]] = [this.list[parent], this.list[i]] i = parent; parent = this.getParent(i); } diff --git a/packages/graph/src/structs/union-find.ts b/packages/graph/src/structs/union-find.ts index 6314bb7..d28b97d 100644 --- a/packages/graph/src/structs/union-find.ts +++ b/packages/graph/src/structs/union-find.ts @@ -17,8 +17,8 @@ export default class UnionFind { // find the root of the item find(item: (number | string)) { let resItem = item; - while (this.parent[item] !== item) { - resItem = this.parent[item]; + while (this.parent[resItem] !== resItem) { + resItem = this.parent[resItem]; } return resItem; } From 3cef3aa5327ae3e095f14617d2616d66c4189286 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Wed, 13 Sep 2023 22:26:37 +0800 Subject: [PATCH 5/7] fix: graph get related edge api --- packages/graph/src/mst.ts | 17 +++++++---------- packages/graph/src/structs/binary-heap.ts | 14 ++++++-------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/packages/graph/src/mst.ts b/packages/graph/src/mst.ts index 4424be0..17eae0b 100644 --- a/packages/graph/src/mst.ts +++ b/packages/graph/src/mst.ts @@ -2,7 +2,6 @@ import UnionFind from './structs/union-find'; import MinBinaryHeap from './structs/binary-heap'; import { Graph, IEdge, IMSTAlgorithm, IMSTAlgorithmOpt } from './types'; import { clone } from '@antv/util'; -import { getEdgesByNodeId } from './utils'; /** Calculates the Minimum Spanning Tree (MST) of a graph using the Prim's algorithm.The MST is a subset of edges that forms a tree connecting all nodes with the minimum possible total edge weight. @@ -21,7 +20,6 @@ const primMST: IMSTAlgorithm = (graph, weightProps?) => { const currNode = nodes[0]; const visited = new Set(); visited.add(currNode); - // Using binary heap to maintain the weight of edges from other nodes that have joined the node const compareWeight = (a: IEdge, b: IEdge) => { if (weightProps) { @@ -30,8 +28,9 @@ const primMST: IMSTAlgorithm = (graph, weightProps?) => { } return 0; }; - const edgeQueue = new MinBinaryHeap(compareWeight); - getEdgesByNodeId(currNode.id, edges).forEach((edge) => { + const edgeQueue = new MinBinaryHeap(compareWeight); + + graph.getRelatedEdges(currNode.id, 'both').forEach((edge) => { edgeQueue.insert(edge); }); while (!edgeQueue.isEmpty()) { @@ -43,13 +42,13 @@ const primMST: IMSTAlgorithm = (graph, weightProps?) => { selectedEdges.push(currEdge); if (!visited.has(source)) { visited.add(source); - getEdgesByNodeId(source, edges).forEach((edge) => { + graph.getRelatedEdges(source, 'both').forEach((edge) => { edgeQueue.insert(edge); }); } if (!visited.has(target)) { visited.add(target); - getEdgesByNodeId(target, edges).forEach((edge) => { + graph.getRelatedEdges(target, 'both').forEach((edge) => { edgeQueue.insert(edge); }); } @@ -70,7 +69,7 @@ const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => { if (nodes.length === 0) { return selectedEdges; } - // 若指定weight,则将所有的边按权值从小到大排序 + // If you specify weight, all edges are sorted by weight from smallest to largest const weightEdges = clone(edges); if (weightProps) { weightEdges.sort((a: IEdge, b: IEdge) => { @@ -78,9 +77,7 @@ const kruskalMST: IMSTAlgorithm = (graph, weightProps?) => { }); } const disjointSet = new UnionFind(nodes.map((n) => n.id)); - - // 从权值最小的边开始,如果这条边连接的两个节点于图G中不在同一个连通分量中,则添加这条边 - // 直到遍历完所有点或边 + // Starting with the edge with the least weight, if the two nodes connected by this edge are not in the same connected component in graph G, the edge is added. while (weightEdges.length > 0) { const curEdge = weightEdges.shift(); const source = curEdge.source; diff --git a/packages/graph/src/structs/binary-heap.ts b/packages/graph/src/structs/binary-heap.ts index 22f682f..4637552 100644 --- a/packages/graph/src/structs/binary-heap.ts +++ b/packages/graph/src/structs/binary-heap.ts @@ -1,13 +1,11 @@ -export default class MinBinaryHeap { - list: number[]; +export default class MinBinaryHeap { + list: T[]; - compareFn: (a: any, b: any) => number; + compareFn: (a?: T, b?: T) => number; - constructor(compareFn: (a: any, b: any) => number) { - this.compareFn = compareFn || ((a: number, b: number) => { - return a - b; - }); + constructor(compareFn: (a: T, b: T) => number) { + this.compareFn = compareFn || (() => 0); this.list = []; } @@ -44,7 +42,7 @@ export default class MinBinaryHeap { return top; } - insert(value: number) { + insert(value: T) { if (value !== null) { this.list.push(value); const index = this.list.length - 1; From 4d6991198e05fa105e6a01d2d3fedb61bcc8b0f6 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Wed, 13 Sep 2023 22:30:57 +0800 Subject: [PATCH 6/7] chore: translate chinese annotation --- packages/graph/src/structs/union-find.ts | 6 ++--- packages/graph/src/types.ts | 2 +- packages/graph/src/utils.ts | 28 ++++++++---------------- 3 files changed, 12 insertions(+), 24 deletions(-) diff --git a/packages/graph/src/structs/union-find.ts b/packages/graph/src/structs/union-find.ts index d28b97d..2758b5e 100644 --- a/packages/graph/src/structs/union-find.ts +++ b/packages/graph/src/structs/union-find.ts @@ -1,5 +1,5 @@ /** - * 并查集 Disjoint set to support quick union + * Disjoint set to support quick union */ export default class UnionFind { count: number; @@ -26,9 +26,7 @@ export default class UnionFind { union(a: (number | string), b: (number | string)) { const rootA = this.find(a); const rootB = this.find(b); - if (rootA === rootB) return; - // make the element with smaller root the parent if (rootA < rootB) { if (this.parent[b] !== b) this.union(this.parent[b], a); @@ -39,7 +37,7 @@ export default class UnionFind { } } - // whether a and b are connected, i.e. a and b have the same root + // Determine that A and B are connected connected(a: (number | string), b: (number | string)) { return this.find(a) === this.find(b); } diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index 9c842a3..d4aeeee 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -1,6 +1,6 @@ import { Edge, Graph as IGraph, Node, PlainObject } from '@antv/graphlib'; -// 数据集中属性/特征值分布的map +// Map of attribute / eigenvalue distribution in dataset export interface KeyValueMap { [key: string]: any[]; } diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index 3aba87e..b041be3 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -12,17 +12,17 @@ export const getAllProperties = (nodes: Node[]) => { export const getAllKeyValueMap = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => { let keys: string[] = []; - // 指定了参与计算的keys时,使用指定的keys + // Use the specified keys when the keys participating in the calculation is specified if (involvedKeys?.length) { keys = involvedKeys; } else { - // 未指定抽取的keys时,提取数据中所有的key + // When the extracted keys is not specified, all key in the data is extracted dataList.forEach((data) => { keys = keys.concat(Object.keys(data)); }); keys = uniq(keys); } - // 获取所有值非空的key的value数组 + // Get the value array of all key with non-null values const allKeyValueMap: KeyValueMap = {}; keys.forEach((key) => { const value: unknown[] = []; @@ -40,19 +40,17 @@ export const getAllKeyValueMap = (dataList: PlainObject[], involvedKeys?: string }; export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => { - // 获取数据中所有的属性/特征及其对应的值 + // Get all attributes / features in the data and their corresponding values const allKeyValueMap = getAllKeyValueMap(dataList, involvedKeys, uninvolvedKeys); const oneHotCode: unknown[][] = []; if (!Object.keys(allKeyValueMap).length) { return oneHotCode; } - - // 获取所有的属性/特征值 + // Get all attribute / feature values const allValue = Object.values(allKeyValueMap); - // 是否所有属性/特征的值都是数值型 + // Whether the values of all attributes / features are numerical const isAllNumber = allValue.every((value) => value.every((item) => (typeof (item) === 'number'))); - - // 对数据进行one-hot编码 + // One-hot encode the data dataList.forEach((data, index) => { let code: unknown[] = []; Object.keys(allKeyValueMap).forEach((key) => { @@ -60,11 +58,11 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol const allKeyValue = allKeyValueMap[key]; const valueIndex = allKeyValue.findIndex((value) => keyValue === value); const subCode = []; - // 如果属性/特征所有的值都能转成数值型,不满足分箱,则直接用值(todo: 为了收敛更快,需做归一化处理) + // If all the values of the attribute / feature can be converted to numerical type and do not satisfy the box division, then use the value directly (todo: normalization is needed for faster convergence) if (isAllNumber) { subCode.push(keyValue); } else { - // 进行one-hot编码 + // Encode one-hot for (let i = 0; i < allKeyValue.length; i++) { if (i === valueIndex) { subCode.push(1); @@ -80,11 +78,3 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol return oneHotCode; }; -/** - * 获取指定节点的边,包括出边和入边 - * @param nodeId 节点 ID - * @param edges 图中的所有边数据 - */ -export const getEdgesByNodeId = (nodeId: NodeID, edges: Edge<{ [key: string]: any }>[]): Edge<{ [key: string]: any }>[] => { - return edges.filter((edge) => edge.source === nodeId || edge.target === nodeId); -}; From 1f2cd5c47139460db774a398ba1e2ecb2509b0ec Mon Sep 17 00:00:00 2001 From: zqqcee Date: Wed, 13 Sep 2023 22:33:09 +0800 Subject: [PATCH 7/7] chore: remove the unnecessary var --- packages/graph/src/utils.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index b041be3..4a884fb 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -1,5 +1,5 @@ -import { Edge, Node, PlainObject } from "@antv/graphlib"; -import { KeyValueMap, NodeData, NodeID } from "./types"; +import { Node, PlainObject } from "@antv/graphlib"; +import { KeyValueMap, NodeData } from "./types"; import { uniq } from "@antv/util"; export const getAllProperties = (nodes: Node[]) => {