diff --git a/__tests__/unit/louvain.spec.ts b/__tests__/unit/louvain.spec.ts index 9f9185f..639db91 100644 --- a/__tests__/unit/louvain.spec.ts +++ b/__tests__/unit/louvain.spec.ts @@ -1,38 +1,68 @@ -import { Graph } from "@antv/graphlib"; -import { louvain, iLouvain } from "../../packages/graph/src"; -import * as propertiesGraphData from "../data/cluster-origin-properties-data.json"; +import { Graph } from '@antv/graphlib'; +import { louvain, iLouvain } from '../../packages/graph/src'; +import * as propertiesGraphData from '../data/cluster-origin-properties-data.json'; describe('Louvain', () => { it('simple louvain', () => { const graph = new Graph({ nodes: [ - { id: '0', data: {} }, { id: '1', data: {} }, { id: '2', data: {} }, { id: '3', data: {} }, { id: '4', data: {} }, - { id: '5', data: {} }, { id: '6', data: {} }, { id: '7', data: {} }, { id: '8', data: {} }, { id: '9', data: {} }, - { id: '10', data: {} }, { id: '11', data: {} }, { id: '12', data: {} }, { id: '13', data: {} }, { id: '14', data: {} }, + { id: '0', data: {} }, + { id: '1', data: {} }, + { id: '2', data: {} }, + { id: '3', data: {} }, + { id: '4', data: {} }, + { id: '5', data: {} }, + { id: '6', data: {} }, + { id: '7', data: {} }, + { id: '8', data: {} }, + { id: '9', data: {} }, + { id: '10', data: {} }, + { id: '11', data: {} }, + { id: '12', data: {} }, + { id: '13', data: {} }, + { id: '14', data: {} }, ], edges: [ - { id: 'e1', source: '0', target: '1', data: {} }, { id: 'e2', source: '0', target: '2', data: {} }, { id: 'e3', source: '0', target: '3', data: {} }, { id: 'e4', source: '0', target: '4', data: {} }, - { id: 'e5', source: '1', target: '2', data: {} }, { id: 'e6', source: '1', target: '3', data: {} }, { id: 'e7', source: '1', target: '4', data: {} }, - { id: 'e8', source: '2', target: '3', data: {} }, { id: 'e9', source: '2', target: '4', data: {} }, + { id: 'e1', source: '0', target: '1', data: {} }, + { id: 'e2', source: '0', target: '2', data: {} }, + { id: 'e3', source: '0', target: '3', data: {} }, + { id: 'e4', source: '0', target: '4', data: {} }, + { id: 'e5', source: '1', target: '2', data: {} }, + { id: 'e6', source: '1', target: '3', data: {} }, + { id: 'e7', source: '1', target: '4', data: {} }, + { id: 'e8', source: '2', target: '3', data: {} }, + { id: 'e9', source: '2', target: '4', data: {} }, { id: 'e10', source: '3', target: '4', data: {} }, { id: 'e11', source: '0', target: '0', data: {} }, { id: 'e12', source: '0', target: '0', data: {} }, { id: 'e13', source: '0', target: '0', data: {} }, - - { id: 'e14', source: '5', target: '6', data: {weight: 5} }, { id: 'e15', source: '5', target: '7', data: {} }, { id: 'e16', source: '5', target: '8', data: {} }, { id: 'e17', source: '5', target: '9', data: {} }, - { id: 'e18', source: '6', target: '7', data: {} }, { id: 'e19', source: '6', target: '8', data: {} }, { id: 'e20', source: '6', target: '9', data: {} }, - { id: 'e21', source: '7', target: '8', data: {} }, { id: 'e22', source: '7', target: '9', data: {} }, - { id: 'e23',source: '8', target: '9', data: {} }, - - { id: 'e24',source: '10', target: '11', data: {} }, { id: 'e25',source: '10', target: '12', data: {} }, { id: 'e26',source: '10', target: '13', data: {} }, { id: 'e27',source: '10', target: '14', data: {} }, - { id: 'e28',source: '11', target: '12', data: {} }, { id: 'e29',source: '11', target: '13', data: {} }, { id: 'e30',source: '11', target: '14', data: {} }, - { id: 'e31',source: '12', target: '13', data: {} }, { id: 'e32',source: '12', target: '14', data: {} }, - { id: 'e33',source: '13', target: '14', data: { weight: 5 } }, - - { id: 'e34',source: '0', target: '5', data: {}}, - { id: 'e35',source: '5', target: '10', data: {} }, - { id: 'e36',source: '10', target: '0', data: {} }, - { id: 'e37',source: '10', target: '0', data: {} }, + + { id: 'e14', source: '5', target: '6', data: { weight: 5 } }, + { id: 'e15', source: '5', target: '7', data: {} }, + { id: 'e16', source: '5', target: '8', data: {} }, + { id: 'e17', source: '5', target: '9', data: {} }, + { id: 'e18', source: '6', target: '7', data: {} }, + { id: 'e19', source: '6', target: '8', data: {} }, + { id: 'e20', source: '6', target: '9', data: {} }, + { id: 'e21', source: '7', target: '8', data: {} }, + { id: 'e22', source: '7', target: '9', data: {} }, + { id: 'e23', source: '8', target: '9', data: {} }, + + { id: 'e24', source: '10', target: '11', data: {} }, + { id: 'e25', source: '10', target: '12', data: {} }, + { id: 'e26', source: '10', target: '13', data: {} }, + { id: 'e27', source: '10', target: '14', data: {} }, + { id: 'e28', source: '11', target: '12', data: {} }, + { id: 'e29', source: '11', target: '13', data: {} }, + { id: 'e30', source: '11', target: '14', data: {} }, + { id: 'e31', source: '12', target: '13', data: {} }, + { id: 'e32', source: '12', target: '14', data: {} }, + { id: 'e33', source: '13', target: '14', data: { weight: 5 } }, + + { id: 'e34', source: '0', target: '5', data: {} }, + { id: 'e35', source: '5', target: '10', data: {} }, + { id: 'e36', source: '10', target: '0', data: {} }, + { id: 'e37', source: '10', target: '0', data: {} }, ], }); const clusteredData = louvain(graph, false, 'weight'); @@ -64,4 +94,4 @@ describe('Louvain', () => { expect(clusteredData.clusters[2].sumTot).toBe(4); expect(clusteredData.clusterEdges.length).toBe(7); }); -}); \ No newline at end of file +}); diff --git a/__tests__/utils/data.ts b/__tests__/utils/data.ts index 246c69a..5d33b89 100644 --- a/__tests__/utils/data.ts +++ b/__tests__/utils/data.ts @@ -1,19 +1,23 @@ -import { NodeID, INode, IEdge } from "../../packages/graph/src/types"; +import { ID } from '@antv/graphlib'; +import { INode, IEdge } from '../../packages/graph/src/types'; /** * Convert the old version of the data format to the new version * @param data old data * @return {{nodes:INode[],edges:IEdge[]}} new data */ -export const dataTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => { - const { nodes, edges } = data; - return { - nodes: nodes.map((n) => { - const { id, ...rest } = n; - return { id, data: rest ? rest : {} }; - }), - edges: edges.map((e, i) => { - const { id, source, target, ...rest } = e; - return { id: id ? id : `edge-${i}`, target, source, data: rest }; - }), - }; +export const dataTransformer = (data: { + nodes: { id: ID; [key: string]: any }[]; + edges: { source: ID; target: ID; [key: string]: any }[]; +}): { nodes: INode[]; edges: IEdge[] } => { + const { nodes, edges } = data; + return { + nodes: nodes.map((n) => { + const { id, ...rest } = n; + return { id, data: rest ? rest : {} }; + }), + edges: edges.map((e, i) => { + const { id, source, target, ...rest } = e; + return { id: id ? id : `edge-${i}`, target, source, data: rest }; + }), + }; }; diff --git a/packages/graph/src/bfs.ts b/packages/graph/src/bfs.ts index 209b8e1..94e8931 100644 --- a/packages/graph/src/bfs.ts +++ b/packages/graph/src/bfs.ts @@ -1,5 +1,6 @@ +import { ID } from '@antv/graphlib'; import Queue from './structs/queue'; -import { Graph, IAlgorithmCallbacks, NodeID } from './types'; +import { Graph, IAlgorithmCallbacks } from './types'; /** * @param startNodeId The ID of the bfs traverse starting node. @@ -8,11 +9,14 @@ import { Graph, IAlgorithmCallbacks, NodeID } from './types'; - enterNode: Called when BFS visits a node. - leaveNode: Called after BFS visits the node. */ -function initCallbacks(callbacks: IAlgorithmCallbacks = {} as IAlgorithmCallbacks) { +function initCallbacks( + callbacks: IAlgorithmCallbacks = {} as IAlgorithmCallbacks +) { const initiatedCallback = callbacks; - const stubCallback = () => { }; + const stubCallback = () => {}; const allowTraversalCallback = () => true; - initiatedCallback.allowTraversal = callbacks.allowTraversal || allowTraversalCallback; + initiatedCallback.allowTraversal = + callbacks.allowTraversal || allowTraversalCallback; initiatedCallback.enter = callbacks.enter || stubCallback; initiatedCallback.leave = callbacks.leave || stubCallback; return initiatedCallback; @@ -26,19 +30,19 @@ Performs breadth-first search (BFS) traversal on a graph. */ export const breadthFirstSearch = ( graph: Graph, - startNodeId: NodeID, - originalCallbacks?: IAlgorithmCallbacks, + startNodeId: ID, + originalCallbacks?: IAlgorithmCallbacks ) => { - const visit = new Set(); + const visit = new Set(); const callbacks = initCallbacks(originalCallbacks); - const nodeQueue = new Queue(); + const nodeQueue = new Queue(); // init Queue. Enqueue node ID. nodeQueue.enqueue(startNodeId); visit.add(startNodeId); - let previousNodeId: NodeID = ''; + let previousNodeId: ID = ''; // 遍历队列中的所有顶点 while (!nodeQueue.isEmpty()) { - const currentNodeId: NodeID = nodeQueue.dequeue(); + const currentNodeId: ID = nodeQueue.dequeue(); callbacks.enter({ current: currentNodeId, previous: previousNodeId, @@ -52,7 +56,8 @@ export const breadthFirstSearch = ( previous: previousNodeId, current: currentNodeId, next: nextNodeId, - }) && !visit.has(nextNodeId) + }) && + !visit.has(nextNodeId) ) { visit.add(nextNodeId); nodeQueue.enqueue(nextNodeId); diff --git a/packages/graph/src/connected-component.ts b/packages/graph/src/connected-component.ts index ce00e0f..32476d4 100644 --- a/packages/graph/src/connected-component.ts +++ b/packages/graph/src/connected-component.ts @@ -1,4 +1,5 @@ -import { Graph, INode, NodeID } from './types'; +import { ID } from '@antv/graphlib'; +import { Graph, INode } from './types'; /** * Generate all connected components for an undirected graph * @param graph @@ -6,7 +7,7 @@ import { Graph, INode, NodeID } from './types'; export const detectConnectedComponents = (graph: Graph): INode[][] => { const nodes = graph.getAllNodes(); const allComponents: INode[][] = []; - const visited: { [key: NodeID]: boolean } = {}; + const visited: { [key: ID]: boolean } = {}; const nodeStack: INode[] = []; const getComponent = (node: INode) => { nodeStack.push(node); @@ -49,9 +50,9 @@ export const detectStrongConnectComponents = (graph: Graph): INode[][] => { const nodes = graph.getAllNodes(); const nodeStack: INode[] = []; // Assist to determine whether it is already in the stack to reduce the search overhead - const inStack: { [key: NodeID]: boolean } = {}; - const indices: { [key: NodeID]: number } = {}; - const lowLink: { [key: NodeID]: number } = {}; + const inStack: { [key: ID]: boolean } = {}; + const indices: { [key: ID]: number } = {}; + const lowLink: { [key: ID]: number } = {}; const allComponents: INode[][] = []; let index = 0; const getComponent = (node: INode) => { @@ -61,7 +62,7 @@ export const detectStrongConnectComponents = (graph: Graph): INode[][] => { index += 1; nodeStack.push(node); inStack[node.id] = true; - const relatedEdges = graph.getRelatedEdges(node.id, "out"); + const relatedEdges = graph.getRelatedEdges(node.id, 'out'); for (let i = 0; i < relatedEdges.length; i++) { const targetNodeID = relatedEdges[i].target; if (!indices[targetNodeID] && indices[targetNodeID] !== 0) { @@ -98,7 +99,10 @@ export const detectStrongConnectComponents = (graph: Graph): INode[][] => { return allComponents; }; -export function getConnectedComponents(graph: Graph, directed?: boolean): INode[][] { +export function getConnectedComponents( + graph: Graph, + directed?: boolean +): INode[][] { if (directed) return detectStrongConnectComponents(graph); return detectConnectedComponents(graph); } diff --git a/packages/graph/src/detect-cycle.ts b/packages/graph/src/detect-cycle.ts index 0af03f8..dbd3f59 100644 --- a/packages/graph/src/detect-cycle.ts +++ b/packages/graph/src/detect-cycle.ts @@ -1,7 +1,10 @@ +import { ID, Node } from '@antv/graphlib'; import { depthFirstSearch } from './dfs'; -import { getConnectedComponents, detectStrongConnectComponents } from './connected-component'; -import { Graph, IAlgorithmCallbacks, INode, NodeData, NodeID } from './types'; -import { Node } from '@antv/graphlib'; +import { + getConnectedComponents, + detectStrongConnectComponents, +} from './connected-component'; +import { Graph, IAlgorithmCallbacks, INode, NodeData } from './types'; /** * Detects a directed cycle in a graph. @@ -9,58 +12,60 @@ import { Node } from '@antv/graphlib'; * @param graph The graph to detect the directed cycle in. * @returns An object representing the detected directed cycle, where each key-value pair represents a node ID and its parent node ID in the cycle. */ -export const detectDirectedCycle = (graph: Graph): { - [key: NodeID]: NodeID; +export const detectDirectedCycle = ( + graph: Graph +): { + [key: ID]: ID; } => { - let cycle: { - [key: NodeID]: NodeID; - } = null; - const nodes = graph.getAllNodes(); - const dfsParentMap: { [key: NodeID]: NodeID } = {}; - // The set of all nodes that are not being accessed - const unvisitedSet: { [key: NodeID]: Node } = {}; - // The set of nodes being accessed - const visitingSet: { [key: NodeID]: NodeID } = {}; - // The set of all nodes that have been accessed - const visitedSet: { [key: NodeID]: NodeID } = {}; - // init unvisitedSet - nodes.forEach((node) => { - unvisitedSet[node.id] = node; - }); - const callbacks: IAlgorithmCallbacks = { - enter: ({ current: currentNodeId, previous: previousNodeId }) => { - if (visitingSet[currentNodeId]) { - // 如果当前节点正在访问中,则说明检测到环路了 - cycle = {}; - let currentCycleNodeId = currentNodeId; - let previousCycleNodeId = previousNodeId; - while (previousCycleNodeId !== currentNodeId) { - cycle[currentCycleNodeId] = previousCycleNodeId; - currentCycleNodeId = previousCycleNodeId; - previousCycleNodeId = dfsParentMap[previousCycleNodeId]; - } - cycle[currentCycleNodeId] = previousCycleNodeId; - } else { - visitingSet[currentNodeId] = currentNodeId; - delete unvisitedSet[currentNodeId]; - dfsParentMap[currentNodeId] = previousNodeId; - } - }, - leave: ({ current: currentNodeId }) => { - visitedSet[currentNodeId] = currentNodeId; - delete visitingSet[currentNodeId]; - }, - allowTraversal: () => { - if (cycle) { - return false; - } - return true; - }, - }; - for (let key of Object.keys(unvisitedSet)) { - depthFirstSearch(graph, key, callbacks, true, false); - } - return cycle; + let cycle: { + [key: ID]: ID; + } = null; + const nodes = graph.getAllNodes(); + const dfsParentMap: { [key: ID]: ID } = {}; + // The set of all nodes that are not being accessed + const unvisitedSet: { [key: ID]: Node } = {}; + // The set of nodes being accessed + const visitingSet: { [key: ID]: ID } = {}; + // The set of all nodes that have been accessed + const visitedSet: { [key: ID]: ID } = {}; + // init unvisitedSet + nodes.forEach((node) => { + unvisitedSet[node.id] = node; + }); + const callbacks: IAlgorithmCallbacks = { + enter: ({ current: currentNodeId, previous: previousNodeId }) => { + if (visitingSet[currentNodeId]) { + // 如果当前节点正在访问中,则说明检测到环路了 + cycle = {}; + let currentCycleNodeId = currentNodeId; + let previousCycleNodeId = previousNodeId; + while (previousCycleNodeId !== currentNodeId) { + cycle[currentCycleNodeId] = previousCycleNodeId; + currentCycleNodeId = previousCycleNodeId; + previousCycleNodeId = dfsParentMap[previousCycleNodeId]; + } + cycle[currentCycleNodeId] = previousCycleNodeId; + } else { + visitingSet[currentNodeId] = currentNodeId; + delete unvisitedSet[currentNodeId]; + dfsParentMap[currentNodeId] = previousNodeId; + } + }, + leave: ({ current: currentNodeId }) => { + visitedSet[currentNodeId] = currentNodeId; + delete visitingSet[currentNodeId]; + }, + allowTraversal: () => { + if (cycle) { + return false; + } + return true; + }, + }; + for (let key of Object.keys(unvisitedSet)) { + depthFirstSearch(graph, key, callbacks, true, false); + } + return cycle; }; /** @@ -70,71 +75,81 @@ export const detectDirectedCycle = (graph: Graph): { * @param include Specifies whether the filtered cycles should be included (true) or excluded (false). * @returns An array of objects representing the detected cycles in the graph. */ -export const detectAllUndirectedCycle = (graph: Graph, nodeIds?: NodeID[], include = true) => { - const allCycles: { [key: NodeID]: INode }[] = []; - const components = getConnectedComponents(graph, false); - // loop through all connected components - for (const component of components) { - if (!component.length) continue; - const root = component[0]; - const rootId = root.id; - const stack = [root]; - const parent = { [rootId]: root }; - const used = { [rootId]: new Set() }; - // walk a spanning tree to find cycles - while (stack.length > 0) { - const curNode = stack.pop(); - const curNodeId = curNode.id; - const neighbors = graph.getNeighbors(curNodeId); - // const neighbors = getNeighbors(curNodeId, graphData.edges); - for (let i = 0; i < neighbors.length; i += 1) { - const neighborId = neighbors[i].id; - const neighbor = graph.getAllNodes().find(node => node.id === neighborId); - if (neighborId === curNodeId) { - allCycles.push({ [neighborId]: curNode }); - } else if (!(neighborId in used)) { - // visit a new node - parent[neighborId] = curNode; - stack.push(neighbor); - used[neighborId] = new Set([curNode]); - } else if (!used[curNodeId].has(neighbor)) { - // a cycle found - let cycleValid = true; - const cyclePath = [neighbor, curNode]; - let p = parent[curNodeId]; - while (used[neighborId].size && !used[neighborId].has(p)) { - cyclePath.push(p); - if (p === parent[p.id]) break; - else p = parent[p.id]; - } - cyclePath.push(p); - if (nodeIds && include) { - cycleValid = false; - if (cyclePath.findIndex((node) => nodeIds.indexOf(node.id) > -1) > -1) { - cycleValid = true; - } - } else if (nodeIds && !include) { - if (cyclePath.findIndex((node) => nodeIds.indexOf(node.id) > -1) > -1) { - cycleValid = false; - } - } - // Format node list to cycle - if (cycleValid) { - const cycle: { [key: NodeID]: INode } = {}; - for (let index = 1; index < cyclePath.length; index += 1) { - cycle[cyclePath[index - 1].id] = cyclePath[index]; - } - if (cyclePath.length) { - cycle[cyclePath[cyclePath.length - 1].id] = cyclePath[0]; - } - allCycles.push(cycle); - } - used[neighborId].add(curNode); - } +export const detectAllUndirectedCycle = ( + graph: Graph, + nodeIds?: ID[], + include = true +) => { + const allCycles: { [key: ID]: INode }[] = []; + const components = getConnectedComponents(graph, false); + // loop through all connected components + for (const component of components) { + if (!component.length) continue; + const root = component[0]; + const rootId = root.id; + const stack = [root]; + const parent = { [rootId]: root }; + const used = { [rootId]: new Set() }; + // walk a spanning tree to find cycles + while (stack.length > 0) { + const curNode = stack.pop(); + const curNodeId = curNode.id; + const neighbors = graph.getNeighbors(curNodeId); + // const neighbors = getNeighbors(curNodeId, graphData.edges); + for (let i = 0; i < neighbors.length; i += 1) { + const neighborId = neighbors[i].id; + const neighbor = graph + .getAllNodes() + .find((node) => node.id === neighborId); + if (neighborId === curNodeId) { + allCycles.push({ [neighborId]: curNode }); + } else if (!(neighborId in used)) { + // visit a new node + parent[neighborId] = curNode; + stack.push(neighbor); + used[neighborId] = new Set([curNode]); + } else if (!used[curNodeId].has(neighbor)) { + // a cycle found + let cycleValid = true; + const cyclePath = [neighbor, curNode]; + let p = parent[curNodeId]; + while (used[neighborId].size && !used[neighborId].has(p)) { + cyclePath.push(p); + if (p === parent[p.id]) break; + else p = parent[p.id]; + } + cyclePath.push(p); + if (nodeIds && include) { + cycleValid = false; + if ( + cyclePath.findIndex((node) => nodeIds.indexOf(node.id) > -1) > -1 + ) { + cycleValid = true; + } + } else if (nodeIds && !include) { + if ( + cyclePath.findIndex((node) => nodeIds.indexOf(node.id) > -1) > -1 + ) { + cycleValid = false; + } + } + // Format node list to cycle + if (cycleValid) { + const cycle: { [key: ID]: INode } = {}; + for (let index = 1; index < cyclePath.length; index += 1) { + cycle[cyclePath[index - 1].id] = cyclePath[index]; } + if (cyclePath.length) { + cycle[cyclePath[cyclePath.length - 1].id] = cyclePath[0]; + } + allCycles.push(cycle); + } + used[neighborId].add(curNode); } + } } - return allCycles; + } + return allCycles; }; /** @@ -144,147 +159,163 @@ export const detectAllUndirectedCycle = (graph: Graph, nodeIds?: NodeID[], inclu * @param include Specifies whether the filtered cycles should be included (true) or excluded (false). * @returns An array of objects representing the detected cycles in the graph. */ -export const detectAllDirectedCycle = (graph: Graph, nodeIds?: NodeID[], include = true) => { - const path: INode[] = []; // stack of nodes in current pate - const blocked = new Set(); - const B: { [key: NodeID]: Set } = {}; // remember portions of the graph that yield no elementary circuit - const allCycles: { [key: NodeID]: INode }[] = []; - const idx2Node: { - [key: number]: INode; - } = {}; - const node2Idx: { [key: NodeID]: number } = {}; - // unblock all blocked nodes - const unblock = (thisNode: INode) => { - const stack = [thisNode]; - while (stack.length > 0) { - const node = stack.pop(); - if (blocked.has(node)) { - blocked.delete(node); - B[node.id].forEach((n) => { - stack.push(n); - }); - B[node.id].clear(); - } - } - }; +export const detectAllDirectedCycle = ( + graph: Graph, + nodeIds?: ID[], + include = true +) => { + const path: INode[] = []; // stack of nodes in current pate + const blocked = new Set(); + const B: { [key: ID]: Set } = {}; // remember portions of the graph that yield no elementary circuit + const allCycles: { [key: ID]: INode }[] = []; + const idx2Node: { + [key: number]: INode; + } = {}; + const node2Idx: { [key: ID]: number } = {}; + // unblock all blocked nodes + const unblock = (thisNode: INode) => { + const stack = [thisNode]; + while (stack.length > 0) { + const node = stack.pop(); + if (blocked.has(node)) { + blocked.delete(node); + B[node.id].forEach((n) => { + stack.push(n); + }); + B[node.id].clear(); + } + } + }; - const circuit = (node: INode, start: INode, adjList: { [key: NodeID]: number[] }) => { - let closed = false; // whether a path is closed - if (nodeIds && include === false && nodeIds.indexOf(node.id) > -1) return closed; - path.push(node); - blocked.add(node); - const neighbors = adjList[node.id]; - for (let i = 0; i < neighbors.length; i += 1) { - const neighbor = idx2Node[neighbors[i]]; - if (neighbor === start) { - const cycle: { [key: NodeID]: INode } = {}; - for (let index = 1; index < path.length; index += 1) { - cycle[path[index - 1].id] = path[index]; - } - if (path.length) { - cycle[path[path.length - 1].id] = path[0]; - } - allCycles.push(cycle); - closed = true; - } else if (!blocked.has(neighbor)) { - if (circuit(neighbor, start, adjList)) { - closed = true; - } - } + const circuit = ( + node: INode, + start: INode, + adjList: { [key: ID]: number[] } + ) => { + let closed = false; // whether a path is closed + if (nodeIds && include === false && nodeIds.indexOf(node.id) > -1) + return closed; + path.push(node); + blocked.add(node); + const neighbors = adjList[node.id]; + for (let i = 0; i < neighbors.length; i += 1) { + const neighbor = idx2Node[neighbors[i]]; + if (neighbor === start) { + const cycle: { [key: ID]: INode } = {}; + for (let index = 1; index < path.length; index += 1) { + cycle[path[index - 1].id] = path[index]; } - if (closed) { - unblock(node); - } else { - for (let i = 0; i < neighbors.length; i += 1) { - const neighbor = idx2Node[neighbors[i]]; - if (!B[neighbor.id].has(node)) { - B[neighbor.id].add(node); - } - } + if (path.length) { + cycle[path[path.length - 1].id] = path[0]; } - path.pop(); - return closed; - }; - - const nodes = graph.getAllNodes(); - - // Johnson's algorithm, sort nodes - for (let i = 0; i < nodes.length; i += 1) { - const node = nodes[i]; - const nodeId = node.id; - node2Idx[nodeId] = i; - idx2Node[i] = node; + allCycles.push(cycle); + closed = true; + } else if (!blocked.has(neighbor)) { + if (circuit(neighbor, start, adjList)) { + closed = true; + } + } } - // If there are specified included nodes, the specified nodes are sorted first in order to end the search early - if (nodeIds && include) { - for (let i = 0; i < nodeIds.length; i++) { - const nodeId = nodeIds[i]; - node2Idx[nodes[i].id] = node2Idx[nodeId]; - node2Idx[nodeId] = 0; - idx2Node[0] = nodes.find(node => node.id === nodeId); - idx2Node[node2Idx[nodes[i].id]] = nodes[i]; + if (closed) { + unblock(node); + } else { + for (let i = 0; i < neighbors.length; i += 1) { + const neighbor = idx2Node[neighbors[i]]; + if (!B[neighbor.id].has(node)) { + B[neighbor.id].add(node); } + } } + path.pop(); + return closed; + }; - // Returns the adjList of the strongly connected component of the node (order > = nodeOrder) - const getMinComponentAdj = (components: INode[][]) => { - let minCompIdx; - let minIdx = Infinity; - // Find least component and the lowest node - for (let i = 0; i < components.length; i += 1) { - const comp = components[i]; - for (let j = 0; j < comp.length; j++) { - const nodeIdx = node2Idx[comp[j].id]; - if (nodeIdx < minIdx) { - minIdx = nodeIdx; - minCompIdx = i; - } - } + const nodes = graph.getAllNodes(); + + // Johnson's algorithm, sort nodes + for (let i = 0; i < nodes.length; i += 1) { + const node = nodes[i]; + const nodeId = node.id; + node2Idx[nodeId] = i; + idx2Node[i] = node; + } + // If there are specified included nodes, the specified nodes are sorted first in order to end the search early + if (nodeIds && include) { + for (let i = 0; i < nodeIds.length; i++) { + const nodeId = nodeIds[i]; + node2Idx[nodes[i].id] = node2Idx[nodeId]; + node2Idx[nodeId] = 0; + idx2Node[0] = nodes.find((node) => node.id === nodeId); + idx2Node[node2Idx[nodes[i].id]] = nodes[i]; + } + } + + // Returns the adjList of the strongly connected component of the node (order > = nodeOrder) + const getMinComponentAdj = (components: INode[][]) => { + let minCompIdx; + let minIdx = Infinity; + // Find least component and the lowest node + for (let i = 0; i < components.length; i += 1) { + const comp = components[i]; + for (let j = 0; j < comp.length; j++) { + const nodeIdx = node2Idx[comp[j].id]; + if (nodeIdx < minIdx) { + minIdx = nodeIdx; + minCompIdx = i; } - const component = components[minCompIdx]; - const adjList: { [key: NodeID]: number[] } = {}; - for (let i = 0; i < component.length; i += 1) { - const node = component[i]; - adjList[node.id] = []; - for (const neighbor of graph.getRelatedEdges(node.id, "out").map(n => n.target).filter((n) => component.map(c => c.id).indexOf(n) > -1)) { - // 对自环情况 (点连向自身) 特殊处理:记录自环,但不加入adjList - if (neighbor === node.id && !(include === false && nodeIds.indexOf(node.id) > -1)) { - allCycles.push({ [node.id]: node }); - } else { - adjList[node.id].push(node2Idx[neighbor]); - } - } + } + } + const component = components[minCompIdx]; + const adjList: { [key: ID]: number[] } = {}; + for (let i = 0; i < component.length; i += 1) { + const node = component[i]; + adjList[node.id] = []; + for (const neighbor of graph + .getRelatedEdges(node.id, 'out') + .map((n) => n.target) + .filter((n) => component.map((c) => c.id).indexOf(n) > -1)) { + // 对自环情况 (点连向自身) 特殊处理:记录自环,但不加入adjList + if ( + neighbor === node.id && + !(include === false && nodeIds.indexOf(node.id) > -1) + ) { + allCycles.push({ [node.id]: node }); + } else { + adjList[node.id].push(node2Idx[neighbor]); } - return { - component, - adjList, - minIdx, - }; + } + } + return { + component, + adjList, + minIdx, }; + }; - let nodeIdx = 0; - while (nodeIdx < nodes.length) { - const sccs = detectStrongConnectComponents(graph).filter( - (component) => component.length > 1, - ); - if (sccs.length === 0) break; - const scc = getMinComponentAdj(sccs); - const { minIdx, adjList, component } = scc; - if (component.length > 1) { - component.forEach((node) => { - B[node.id] = new Set(); - }); - const startNode = idx2Node[minIdx]; - // StartNode is not in the specified node to include. End the search ahead of time. - if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1) return allCycles; - circuit(startNode, startNode, adjList); - nodeIdx = minIdx + 1; - } else { - break; - } - break; + let nodeIdx = 0; + while (nodeIdx < nodes.length) { + const sccs = detectStrongConnectComponents(graph).filter( + (component) => component.length > 1 + ); + if (sccs.length === 0) break; + const scc = getMinComponentAdj(sccs); + const { minIdx, adjList, component } = scc; + if (component.length > 1) { + component.forEach((node) => { + B[node.id] = new Set(); + }); + const startNode = idx2Node[minIdx]; + // StartNode is not in the specified node to include. End the search ahead of time. + if (nodeIds && include && nodeIds.indexOf(startNode.id) === -1) + return allCycles; + circuit(startNode, startNode, adjList); + nodeIdx = minIdx + 1; + } else { + break; } - return allCycles; + break; + } + return allCycles; }; /** @@ -296,12 +327,11 @@ export const detectAllDirectedCycle = (graph: Graph, nodeIds?: NodeID[], include * @returns An array of objects representing the detected cycles in the graph. */ export const detectAllCycles = ( - graph: Graph, - directed?: boolean, - nodeIds?: string[], - include = true, + graph: Graph, + directed?: boolean, + nodeIds?: string[], + include = true ) => { - if (directed) return detectAllDirectedCycle(graph, nodeIds, include); - return detectAllUndirectedCycle(graph, nodeIds, include); + if (directed) return detectAllDirectedCycle(graph, nodeIds, include); + return detectAllUndirectedCycle(graph, nodeIds, include); }; - diff --git a/packages/graph/src/dfs.ts b/packages/graph/src/dfs.ts index 723f44c..a53896b 100644 --- a/packages/graph/src/dfs.ts +++ b/packages/graph/src/dfs.ts @@ -1,15 +1,19 @@ -import { Graph, IAlgorithmCallbacks, NodeID } from './types'; +import { ID } from '@antv/graphlib'; +import { Graph, IAlgorithmCallbacks } from './types'; /** * Initializes the callback functions for the depth-first search algorithm. * @param callbacks (Optional) The original callbacks object containing custom callback functions. * @returns The initialized callbacks object. */ -function initCallbacks(callbacks: IAlgorithmCallbacks = {} as IAlgorithmCallbacks) { +function initCallbacks( + callbacks: IAlgorithmCallbacks = {} as IAlgorithmCallbacks +) { const initiatedCallback = callbacks; - const stubCallback = () => { }; + const stubCallback = () => {}; const allowTraversalCallback = () => true; - initiatedCallback.allowTraversal = callbacks.allowTraversal || allowTraversalCallback; + initiatedCallback.allowTraversal = + callbacks.allowTraversal || allowTraversalCallback; initiatedCallback.enter = callbacks.enter || stubCallback; initiatedCallback.leave = callbacks.leave || stubCallback; return initiatedCallback; @@ -27,42 +31,48 @@ function initCallbacks(callbacks: IAlgorithmCallbacks = {} as IAlgorithmCallback */ function depthFirstSearchRecursive( graph: Graph, - currentNodeId: NodeID, - previousNodeId: NodeID, + currentNodeId: ID, + previousNodeId: ID, callbacks: IAlgorithmCallbacks, - visit: Set, + visit: Set, directed: boolean, - visitOnce: boolean, + visitOnce: boolean ) { callbacks.enter({ current: currentNodeId, previous: previousNodeId, }); const neighbors = directed - ? - graph.getRelatedEdges(currentNodeId, "out").map(e => graph.getNode(e.target)) - : - graph.getNeighbors(currentNodeId) - ; + ? graph + .getRelatedEdges(currentNodeId, 'out') + .map((e) => graph.getNode(e.target)) + : graph.getNeighbors(currentNodeId); neighbors.forEach((nextNode) => { const nextNodeId = nextNode.id; // `Visit` is not considered when judging recursive conditions if ( - visitOnce ? - (callbacks.allowTraversal({ - previous: previousNodeId, - current: currentNodeId, - next: nextNodeId, - }) && !visit.has(nextNodeId)) - : - callbacks.allowTraversal({ - previous: previousNodeId, - current: currentNodeId, - next: nextNodeId, - }) + visitOnce + ? callbacks.allowTraversal({ + previous: previousNodeId, + current: currentNodeId, + next: nextNodeId, + }) && !visit.has(nextNodeId) + : callbacks.allowTraversal({ + previous: previousNodeId, + current: currentNodeId, + next: nextNodeId, + }) ) { visit.add(nextNodeId); - depthFirstSearchRecursive(graph, nextNodeId, currentNodeId, callbacks, visit, directed, visitOnce); + depthFirstSearchRecursive( + graph, + nextNodeId, + currentNodeId, + callbacks, + visit, + directed, + visitOnce + ); } }); callbacks.leave({ @@ -81,12 +91,20 @@ function depthFirstSearchRecursive( */ export function depthFirstSearch( graph: Graph, - startNodeId: NodeID, + startNodeId: ID, originalCallbacks?: IAlgorithmCallbacks, directed: boolean = false, visitOnce: boolean = true ) { - const visit = new Set(); + const visit = new Set(); visit.add(startNodeId); - depthFirstSearchRecursive(graph, startNodeId, '', initCallbacks(originalCallbacks), visit, directed, visitOnce); + depthFirstSearchRecursive( + graph, + startNodeId, + '', + initCallbacks(originalCallbacks), + visit, + directed, + visitOnce + ); } diff --git a/packages/graph/src/louvain.ts b/packages/graph/src/louvain.ts index f37ba6d..705bbb4 100644 --- a/packages/graph/src/louvain.ts +++ b/packages/graph/src/louvain.ts @@ -1,9 +1,9 @@ -import { ID, Node } from "@antv/graphlib"; -import { clone } from "@antv/util"; -import { Cluster, ClusterData, ClusterMap, Graph, NodeData } from "./types"; -import { getAllProperties, oneHot } from "./utils"; -import { graph2AdjacencyMatrix } from "./adjMatrix"; -import { Vector } from "./vector"; +import { ID, Node } from '@antv/graphlib'; +import { clone } from '@antv/util'; +import { Cluster, ClusterData, ClusterMap, Graph, NodeData } from './types'; +import { getAllProperties, oneHot } from './utils'; +import { graph2AdjacencyMatrix } from './adjMatrix'; +import { Vector } from './vector'; /** * The quality of the communities referred as partitions hereafter is measured by Modularity of the partition. @@ -13,23 +13,24 @@ function getModularity( nodes: Node[], adjMatrix: number[][], ks: number[], - m: number + m: number, + nodeToCluster: Map ) { const length = adjMatrix.length; const param = 2 * m; // number if links let modularity = 0; for (let i = 0; i < length; i++) { - const clusteri = nodes[i].data.clusterId as string; + const clusteri = nodeToCluster.get(nodes[i].id); for (let j = 0; j < length; j++) { - const clusterj = nodes[j].data.clusterId as string; + const clusterj = nodeToCluster.get(nodes[j].id); if (clusteri !== clusterj) continue; // 1 if x = y and 0 otherwise const entry = adjMatrix[i][j] || 0; // Aij: the weightof the edge between i & j const ki = ks[i] || 0; // Ki: degree of the node const kj = ks[j] || 0; - modularity += (entry - ki * kj / param); + modularity += entry - (ki * kj) / param; } } - modularity *= (1 / param); + modularity *= 1 / param; return modularity; } @@ -37,6 +38,7 @@ function getModularity( function getInertialModularity( nodes: Node[] = [], allPropertiesWeight: number[][], + nodeToCluster: Map ) { const length = nodes.length; let totalProperties = new Vector([]); @@ -51,7 +53,8 @@ function getInertialModularity( let variance: number = 0; for (let i = 0; i < length; i++) { const propertiesi = new Vector(allPropertiesWeight[i]); - const squareEuclideanDistance = propertiesi.squareEuclideanDistance(avgProperties); + const squareEuclideanDistance = + propertiesi.squareEuclideanDistance(avgProperties); variance += squareEuclideanDistance; } @@ -60,17 +63,21 @@ function getInertialModularity( nodes.forEach(() => { squareEuclideanDistanceInfo.push([]); }); + const clusterInertialMap = new Map(); for (let i = 0; i < length; i++) { const propertiesi = new Vector(allPropertiesWeight[i]); - nodes[i].data['clusterInertial'] = 0; + clusterInertialMap.set(nodes[i].id, 0); for (let j = 0; j < length; j++) { - if ( i === j) { + if (i === j) { squareEuclideanDistanceInfo[i][j] = 0; continue; } const propertiesj = new Vector(allPropertiesWeight[j]); - squareEuclideanDistanceInfo[i][j] = propertiesi.squareEuclideanDistance(propertiesj); - (nodes[i].data['clusterInertial'] as number) += squareEuclideanDistanceInfo[i][j]; + squareEuclideanDistanceInfo[i][j] = + propertiesi.squareEuclideanDistance(propertiesj); + let clusterInertial = clusterInertialMap.get(nodes[i].id); + clusterInertial += squareEuclideanDistanceInfo[i][j]; + clusterInertialMap.set(nodes[i].id, clusterInertial); } } @@ -78,12 +85,15 @@ function getInertialModularity( let inertialModularity = 0; const param = 2 * length * variance; for (let i = 0; i < length; i++) { - const clusteri = nodes[i].data.clusterId; + const clusteri = nodeToCluster.get(nodes[i].id); for (let j = 0; j < length; j++) { - const clusterj = nodes[j].data.clusterId; - if ( i === j || clusteri !== clusterj) continue; - const inertial = ((nodes[i].data.clusterInertial as number) * (nodes[j].data.clusterInertial as number)) - / Math.pow(param, 2) - squareEuclideanDistanceInfo[i][j] / param; + const clusterj = nodeToCluster.get(nodes[j].id); + if (i === j || clusteri !== clusterj) continue; + const inertial = + (clusterInertialMap.get(nodes[i].id) * + clusterInertialMap.get(nodes[j].id)) / + Math.pow(param, 2) - + squareEuclideanDistanceInfo[i][j] / param; inertialModularity += inertial; } } @@ -110,31 +120,40 @@ export function louvain( inertialModularity: boolean = false, involvedKeys: string[] = [], uninvolvedKeys: string[] = ['id'], - inertialWeight: number = 1, + inertialWeight: number = 1 ): ClusterData { const nodes = graph.getAllNodes(); const edges = graph.getAllEdges(); let allPropertiesWeight: number[][] = []; + const originIndexMap = new Map(); if (inertialModularity) { nodes.forEach((node, index) => { - node.data.originIndex = index; + originIndexMap.set(node.id, index); }); - + let nodeTypeInfo: string[] = []; if (nodes.every((node) => 'nodeType' in node.data)) { - nodeTypeInfo = Array.from(new Set(nodes.map((node) => node.data.nodeType as string))); + nodeTypeInfo = Array.from( + new Set(nodes.map((node) => node.data.nodeType as string)) + ); nodes.forEach((node) => { - node.data.nodeType = nodeTypeInfo.findIndex((nodeType) => nodeType === node.data.nodeType); + node.data.nodeType = nodeTypeInfo.findIndex( + (nodeType) => nodeType === node.data.nodeType + ); }); } // 所有节点属性集合 const properties = getAllProperties(nodes); - + // 所有节点属性one-hot特征向量集合 - allPropertiesWeight = oneHot(properties, involvedKeys, uninvolvedKeys) as number[][]; + allPropertiesWeight = oneHot( + properties, + involvedKeys, + uninvolvedKeys + ) as number[][]; } - + /** * 1. To start with each node is assigned to a different community or partition. * The number of partitions is equal to number of nodes N. @@ -142,16 +161,17 @@ export function louvain( let uniqueId = 1; const clusters: ClusterMap = {}; const nodeMap: Record; idx: number }> = {}; + const nodeToCluster = new Map(); nodes.forEach((node, i) => { const cid: string = String(uniqueId++); - node.data.clusterId = cid; + nodeToCluster.set(node.id, cid); clusters[cid] = { id: cid, - nodes: [node] + nodes: [node], }; nodeMap[node.id] = { node, - idx: i + idx: i, }; }); // the adjacent matrix of calNodes inside clusters @@ -191,12 +211,18 @@ export function louvain( let finalNodes: Node[] = []; let finalClusters: ClusterMap = {}; while (true) { - if (inertialModularity && nodes.every((node) => node.hasOwnProperty('properties'))) { - totalModularity = getModularity(nodes, adjMatrix, ks, m) + getInertialModularity(nodes, allPropertiesWeight) * inertialWeight; + if ( + inertialModularity && + nodes.every((node) => node.hasOwnProperty('properties')) + ) { + totalModularity = + getModularity(nodes, adjMatrix, ks, m, nodeToCluster) + + getInertialModularity(nodes, allPropertiesWeight, nodeToCluster) * + inertialWeight; } else { - totalModularity = getModularity(nodes, adjMatrix, ks, m); + totalModularity = getModularity(nodes, adjMatrix, ks, m, nodeToCluster); } - + // 第一次迭代previousModularity直接赋值 if (iter === 0) { previousModularity = totalModularity; @@ -204,7 +230,10 @@ export function louvain( finalClusters = clusters; } - const increaseWithinThreshold = totalModularity > 0 && totalModularity > previousModularity && totalModularity - previousModularity < threshold; + const increaseWithinThreshold = + totalModularity > 0 && + totalModularity > previousModularity && + totalModularity - previousModularity < threshold; // 总模块度增加才更新最优解 if (totalModularity > previousModularity) { finalClusters = clone(clusters); @@ -222,20 +251,21 @@ export function louvain( let sumTot = 0; edges.forEach((edge) => { const { source, target } = edge; - const sourceClusterId = nodeMap[source].node.data.clusterId; - const targetClusterId = nodeMap[target].node.data.clusterId; - if ((sourceClusterId === clusterId && targetClusterId !== clusterId) - || (targetClusterId === clusterId && sourceClusterId !== clusterId)) { - sumTot = sumTot + (edge.data[weightPropertyName] as number || 1); + const sourceClusterId = nodeToCluster.get(source); + const targetClusterId = nodeToCluster.get(target); + if ( + (sourceClusterId === clusterId && targetClusterId !== clusterId) || + (targetClusterId === clusterId && sourceClusterId !== clusterId) + ) { + sumTot = sumTot + ((edge.data[weightPropertyName] as number) || 1); } }); clusters[clusterId].sumTot = sumTot; }); - // move the nodes to increase the delta modularity nodes.forEach((node, i) => { - const selfCluster = clusters[node.data.clusterId]; + const selfCluster = clusters[nodeToCluster.get(node.id)]; let bestIncrease = 0; let bestCluster: Cluster; @@ -251,22 +281,31 @@ export function louvain( // the modurarity for **removing** the node i from the origin cluster of node i const removeModurarity = kiin - selfCluster.sumTot * commonParam; // nodes for **removing** node i into this neighbor cluster - const selfClusterNodesAfterRemove = selfClusterNodes.filter((scNode) => scNode.id !== node.id); + const selfClusterNodesAfterRemove = selfClusterNodes.filter( + (scNode) => scNode.id !== node.id + ); const propertiesWeightRemove: number[][] = []; selfClusterNodesAfterRemove.forEach((nodeRemove, index) => { - propertiesWeightRemove[index] = allPropertiesWeight[nodeRemove.data.originIndex as number]; + propertiesWeightRemove[index] = + allPropertiesWeight[originIndexMap.get(nodeRemove.id)]; }); // the inertialModularity for **removing** the node i from the origin cluster of node i - const removeInertialModularity = inertialModularity ? getInertialModularity(selfClusterNodesAfterRemove, allPropertiesWeight) * inertialWeight : 0; + const removeInertialModularity = inertialModularity + ? getInertialModularity( + selfClusterNodesAfterRemove, + allPropertiesWeight, + nodeToCluster + ) * inertialWeight + : 0; // the neightbors of the node const nodeNeighborIds = neighbors[node.id]; Object.keys(nodeNeighborIds).forEach((neighborNodeId) => { const neighborNode = nodeMap[neighborNodeId].node; - const neighborClusterId = neighborNode.data.clusterId; + const neighborClusterId = nodeToCluster.get(neighborNode.id); // if the node and the neighbor of node are in the same cluster, reutrn - if (neighborClusterId === node.data.clusterId) return; + if (neighborClusterId === nodeToCluster.get(node.id)) return; const neighborCluster = clusters[neighborClusterId]; const clusterNodes = neighborCluster.nodes; @@ -281,20 +320,31 @@ export function louvain( }); // the modurarity for **adding** node i into this neighbor cluster - const addModurarity = neighborClusterKiin - neighborCluster.sumTot * commonParam; + const addModurarity = + neighborClusterKiin - neighborCluster.sumTot * commonParam; // nodes for **adding** node i into this neighbor cluster - const clusterNodesAfterAdd= clusterNodes.concat([node]); + const clusterNodesAfterAdd = clusterNodes.concat([node]); const propertiesWeightAdd: number[][] = []; clusterNodesAfterAdd.forEach((nodeAdd, index) => { - propertiesWeightAdd[index] = allPropertiesWeight[nodeAdd.data.originIndex as number]; + propertiesWeightAdd[index] = + allPropertiesWeight[originIndexMap.get(nodeAdd.id)]; }); // the inertialModularity for **adding** node i into this neighbor cluster - const addInertialModularity = inertialModularity ? getInertialModularity(clusterNodesAfterAdd, allPropertiesWeight) * inertialWeight : 0; + const addInertialModularity = inertialModularity + ? getInertialModularity( + clusterNodesAfterAdd, + allPropertiesWeight, + nodeToCluster + ) * inertialWeight + : 0; // the increase modurarity is the difference between addModurarity and removeModurarity let increase = addModurarity - removeModurarity; if (inertialModularity) { - increase = (addModurarity + addInertialModularity) - (removeModurarity + removeInertialModularity); + increase = + addModurarity + + addInertialModularity - + (removeModurarity + removeInertialModularity); } // find the best cluster to move node i into @@ -307,8 +357,8 @@ export function louvain( // if found a best cluster to move into if (bestIncrease > 0) { bestCluster.nodes.push(node); - const previousClusterId = node.data.clusterId; - node.data.clusterId = bestCluster.id; + const previousClusterId = nodeToCluster.get(node.id); + nodeToCluster.set(node.id, bestCluster.id); // move the node to the best cluster const nodeInSelfClusterIdx = selfCluster.nodes.indexOf(node); // remove from origin cluster @@ -319,15 +369,27 @@ export function louvain( let selfClusterSumTot = 0; edges.forEach((edge) => { const { source, target } = edge; - const sourceClusterId = nodeMap[source].node.data.clusterId; - const targetClusterId = nodeMap[target].node.data.clusterId; - if ((sourceClusterId === bestCluster.id && targetClusterId !== bestCluster.id) - || (targetClusterId === bestCluster.id && sourceClusterId !== bestCluster.id)) { - neighborClusterSumTot = neighborClusterSumTot + (edge.data[weightPropertyName] as number || 1); + const sourceClusterId = nodeToCluster.get(source); + const targetClusterId = nodeToCluster.get(target); + if ( + (sourceClusterId === bestCluster.id && + targetClusterId !== bestCluster.id) || + (targetClusterId === bestCluster.id && + sourceClusterId !== bestCluster.id) + ) { + neighborClusterSumTot = + neighborClusterSumTot + + ((edge.data[weightPropertyName] as number) || 1); } - if ((sourceClusterId === previousClusterId && targetClusterId !== previousClusterId) - || (targetClusterId === previousClusterId && sourceClusterId !== previousClusterId)) { - selfClusterSumTot = selfClusterSumTot + (edge.data[weightPropertyName] as number || 1); + if ( + (sourceClusterId === previousClusterId && + targetClusterId !== previousClusterId) || + (targetClusterId === previousClusterId && + sourceClusterId !== previousClusterId) + ) { + selfClusterSumTot = + selfClusterSumTot + + ((edge.data[weightPropertyName] as number) || 1); } }); @@ -358,12 +420,13 @@ export function louvain( finalClusters[newId] = cluster; newClusterIdMap[clusterId] = newId; delete finalClusters[clusterId]; - clusterIdx ++; + clusterIdx++; }); // restore node clusterId finalNodes.forEach((node) => { - if (node.data.clusterId && newClusterIdMap[node.data.clusterId]) { - node.data.clusterId = newClusterIdMap[node.data.clusterId]; + const clusterId = nodeToCluster.get(node.id); + if (clusterId && newClusterIdMap[clusterId]) { + nodeToCluster.set(node.id, newClusterIdMap[clusterId]); } }); // get the cluster edges @@ -374,22 +437,25 @@ export function louvain( data: { weight: number; count: number; - } + }; }[] = []; - const clusterEdgeMap: Record = {}; + > = {}; edges.forEach((edge) => { const { source, target } = edge; - const weight = edge.data[weightPropertyName] as number || 1; - const sourceClusterId = nodeMap[source].node.data.clusterId; - const targetClusterId = nodeMap[target].node.data.clusterId; + const weight = (edge.data[weightPropertyName] as number) || 1; + const sourceClusterId = nodeToCluster.get(source); + const targetClusterId = nodeToCluster.get(target); if (!sourceClusterId || !targetClusterId) return; const newEdgeId = `${sourceClusterId}---${targetClusterId}`; if (clusterEdgeMap[newEdgeId]) { @@ -402,8 +468,8 @@ export function louvain( target: targetClusterId, data: { weight, - count: 1 - } + count: 1, + }, }; clusterEdgeMap[newEdgeId] = newEdge; clusterEdges.push(newEdge); @@ -415,6 +481,7 @@ export function louvain( }); return { clusters: clustersArray, - clusterEdges + clusterEdges, + nodeToCluster, }; } diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index d4aeeee..df9e364 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -1,4 +1,4 @@ -import { Edge, Graph as IGraph, Node, PlainObject } from '@antv/graphlib'; +import { ID, Edge, Graph as IGraph, Node, PlainObject } from '@antv/graphlib'; // Map of attribute / eigenvalue distribution in dataset export interface KeyValueMap { @@ -22,6 +22,7 @@ export interface Cluster { export interface ClusterData { clusters: Cluster[]; clusterEdges: Edge[]; + nodeToCluster: Map; } export interface ClusterMap { @@ -32,17 +33,15 @@ export type Graph = IGraph; export type Matrix = number[]; export interface IAlgorithmCallbacks { - enter?: (param: { current: NodeID; previous: NodeID }) => void; - leave?: (param: { current: NodeID; previous?: NodeID }) => void; + enter?: (param: { current: ID; previous: ID }) => void; + leave?: (param: { current: ID; previous?: ID }) => void; allowTraversal?: (param: { - previous?: NodeID; - current?: NodeID; - next: NodeID; + previous?: ID; + current?: ID; + next: ID; }) => boolean; } -export type NodeID = string | number; - export type NodeSimilarity = Node & { data: { cosineSimilarity?: number; @@ -59,6 +58,6 @@ export type IEdge = Edge; export type IMSTAlgorithm = (graph: Graph, weightProps?: string) => IEdge[]; export interface IMSTAlgorithmOpt { - 'prim': IMSTAlgorithm; - 'kruskal': IMSTAlgorithm; -} \ No newline at end of file + prim: IMSTAlgorithm; + kruskal: IMSTAlgorithm; +}