From 3dd025a77df4d8a75c352d41d746873f8d8aa42c Mon Sep 17 00:00:00 2001 From: zqqcee Date: Mon, 2 Oct 2023 17:47:52 +0800 Subject: [PATCH 01/10] feat: v5 algorithm k-means --- packages/graph/src/index.ts | 1 + packages/graph/src/k-means.ts | 215 ++++++++++++++++++++++++++++++++++ packages/graph/src/types.ts | 7 +- packages/graph/src/utils.ts | 15 ++- 4 files changed, 236 insertions(+), 2 deletions(-) create mode 100644 packages/graph/src/k-means.ts diff --git a/packages/graph/src/index.ts b/packages/graph/src/index.ts index 816e8b7..b0e22c8 100644 --- a/packages/graph/src/index.ts +++ b/packages/graph/src/index.ts @@ -11,3 +11,4 @@ export * from './nodes-cosine-similarity'; export * from './gaddi'; export * from './connected-component'; export * from './mst'; +export * from './k-means'; \ No newline at end of file diff --git a/packages/graph/src/k-means.ts b/packages/graph/src/k-means.ts new file mode 100644 index 0000000..7f35673 --- /dev/null +++ b/packages/graph/src/k-means.ts @@ -0,0 +1,215 @@ +import { isEqual, uniq } from '@antv/util'; +import { Edge } from '@antv/graphlib'; +import { getAllProperties, oneHot, getDistance } from './utils'; +import { Vector } from "./vector"; +import { ClusterData, DistanceType, Graph, EdgeData, Cluster } from './types'; + +/** + * Calculates the centroid based on the distance type and the given index. + * @param distanceType The distance type to use for centroid calculation. + * @param allPropertiesWeight The weight matrix of all properties. + * @param index The index of the centroid. + * @returns The centroid. + */ +const getCentroid = (distanceType: DistanceType, allPropertiesWeight: number[][], index: number) => { + let centroid: number[] = []; + switch (distanceType) { + case DistanceType.EuclideanDistance: + centroid = allPropertiesWeight[index]; + break; + default: + centroid = []; + break; + } + return centroid; +} + +/** + * Performs the k-means clustering algorithm on a graph. + * @param graph The graph to perform clustering on. + * @param k The number of clusters. + * @param propertyKey The property key to use for clustering. Default is undefined. + * @param involvedKeys The keys of properties to be considered for clustering. Default is an empty array. + * @param uninvolvedKeys The keys of properties to be ignored for clustering. Default is ['id']. + * @param distanceType The distance type to use for clustering. Default is DistanceType.EuclideanDistance. + * @returns The cluster data containing the clusters and cluster edges. + */ +export const kMeans = ( + graph: Graph, + k: number = 3, + propertyKey: string = undefined, + involvedKeys: string[] = [], + uninvolvedKeys: string[] = ['id'], + distanceType: DistanceType = DistanceType.EuclideanDistance, +): ClusterData => { + const nodes = graph.getAllNodes(); + const edges = graph.getAllEdges(); + const defaultClusterInfo: ClusterData = { + clusters: [ + { + id: "0", + nodes, + } + ], + clusterEdges: [] + }; + + // When the distance type is Euclidean distance and there are no attributes in data, return directly + if (distanceType === DistanceType.EuclideanDistance && !nodes.every(node => node.data.hasOwnProperty(propertyKey))) { + return defaultClusterInfo; + } + let properties = []; + let allPropertiesWeight: number[][] = []; + if (distanceType === DistanceType.EuclideanDistance) { + properties = getAllProperties(nodes); + allPropertiesWeight = oneHot(properties, involvedKeys, uninvolvedKeys) as number[][]; + } + if (!allPropertiesWeight.length) { + return defaultClusterInfo; + } + const allPropertiesWeightUniq = uniq(allPropertiesWeight.map(item => item.join(''))); + // When the number of nodes or the length of the attribute set is less than k, k will be adjusted to the smallest of them + const finalK = Math.min(k, nodes.length, allPropertiesWeightUniq.length); + for (let i = 0; i < nodes.length; i++) { + nodes[i].data.originIndex = i; + } + const centroids: number[][] = []; + const centroidIndexList: number[] = []; + const clusters: Cluster[] = []; + for (let i = 0; i < finalK; i++) { + if (i === 0) { + // random choose centroid + const randomIndex = Math.floor(Math.random() * nodes.length); + switch (distanceType) { + case DistanceType.EuclideanDistance: + centroids[i] = allPropertiesWeight[randomIndex]; + break; + default: + centroids[i] = []; + break; + } + centroidIndexList.push(randomIndex); + clusters[i].nodes = [nodes[randomIndex]]; + nodes[randomIndex].data.clusterId = String(i); + } else { + let maxDistance = -Infinity; + let maxDistanceNodeIndex = 0; + // Select the point with the farthest average distance from the existing centroid as the new centroid + for (let m = 0; m < nodes.length; m++) { + if (!centroidIndexList.includes(m)) { + let totalDistance = 0; + for (let j = 0; j < centroids.length; j++) { + // Find the distance from the node to the centroid (Euclidean distance of the default node attribute) + let distance = 0; + switch (distanceType) { + case DistanceType.EuclideanDistance: + distance = getDistance(allPropertiesWeight[nodes[m].data.originIndex as number], centroids[j], distanceType); + break; + default: + break; + } + totalDistance += distance; + } + // The average distance from the node to each centroid (default Euclidean distance) + const avgDistance = totalDistance / centroids.length; + // Record the distance and node index to the farthest centroid + if (avgDistance > maxDistance && + !centroids.find(centroid => isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodes[m].data.originIndex as number)))) { + maxDistance = avgDistance; + maxDistanceNodeIndex = m; + } + } + } + centroids[i] = getCentroid(distanceType, allPropertiesWeight, maxDistanceNodeIndex); + centroidIndexList.push(maxDistanceNodeIndex); + clusters[i].nodes = [nodes[maxDistanceNodeIndex]]; + nodes[maxDistanceNodeIndex].data.clusterId = String(i); + } + } + + + let iterations = 0; + while (true) { + for (let i = 0; i < nodes.length; i++) { + let minDistanceIndex = 0; + let minDistance = Infinity; + if (!(iterations === 0 && centroidIndexList.includes(i))) { + for (let j = 0; j < centroids.length; j++) { + let distance = 0; + switch (distanceType) { + case DistanceType.EuclideanDistance: + distance = getDistance(allPropertiesWeight[i], centroids[j], distanceType); + break; + default: + break; + } + if (distance < minDistance) { + minDistance = distance; + minDistanceIndex = j; + } + } + // delete node + if (nodes[i].data.clusterId !== undefined) { + for (let n = clusters[Number(nodes[i].data.clusterId)].nodes.length - 1; n >= 0; n--) { + if (clusters[Number(nodes[i].data.clusterId)].nodes[n].id === nodes[i].id) { + clusters[Number(nodes[i].data.clusterId)].nodes.splice(n, 1); + } + } + } + // Divide the node into the class corresponding to the centroid (cluster center) with the smallest distance. + nodes[i].data.clusterId = String(minDistanceIndex); + clusters[minDistanceIndex].nodes.push(nodes[i]); + } + } + // Determine if there is a centroid (cluster center) movement + let centroidsEqualAvg = false; + for (let i = 0; i < clusters.length; i++) { + const clusterNodes = clusters[i].nodes; + let totalVector = new Vector([]); + for (let j = 0; j < clusterNodes.length; j++) { + totalVector = totalVector.add(new Vector(allPropertiesWeight[clusterNodes[j].data.originIndex as number])); + } + // Calculates the mean vector for each category + const avgVector = totalVector.avg(clusterNodes.length); + // If the mean vector is not equal to the centroid vector + if (!avgVector.equal(new Vector(centroids[i]))) { + centroidsEqualAvg = true; + // Move/update the centroid (cluster center) of each category to this mean vector + centroids[i] = avgVector.getArr(); + } + } + iterations++; + // Stop if each node belongs to a category and there is no centroid (cluster center) movement or the number of iterations exceeds 1000 + if (nodes.every(node => node.data.clusterId !== undefined) && centroidsEqualAvg || iterations >= 1000) { + break; + } + } + + // get the cluster edges + const clusterEdges: Edge[] = []; + const clusterEdgeMap: { + [key: string]: Edge + } = {}; + let edgeIndex = 0; + edges.forEach(edge => { + const { source, target } = edge; + const sourceClusterId = nodes.find(node => node.id === source)?.data.clusterId; + const targetClusterId = nodes.find(node => node.id === target)?.data.clusterId; + const newEdgeId = `${sourceClusterId}---${targetClusterId}`; + if (clusterEdgeMap[newEdgeId]) { + (clusterEdgeMap[newEdgeId].data.count as number)++; + } else { + const newEdge = { + id: edgeIndex++, + source: sourceClusterId, + target: targetClusterId, + data: { count: 1 }, + }; + clusterEdgeMap[newEdgeId] = newEdge; + clusterEdges.push(newEdge); + } + }); + + return { clusters, clusterEdges }; +} + diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index d4aeeee..f9cbcaf 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -28,6 +28,7 @@ export interface ClusterMap { [key: string]: Cluster; } + export type Graph = IGraph; export type Matrix = number[]; @@ -61,4 +62,8 @@ export type IMSTAlgorithm = (graph: Graph, weightProps?: string) => IEdge[]; export interface IMSTAlgorithmOpt { 'prim': IMSTAlgorithm; 'kruskal': IMSTAlgorithm; -} \ No newline at end of file +} + +export enum DistanceType { + EuclideanDistance = 'euclideanDistance', +} diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index e383a1f..4991232 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -1,5 +1,6 @@ import { Node, PlainObject } from "@antv/graphlib"; -import { KeyValueMap, NodeData } from "./types"; +import { Vector } from "./vector"; +import { DistanceType, KeyValueMap, NodeData } from "./types"; import { uniq } from "@antv/util"; export const getAllProperties = (nodes: Node[]) => { @@ -77,3 +78,15 @@ export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvol }); return oneHotCode; }; + +export const getDistance = (item: number[], otherItem: number[], distanceType: DistanceType = DistanceType.EuclideanDistance) => { + let distance = 0; + switch (distanceType) { + case DistanceType.EuclideanDistance: + distance = new Vector(item).euclideanDistance(new Vector(otherItem)); + break; + default: + break; + } + return distance; +} \ No newline at end of file From b6a8f3c2392831cc65a5f24bb4e3338420bc15f6 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Sat, 7 Oct 2023 12:19:47 +0800 Subject: [PATCH 02/10] fix: remove the properties param, because of the 'data' property in node --- packages/graph/src/k-means.ts | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/packages/graph/src/k-means.ts b/packages/graph/src/k-means.ts index 7f35673..f87d9a4 100644 --- a/packages/graph/src/k-means.ts +++ b/packages/graph/src/k-means.ts @@ -28,7 +28,6 @@ const getCentroid = (distanceType: DistanceType, allPropertiesWeight: number[][] * Performs the k-means clustering algorithm on a graph. * @param graph The graph to perform clustering on. * @param k The number of clusters. - * @param propertyKey The property key to use for clustering. Default is undefined. * @param involvedKeys The keys of properties to be considered for clustering. Default is an empty array. * @param uninvolvedKeys The keys of properties to be ignored for clustering. Default is ['id']. * @param distanceType The distance type to use for clustering. Default is DistanceType.EuclideanDistance. @@ -37,9 +36,8 @@ const getCentroid = (distanceType: DistanceType, allPropertiesWeight: number[][] export const kMeans = ( graph: Graph, k: number = 3, - propertyKey: string = undefined, involvedKeys: string[] = [], - uninvolvedKeys: string[] = ['id'], + uninvolvedKeys: string[] = [], distanceType: DistanceType = DistanceType.EuclideanDistance, ): ClusterData => { const nodes = graph.getAllNodes(); @@ -55,7 +53,7 @@ export const kMeans = ( }; // When the distance type is Euclidean distance and there are no attributes in data, return directly - if (distanceType === DistanceType.EuclideanDistance && !nodes.every(node => node.data.hasOwnProperty(propertyKey))) { + if (distanceType === DistanceType.EuclideanDistance && !nodes.every(node => node.data)) { return defaultClusterInfo; } let properties = []; @@ -89,8 +87,11 @@ export const kMeans = ( break; } centroidIndexList.push(randomIndex); - clusters[i].nodes = [nodes[randomIndex]]; nodes[randomIndex].data.clusterId = String(i); + clusters[i] = { + id: `${i}`, + nodes: [nodes[randomIndex]] + }; } else { let maxDistance = -Infinity; let maxDistanceNodeIndex = 0; @@ -122,7 +123,10 @@ export const kMeans = ( } centroids[i] = getCentroid(distanceType, allPropertiesWeight, maxDistanceNodeIndex); centroidIndexList.push(maxDistanceNodeIndex); - clusters[i].nodes = [nodes[maxDistanceNodeIndex]]; + clusters[i] = { + id: `${i}`, + nodes: [nodes[maxDistanceNodeIndex]] + }; nodes[maxDistanceNodeIndex].data.clusterId = String(i); } } From 30d7c3bb0d111bb1b92529dc256be0e323543e07 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Sat, 7 Oct 2023 12:20:27 +0800 Subject: [PATCH 03/10] test: unit test for k-means --- __tests__/unit/k-means.spec.ts | 389 +++++++++++++++++++++++++++++++++ __tests__/utils/data.ts | 29 +++ package.json | 2 +- 3 files changed, 419 insertions(+), 1 deletion(-) create mode 100644 __tests__/unit/k-means.spec.ts diff --git a/__tests__/unit/k-means.spec.ts b/__tests__/unit/k-means.spec.ts new file mode 100644 index 0000000..1b89e20 --- /dev/null +++ b/__tests__/unit/k-means.spec.ts @@ -0,0 +1,389 @@ +import { kMeans } from '../../packages/graph/src' +import propertiesGraphData from '../data/cluster-origin-properties-data.json'; +import { Graph } from "@antv/graphlib"; +import { dataPropertiesTransformer, dataLabelDataTransformer } from '../utils/data'; + + +describe('kMeans abnormal demo', () => { + it('no properties demo: ', () => { + const noPropertiesData = { + nodes: [ + { + id: 'node-0', + data: {}, + }, + { + id: 'node-1', + data: {}, + }, + { + id: 'node-2', + data: {}, + }, + { + id: 'node-3', + data: {}, + } + ], + } + const graph = new Graph(noPropertiesData); + const { clusters, clusterEdges } = kMeans(graph, 2); + expect(clusters.length).toBe(1); + expect(clusterEdges.length).toBe(0); + }); +}); + + +describe('kMeans normal demo', () => { + it('simple data demo: ', () => { + const simpleGraphData = { + nodes: [ + { + id: 'node-0', + properties: { + amount: 10, + city: '10001', + } + }, + { + id: 'node-1', + properties: { + amount: 10000, + city: '10002', + } + }, + { + id: 'node-2', + properties: { + amount: 3000, + city: '10003', + } + }, + { + id: 'node-3', + properties: { + amount: 3200, + city: '10003', + } + }, + { + id: 'node-4', + properties: { + amount: 2000, + city: '10003', + } + } + ], + edges: [ + { + id: 'edge-0', + source: 'node-0', + target: 'node-1', + }, + { + id: 'edge-1', + source: 'node-0', + target: 'node-2', + }, + { + id: 'edge-4', + source: 'node-3', + target: 'node-2', + }, + { + id: 'edge-5', + source: 'node-2', + target: 'node-1', + }, + { + id: 'edge-6', + source: 'node-4', + target: 'node-1', + }, + ] + } + const data = dataPropertiesTransformer(simpleGraphData); + const graph = new Graph(data); + const { clusters } = kMeans(graph, 3); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); + expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); + expect(nodes[2].data.clusterId).toEqual(nodes[4].data.clusterId); + }); + + + it('complex data demo: ', () => { + const data = dataLabelDataTransformer(propertiesGraphData); + const graph = new Graph(data); + const { clusters } = kMeans(graph, 3); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); + expect(nodes[0].data.clusterId).toEqual(nodes[1].data.clusterId); + expect(nodes[0].data.clusterId).toEqual(nodes[2].data.clusterId); + expect(nodes[0].data.clusterId).toEqual(nodes[3].data.clusterId); + expect(nodes[0].data.clusterId).toEqual(nodes[4].data.clusterId); + expect(nodes[5].data.clusterId).toEqual(nodes[6].data.clusterId); + expect(nodes[5].data.clusterId).toEqual(nodes[7].data.clusterId); + expect(nodes[5].data.clusterId).toEqual(nodes[8].data.clusterId); + expect(nodes[5].data.clusterId).toEqual(nodes[9].data.clusterId); + expect(nodes[5].data.clusterId).toEqual(nodes[10].data.clusterId); + expect(nodes[11].data.clusterId).toEqual(nodes[12].data.clusterId); + expect(nodes[11].data.clusterId).toEqual(nodes[13].data.clusterId); + expect(nodes[11].data.clusterId).toEqual(nodes[14].data.clusterId); + expect(nodes[11].data.clusterId).toEqual(nodes[15].data.clusterId); + expect(nodes[11].data.clusterId).toEqual(nodes[16].data.clusterId); + }); + + it('demo use involvedKeys: ', () => { + const simpleGraphData = { + nodes: [ + { + id: 'node-0', + properties: { + amount: 10, + city: '10001', + } + }, + { + id: 'node-1', + properties: { + amount: 10000, + city: '10002', + } + }, + { + id: 'node-2', + properties: { + amount: 3000, + city: '10003', + } + }, + { + id: 'node-3', + properties: { + amount: 3200, + city: '10003', + } + }, + { + id: 'node-4', + properties: { + amount: 2000, + city: '10003', + } + } + ], + edges: [ + { + id: 'edge-0', + source: 'node-0', + target: 'node-1', + }, + { + id: 'edge-1', + source: 'node-0', + target: 'node-2', + }, + { + id: 'edge-4', + source: 'node-3', + target: 'node-2', + }, + { + id: 'edge-5', + source: 'node-2', + target: 'node-1', + }, + { + id: 'edge-6', + source: 'node-4', + target: 'node-1', + }, + ] + } + const data = dataPropertiesTransformer(simpleGraphData); + const involvedKeys = ['amount']; + const graph = new Graph(data); + const { clusters } = kMeans(graph, 3, involvedKeys); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); + expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); + expect(nodes[2].data.clusterId).toEqual(nodes[4].data.clusterId); + }); + + it('demo use uninvolvedKeys: ', () => { + const simpleGraphData = { + nodes: [ + { + id: 'node-0', + properties: { + amount: 10, + city: '10001', + } + }, + { + id: 'node-1', + properties: { + amount: 10000, + city: '10002', + } + }, + { + id: 'node-2', + properties: { + amount: 3000, + city: '10003', + } + }, + { + id: 'node-3', + properties: { + amount: 3200, + city: '10003', + } + }, + { + id: 'node-4', + properties: { + amount: 2000, + city: '10003', + } + } + ], + edges: [ + { + id: 'edge-0', + source: 'node-0', + target: 'node-1', + }, + { + id: 'edge-1', + source: 'node-0', + target: 'node-2', + }, + { + id: 'edge-4', + source: 'node-3', + target: 'node-2', + }, + { + id: 'edge-5', + source: 'node-2', + target: 'node-1', + }, + { + id: 'edge-6', + source: 'node-4', + target: 'node-1', + }, + ] + } + const data = dataPropertiesTransformer(simpleGraphData); + const graph = new Graph(data); + const uninvolvedKeys = ['id', 'city']; + const { clusters } = kMeans(graph, 3, [], uninvolvedKeys); + expect(clusters.length).toBe(3); + const nodes = graph.getAllNodes(); data + expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); + expect(nodes[2].data.clusterId).toEqual(nodes[4].data.clusterId); + }); + +}); + +describe('kMeans All properties values are numeric demo', () => { + it('all properties values are numeric demo: ', () => { + const allPropertiesValuesNumericData = { + nodes: [ + { + id: 'node-0', + properties: { + max: 1000000, + mean: 900000, + min: 800000, + } + }, + { + id: 'node-1', + properties: { + max: 1600000, + mean: 1100000, + min: 600000, + } + }, + { + id: 'node-2', + properties: { + max: 5000, + mean: 3500, + min: 2000, + } + }, + { + id: 'node-3', + properties: { + max: 9000, + mean: 7500, + min: 6000, + } + } + ], + edges: [], + } + const data = dataPropertiesTransformer(allPropertiesValuesNumericData); + const graph = new Graph(data); + const { clusters, clusterEdges } = kMeans(graph, 2); + expect(clusters.length).toBe(2); + expect(clusterEdges.length).toBe(0); + const nodes = graph.getAllNodes(); + expect(nodes[0].data.clusterId).toEqual(nodes[1].data.clusterId); + expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); + }); + it('only one property and the value are numeric demo: ', () => { + const allPropertiesValuesNumericData = { + nodes: [ + { + id: 'node-0', + properties: { + num: 10, + } + }, + { + id: 'node-1', + properties: { + num: 12, + } + }, + { + id: 'node-2', + properties: { + num: 56, + } + }, + { + id: 'node-3', + properties: { + num: 300, + } + }, + { + id: 'node-4', + properties: { + num: 350, + } + } + ], + edges: [], + } + const data = dataPropertiesTransformer(allPropertiesValuesNumericData); + const graph = new Graph(data); + const { clusters, clusterEdges } = kMeans(graph, 2); + expect(clusters.length).toBe(2); + expect(clusterEdges.length).toBe(0); + const nodes = graph.getAllNodes(); + expect(nodes[0].data.clusterId).toEqual(nodes[1].data.clusterId); + expect(nodes[0].data.clusterId).toEqual(nodes[2].data.clusterId); + expect(nodes[3].data.clusterId).toEqual(nodes[4].data.clusterId); + }); + +}); + diff --git a/__tests__/utils/data.ts b/__tests__/utils/data.ts index 9c131d7..9c9035c 100644 --- a/__tests__/utils/data.ts +++ b/__tests__/utils/data.ts @@ -17,3 +17,32 @@ export const dataTransformer = (data: { nodes: { id: NodeID, [key: string]: any }), }; }; + +export const dataPropertiesTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => { + const { nodes, edges } = data; + return { + nodes: nodes.map((n) => { + const { id, properties, ...rest } = n; + return { id, data: { ...properties, ...rest } }; + }), + edges: edges.map((e, i) => { + const { id, source, target, ...rest } = e; + return { id: id ? id : `edge-${i}`, target, source, data: rest }; + }), + }; +}; + + +export const dataLabelDataTransformer = (data: { nodes: { id: NodeID, [key: string]: any }[], edges: { source: NodeID, target: NodeID, [key: string]: any }[] }): { nodes: INode[], edges: IEdge[] } => { + const { nodes, edges } = data; + return { + nodes: nodes.map((n) => { + const { id, label, data } = n; + return { id, data: { label, ...data } }; + }), + edges: edges.map((e, i) => { + const { id, source, target, ...rest } = e; + return { id: id ? id : `edge-${i}`, target, source, data: rest }; + }), + }; +}; \ No newline at end of file diff --git a/package.json b/package.json index e8f2ba9..2ee91e3 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ "build:ci": "pnpm -r run build:ci", "prepare": "husky install", "test": "jest", - "test_one": "jest ./__tests__/unit/mst.spec.ts", + "test_one": "jest ./__tests__/unit/k-means.spec.ts", "coverage": "jest --coverage", "build:site": "vite build", "deploy": "gh-pages -d site/dist", From 360d63933af2f02d1f4d62aa5516bdc92f843175 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Sat, 7 Oct 2023 12:28:34 +0800 Subject: [PATCH 04/10] fix: fix lint --- packages/graph/src/k-means.ts | 18 +++++++++--------- packages/graph/src/utils.ts | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/packages/graph/src/k-means.ts b/packages/graph/src/k-means.ts index f87d9a4..748c43d 100644 --- a/packages/graph/src/k-means.ts +++ b/packages/graph/src/k-means.ts @@ -22,7 +22,7 @@ const getCentroid = (distanceType: DistanceType, allPropertiesWeight: number[][] break; } return centroid; -} +}; /** * Performs the k-means clustering algorithm on a graph. @@ -53,7 +53,7 @@ export const kMeans = ( }; // When the distance type is Euclidean distance and there are no attributes in data, return directly - if (distanceType === DistanceType.EuclideanDistance && !nodes.every(node => node.data)) { + if (distanceType === DistanceType.EuclideanDistance && !nodes.every((node) => node.data)) { return defaultClusterInfo; } let properties = []; @@ -65,7 +65,7 @@ export const kMeans = ( if (!allPropertiesWeight.length) { return defaultClusterInfo; } - const allPropertiesWeightUniq = uniq(allPropertiesWeight.map(item => item.join(''))); + const allPropertiesWeightUniq = uniq(allPropertiesWeight.map((item) => item.join(''))); // When the number of nodes or the length of the attribute set is less than k, k will be adjusted to the smallest of them const finalK = Math.min(k, nodes.length, allPropertiesWeightUniq.length); for (let i = 0; i < nodes.length; i++) { @@ -115,7 +115,7 @@ export const kMeans = ( const avgDistance = totalDistance / centroids.length; // Record the distance and node index to the farthest centroid if (avgDistance > maxDistance && - !centroids.find(centroid => isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodes[m].data.originIndex as number)))) { + !centroids.find((centroid) => isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodes[m].data.originIndex as number)))) { maxDistance = avgDistance; maxDistanceNodeIndex = m; } @@ -184,7 +184,7 @@ export const kMeans = ( } iterations++; // Stop if each node belongs to a category and there is no centroid (cluster center) movement or the number of iterations exceeds 1000 - if (nodes.every(node => node.data.clusterId !== undefined) && centroidsEqualAvg || iterations >= 1000) { + if (nodes.every((node) => node.data.clusterId !== undefined) && centroidsEqualAvg || iterations >= 1000) { break; } } @@ -195,10 +195,10 @@ export const kMeans = ( [key: string]: Edge } = {}; let edgeIndex = 0; - edges.forEach(edge => { + edges.forEach((edge) => { const { source, target } = edge; - const sourceClusterId = nodes.find(node => node.id === source)?.data.clusterId; - const targetClusterId = nodes.find(node => node.id === target)?.data.clusterId; + const sourceClusterId = nodes.find((node) => node.id === source)?.data.clusterId; + const targetClusterId = nodes.find((node) => node.id === target)?.data.clusterId; const newEdgeId = `${sourceClusterId}---${targetClusterId}`; if (clusterEdgeMap[newEdgeId]) { (clusterEdgeMap[newEdgeId].data.count as number)++; @@ -215,5 +215,5 @@ export const kMeans = ( }); return { clusters, clusterEdges }; -} +}; diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index 4991232..9bb4b60 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -89,4 +89,4 @@ export const getDistance = (item: number[], otherItem: number[], distanceType: D break; } return distance; -} \ No newline at end of file +}; \ No newline at end of file From aa60fa6a05e3e69a20d6bb0ef2c5f7e5476da25d Mon Sep 17 00:00:00 2001 From: zqqcee Date: Sun, 15 Oct 2023 21:37:12 +0800 Subject: [PATCH 05/10] fix: replace vector with num array --- packages/graph/src/utils.ts | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index 9bb4b60..74d4fd9 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -83,10 +83,21 @@ export const getDistance = (item: number[], otherItem: number[], distanceType: D let distance = 0; switch (distanceType) { case DistanceType.EuclideanDistance: - distance = new Vector(item).euclideanDistance(new Vector(otherItem)); + distance = euclideanDistance(item, otherItem); break; default: break; } return distance; -}; \ No newline at end of file +}; + + +function euclideanDistance(source: number[], target: number[]) { + if (source.length !== target.length) return 0; + let res = 0; + source.forEach((s, i) => { + res += Math.pow(s - target[i], 2) + }) + return Math.sqrt(res); +} + From 49ff9478c503063582661e4fadc22f0b06126260 Mon Sep 17 00:00:00 2001 From: zqqcee Date: Sun, 15 Oct 2023 21:39:16 +0800 Subject: [PATCH 06/10] fix: replace originIndex in data field with a map --- packages/graph/src/k-means.ts | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/packages/graph/src/k-means.ts b/packages/graph/src/k-means.ts index 748c43d..b7cc05c 100644 --- a/packages/graph/src/k-means.ts +++ b/packages/graph/src/k-means.ts @@ -2,7 +2,7 @@ import { isEqual, uniq } from '@antv/util'; import { Edge } from '@antv/graphlib'; import { getAllProperties, oneHot, getDistance } from './utils'; import { Vector } from "./vector"; -import { ClusterData, DistanceType, Graph, EdgeData, Cluster } from './types'; +import { ClusterData, DistanceType, Graph, EdgeData, Cluster, NodeID } from './types'; /** * Calculates the centroid based on the distance type and the given index. @@ -42,6 +42,8 @@ export const kMeans = ( ): ClusterData => { const nodes = graph.getAllNodes(); const edges = graph.getAllEdges(); + const nodeToOriginIdx = new Map(); + const nodeToCluster = new Map(); const defaultClusterInfo: ClusterData = { clusters: [ { @@ -69,7 +71,7 @@ export const kMeans = ( // When the number of nodes or the length of the attribute set is less than k, k will be adjusted to the smallest of them const finalK = Math.min(k, nodes.length, allPropertiesWeightUniq.length); for (let i = 0; i < nodes.length; i++) { - nodes[i].data.originIndex = i; + nodeToOriginIdx.set(nodes[i].id, i); } const centroids: number[][] = []; const centroidIndexList: number[] = []; @@ -104,7 +106,7 @@ export const kMeans = ( let distance = 0; switch (distanceType) { case DistanceType.EuclideanDistance: - distance = getDistance(allPropertiesWeight[nodes[m].data.originIndex as number], centroids[j], distanceType); + distance = getDistance(allPropertiesWeight[nodeToOriginIdx.get(nodes[m].id)], centroids[j], distanceType); break; default: break; @@ -115,7 +117,7 @@ export const kMeans = ( const avgDistance = totalDistance / centroids.length; // Record the distance and node index to the farthest centroid if (avgDistance > maxDistance && - !centroids.find((centroid) => isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodes[m].data.originIndex as number)))) { + !centroids.find((centroid) => isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodeToOriginIdx.get(nodes[m].id))))) { maxDistance = avgDistance; maxDistanceNodeIndex = m; } @@ -171,7 +173,7 @@ export const kMeans = ( const clusterNodes = clusters[i].nodes; let totalVector = new Vector([]); for (let j = 0; j < clusterNodes.length; j++) { - totalVector = totalVector.add(new Vector(allPropertiesWeight[clusterNodes[j].data.originIndex as number])); + totalVector = totalVector.add(new Vector(allPropertiesWeight[nodeToOriginIdx.get(clusterNodes[j].id)])); } // Calculates the mean vector for each category const avgVector = totalVector.avg(clusterNodes.length); @@ -214,6 +216,8 @@ export const kMeans = ( } }); + console.log(clusters); + return { clusters, clusterEdges }; }; From 60ddfc2dd2f2139b90392fde26f131b0ce4fe92d Mon Sep 17 00:00:00 2001 From: zqqcee Date: Sun, 15 Oct 2023 22:02:57 +0800 Subject: [PATCH 07/10] fix: move clusterId in data field to a new map named nodeToCluster, and return it from k-means func --- __tests__/unit/k-means.spec.ts | 64 ++++++++++++++++++---------------- packages/graph/src/k-means.ts | 28 +++++++-------- packages/graph/src/types.ts | 1 + 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/__tests__/unit/k-means.spec.ts b/__tests__/unit/k-means.spec.ts index 1b89e20..92b8e06 100644 --- a/__tests__/unit/k-means.spec.ts +++ b/__tests__/unit/k-means.spec.ts @@ -104,34 +104,36 @@ describe('kMeans normal demo', () => { } const data = dataPropertiesTransformer(simpleGraphData); const graph = new Graph(data); - const { clusters } = kMeans(graph, 3); + const { clusters, nodeToCluster } = kMeans(graph, 3); expect(clusters.length).toBe(3); const nodes = graph.getAllNodes(); - expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); - expect(nodes[2].data.clusterId).toEqual(nodes[4].data.clusterId); + + + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[4].id)); }); it('complex data demo: ', () => { const data = dataLabelDataTransformer(propertiesGraphData); const graph = new Graph(data); - const { clusters } = kMeans(graph, 3); + const { clusters,nodeToCluster } = kMeans(graph, 3); expect(clusters.length).toBe(3); const nodes = graph.getAllNodes(); - expect(nodes[0].data.clusterId).toEqual(nodes[1].data.clusterId); - expect(nodes[0].data.clusterId).toEqual(nodes[2].data.clusterId); - expect(nodes[0].data.clusterId).toEqual(nodes[3].data.clusterId); - expect(nodes[0].data.clusterId).toEqual(nodes[4].data.clusterId); - expect(nodes[5].data.clusterId).toEqual(nodes[6].data.clusterId); - expect(nodes[5].data.clusterId).toEqual(nodes[7].data.clusterId); - expect(nodes[5].data.clusterId).toEqual(nodes[8].data.clusterId); - expect(nodes[5].data.clusterId).toEqual(nodes[9].data.clusterId); - expect(nodes[5].data.clusterId).toEqual(nodes[10].data.clusterId); - expect(nodes[11].data.clusterId).toEqual(nodes[12].data.clusterId); - expect(nodes[11].data.clusterId).toEqual(nodes[13].data.clusterId); - expect(nodes[11].data.clusterId).toEqual(nodes[14].data.clusterId); - expect(nodes[11].data.clusterId).toEqual(nodes[15].data.clusterId); - expect(nodes[11].data.clusterId).toEqual(nodes[16].data.clusterId); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[1].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[2].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[4].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[6].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[7].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[8].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[9].id)); + expect(nodeToCluster.get(nodes[5].id)).toEqual(nodeToCluster.get(nodes[10].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[12].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[13].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[14].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[15].id)); + expect(nodeToCluster.get(nodes[11].id)).toEqual(nodeToCluster.get(nodes[16].id)); }); it('demo use involvedKeys: ', () => { @@ -204,11 +206,11 @@ describe('kMeans normal demo', () => { const data = dataPropertiesTransformer(simpleGraphData); const involvedKeys = ['amount']; const graph = new Graph(data); - const { clusters } = kMeans(graph, 3, involvedKeys); + const { clusters ,nodeToCluster} = kMeans(graph, 3, involvedKeys); expect(clusters.length).toBe(3); const nodes = graph.getAllNodes(); - expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); - expect(nodes[2].data.clusterId).toEqual(nodes[4].data.clusterId); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[4].id)); }); it('demo use uninvolvedKeys: ', () => { @@ -281,11 +283,11 @@ describe('kMeans normal demo', () => { const data = dataPropertiesTransformer(simpleGraphData); const graph = new Graph(data); const uninvolvedKeys = ['id', 'city']; - const { clusters } = kMeans(graph, 3, [], uninvolvedKeys); + const { clusters,nodeToCluster } = kMeans(graph, 3, [], uninvolvedKeys); expect(clusters.length).toBe(3); const nodes = graph.getAllNodes(); data - expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); - expect(nodes[2].data.clusterId).toEqual(nodes[4].data.clusterId); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[4].id)); }); }); @@ -331,12 +333,12 @@ describe('kMeans All properties values are numeric demo', () => { } const data = dataPropertiesTransformer(allPropertiesValuesNumericData); const graph = new Graph(data); - const { clusters, clusterEdges } = kMeans(graph, 2); + const { clusters, clusterEdges,nodeToCluster } = kMeans(graph, 2); expect(clusters.length).toBe(2); expect(clusterEdges.length).toBe(0); const nodes = graph.getAllNodes(); - expect(nodes[0].data.clusterId).toEqual(nodes[1].data.clusterId); - expect(nodes[2].data.clusterId).toEqual(nodes[3].data.clusterId); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[1].id)); + expect(nodeToCluster.get(nodes[2].id)).toEqual(nodeToCluster.get(nodes[3].id)); }); it('only one property and the value are numeric demo: ', () => { const allPropertiesValuesNumericData = { @@ -376,13 +378,13 @@ describe('kMeans All properties values are numeric demo', () => { } const data = dataPropertiesTransformer(allPropertiesValuesNumericData); const graph = new Graph(data); - const { clusters, clusterEdges } = kMeans(graph, 2); + const { clusters, clusterEdges,nodeToCluster } = kMeans(graph, 2); expect(clusters.length).toBe(2); expect(clusterEdges.length).toBe(0); const nodes = graph.getAllNodes(); - expect(nodes[0].data.clusterId).toEqual(nodes[1].data.clusterId); - expect(nodes[0].data.clusterId).toEqual(nodes[2].data.clusterId); - expect(nodes[3].data.clusterId).toEqual(nodes[4].data.clusterId); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[1].id)); + expect(nodeToCluster.get(nodes[0].id)).toEqual(nodeToCluster.get(nodes[2].id)); + expect(nodeToCluster.get(nodes[3].id)).toEqual(nodeToCluster.get(nodes[4].id)); }); }); diff --git a/packages/graph/src/k-means.ts b/packages/graph/src/k-means.ts index b7cc05c..3e35095 100644 --- a/packages/graph/src/k-means.ts +++ b/packages/graph/src/k-means.ts @@ -51,7 +51,8 @@ export const kMeans = ( nodes, } ], - clusterEdges: [] + clusterEdges: [], + nodeToCluster, }; // When the distance type is Euclidean distance and there are no attributes in data, return directly @@ -89,7 +90,7 @@ export const kMeans = ( break; } centroidIndexList.push(randomIndex); - nodes[randomIndex].data.clusterId = String(i); + nodeToCluster.set(nodes[randomIndex].id, `${i}`); clusters[i] = { id: `${i}`, nodes: [nodes[randomIndex]] @@ -129,7 +130,7 @@ export const kMeans = ( id: `${i}`, nodes: [nodes[maxDistanceNodeIndex]] }; - nodes[maxDistanceNodeIndex].data.clusterId = String(i); + nodeToCluster.set(nodes[maxDistanceNodeIndex].id, `${i}`); } } @@ -155,15 +156,16 @@ export const kMeans = ( } } // delete node - if (nodes[i].data.clusterId !== undefined) { - for (let n = clusters[Number(nodes[i].data.clusterId)].nodes.length - 1; n >= 0; n--) { - if (clusters[Number(nodes[i].data.clusterId)].nodes[n].id === nodes[i].id) { - clusters[Number(nodes[i].data.clusterId)].nodes.splice(n, 1); + const cId = nodeToCluster.get(nodes[i].id); + if (cId !== undefined) { + for (let n = clusters[Number(cId)].nodes.length - 1; n >= 0; n--) { + if (clusters[Number(cId)].nodes[n].id === nodes[i].id) { + clusters[Number(cId)].nodes.splice(n, 1); } } } // Divide the node into the class corresponding to the centroid (cluster center) with the smallest distance. - nodes[i].data.clusterId = String(minDistanceIndex); + nodeToCluster.set(nodes[i].id, `${minDistanceIndex}`); clusters[minDistanceIndex].nodes.push(nodes[i]); } } @@ -186,7 +188,7 @@ export const kMeans = ( } iterations++; // Stop if each node belongs to a category and there is no centroid (cluster center) movement or the number of iterations exceeds 1000 - if (nodes.every((node) => node.data.clusterId !== undefined) && centroidsEqualAvg || iterations >= 1000) { + if (nodes.every((node) => !nodeToCluster.get(node.id)) && centroidsEqualAvg || iterations >= 1000) { break; } } @@ -199,8 +201,8 @@ export const kMeans = ( let edgeIndex = 0; edges.forEach((edge) => { const { source, target } = edge; - const sourceClusterId = nodes.find((node) => node.id === source)?.data.clusterId; - const targetClusterId = nodes.find((node) => node.id === target)?.data.clusterId; + const sourceClusterId = nodeToCluster.get(source); + const targetClusterId = nodeToCluster.get(target); const newEdgeId = `${sourceClusterId}---${targetClusterId}`; if (clusterEdgeMap[newEdgeId]) { (clusterEdgeMap[newEdgeId].data.count as number)++; @@ -216,8 +218,6 @@ export const kMeans = ( } }); - console.log(clusters); - - return { clusters, clusterEdges }; + return { clusters, clusterEdges, nodeToCluster }; }; diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index f9cbcaf..97b3485 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -22,6 +22,7 @@ export interface Cluster { export interface ClusterData { clusters: Cluster[]; clusterEdges: Edge[]; + nodeToCluster: Map } export interface ClusterMap { From a1b83b2860f6bed88d368be67ae8c7ac4f74ed8b Mon Sep 17 00:00:00 2001 From: zqqcee Date: Sun, 15 Oct 2023 22:03:51 +0800 Subject: [PATCH 08/10] fix: fix lint --- packages/graph/src/detect-cycle.ts | 12 ++++++------ packages/graph/src/dfs.ts | 2 +- packages/graph/src/types.ts | 2 +- packages/graph/src/utils.ts | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/packages/graph/src/detect-cycle.ts b/packages/graph/src/detect-cycle.ts index 0af03f8..5a81cdf 100644 --- a/packages/graph/src/detect-cycle.ts +++ b/packages/graph/src/detect-cycle.ts @@ -57,7 +57,7 @@ export const detectDirectedCycle = (graph: Graph): { return true; }, }; - for (let key of Object.keys(unvisitedSet)) { + for (const key of Object.keys(unvisitedSet)) { depthFirstSearch(graph, key, callbacks, true, false); } return cycle; @@ -89,7 +89,7 @@ export const detectAllUndirectedCycle = (graph: Graph, nodeIds?: NodeID[], inclu // const neighbors = getNeighbors(curNodeId, graphData.edges); for (let i = 0; i < neighbors.length; i += 1) { const neighborId = neighbors[i].id; - const neighbor = graph.getAllNodes().find(node => node.id === neighborId); + const neighbor = graph.getAllNodes().find((node) => node.id === neighborId); if (neighborId === curNodeId) { allCycles.push({ [neighborId]: curNode }); } else if (!(neighborId in used)) { @@ -170,7 +170,7 @@ export const detectAllDirectedCycle = (graph: Graph, nodeIds?: NodeID[], include const circuit = (node: INode, start: INode, adjList: { [key: NodeID]: number[] }) => { let closed = false; // whether a path is closed - if (nodeIds && include === false && nodeIds.indexOf(node.id) > -1) return closed; + if (nodeIds && !include && nodeIds.indexOf(node.id) > -1) return closed; path.push(node); blocked.add(node); const neighbors = adjList[node.id]; @@ -221,7 +221,7 @@ export const detectAllDirectedCycle = (graph: Graph, nodeIds?: NodeID[], include const nodeId = nodeIds[i]; node2Idx[nodes[i].id] = node2Idx[nodeId]; node2Idx[nodeId] = 0; - idx2Node[0] = nodes.find(node => node.id === nodeId); + idx2Node[0] = nodes.find((node) => node.id === nodeId); idx2Node[node2Idx[nodes[i].id]] = nodes[i]; } } @@ -246,9 +246,9 @@ export const detectAllDirectedCycle = (graph: Graph, nodeIds?: NodeID[], include for (let i = 0; i < component.length; i += 1) { const node = component[i]; adjList[node.id] = []; - for (const neighbor of graph.getRelatedEdges(node.id, "out").map(n => n.target).filter((n) => component.map(c => c.id).indexOf(n) > -1)) { + for (const neighbor of graph.getRelatedEdges(node.id, "out").map((n) => n.target).filter((n) => component.map((c) => c.id).indexOf(n) > -1)) { // 对自环情况 (点连向自身) 特殊处理:记录自环,但不加入adjList - if (neighbor === node.id && !(include === false && nodeIds.indexOf(node.id) > -1)) { + if (neighbor === node.id && !(!include && nodeIds.indexOf(node.id) > -1)) { allCycles.push({ [node.id]: node }); } else { adjList[node.id].push(node2Idx[neighbor]); diff --git a/packages/graph/src/dfs.ts b/packages/graph/src/dfs.ts index 723f44c..97883f8 100644 --- a/packages/graph/src/dfs.ts +++ b/packages/graph/src/dfs.ts @@ -40,7 +40,7 @@ function depthFirstSearchRecursive( }); const neighbors = directed ? - graph.getRelatedEdges(currentNodeId, "out").map(e => graph.getNode(e.target)) + graph.getRelatedEdges(currentNodeId, "out").map((e) => graph.getNode(e.target)) : graph.getNeighbors(currentNodeId) ; diff --git a/packages/graph/src/types.ts b/packages/graph/src/types.ts index 97b3485..3254e4c 100644 --- a/packages/graph/src/types.ts +++ b/packages/graph/src/types.ts @@ -22,7 +22,7 @@ export interface Cluster { export interface ClusterData { clusters: Cluster[]; clusterEdges: Edge[]; - nodeToCluster: Map + nodeToCluster: Map; } export interface ClusterMap { diff --git a/packages/graph/src/utils.ts b/packages/graph/src/utils.ts index 74d4fd9..93fa9a5 100644 --- a/packages/graph/src/utils.ts +++ b/packages/graph/src/utils.ts @@ -96,8 +96,8 @@ function euclideanDistance(source: number[], target: number[]) { if (source.length !== target.length) return 0; let res = 0; source.forEach((s, i) => { - res += Math.pow(s - target[i], 2) - }) + res += Math.pow(s - target[i], 2); + }); return Math.sqrt(res); } From be2a0b54a6df2fcbe30aadcf05804809289fadb3 Mon Sep 17 00:00:00 2001 From: "yuqi.pyq" Date: Mon, 16 Oct 2023 10:22:01 +0800 Subject: [PATCH 09/10] fix: the return value of louvain algorithm --- __tests__/unit/louvain.spec.ts | 15 +++++++++++++++ packages/graph/src/louvain.ts | 7 ++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/__tests__/unit/louvain.spec.ts b/__tests__/unit/louvain.spec.ts index 9f9185f..634cffa 100644 --- a/__tests__/unit/louvain.spec.ts +++ b/__tests__/unit/louvain.spec.ts @@ -43,6 +43,21 @@ describe('Louvain', () => { expect(clusteredData.clusterEdges[0].data.count).toBe(13); expect(clusteredData.clusterEdges[1].data.count).toBe(10); expect(clusteredData.clusterEdges[1].data.weight).toBe(14); + expect(clusteredData.nodeToCluster.get('0')).toBe('1'); + expect(clusteredData.nodeToCluster.get('1')).toBe('1'); + expect(clusteredData.nodeToCluster.get('2')).toBe('1'); + expect(clusteredData.nodeToCluster.get('3')).toBe('1'); + expect(clusteredData.nodeToCluster.get('4')).toBe('1'); + expect(clusteredData.nodeToCluster.get('5')).toBe('2'); + expect(clusteredData.nodeToCluster.get('6')).toBe('2'); + expect(clusteredData.nodeToCluster.get('7')).toBe('2'); + expect(clusteredData.nodeToCluster.get('8')).toBe('2'); + expect(clusteredData.nodeToCluster.get('9')).toBe('2'); + expect(clusteredData.nodeToCluster.get('10')).toBe('3'); + expect(clusteredData.nodeToCluster.get('11')).toBe('3'); + expect(clusteredData.nodeToCluster.get('12')).toBe('3'); + expect(clusteredData.nodeToCluster.get('13')).toBe('3'); + expect(clusteredData.nodeToCluster.get('14')).toBe('3'); }); // it('louvain with large graph', () => { // https://gw.alipayobjects.com/os/antvdemo/assets/data/relations.json diff --git a/packages/graph/src/louvain.ts b/packages/graph/src/louvain.ts index f37ba6d..5ad6a26 100644 --- a/packages/graph/src/louvain.ts +++ b/packages/graph/src/louvain.ts @@ -142,6 +142,7 @@ export function louvain( let uniqueId = 1; const clusters: ClusterMap = {}; const nodeMap: Record; idx: number }> = {}; + const nodeToCluster = new Map(); nodes.forEach((node, i) => { const cid: string = String(uniqueId++); node.data.clusterId = cid; @@ -412,9 +413,13 @@ export function louvain( const clustersArray: Cluster[] = []; Object.keys(finalClusters).forEach((clusterId) => { clustersArray.push(finalClusters[clusterId]); + finalClusters[clusterId].nodes.forEach((node) => { + nodeToCluster.set(node.id, clusterId); + }); }); return { clusters: clustersArray, - clusterEdges + clusterEdges, + nodeToCluster }; } From f1e3fc5c4935cd84e38bc71c083586f563405a5f Mon Sep 17 00:00:00 2001 From: "yuqi.pyq" Date: Mon, 16 Oct 2023 11:36:20 +0800 Subject: [PATCH 10/10] fix: use ID from graphlib instead of NodeID --- packages/graph/src/k-means.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/graph/src/k-means.ts b/packages/graph/src/k-means.ts index 3e35095..4ae1bdd 100644 --- a/packages/graph/src/k-means.ts +++ b/packages/graph/src/k-means.ts @@ -1,8 +1,8 @@ import { isEqual, uniq } from '@antv/util'; -import { Edge } from '@antv/graphlib'; +import { Edge, ID } from '@antv/graphlib'; import { getAllProperties, oneHot, getDistance } from './utils'; import { Vector } from "./vector"; -import { ClusterData, DistanceType, Graph, EdgeData, Cluster, NodeID } from './types'; +import { ClusterData, DistanceType, Graph, EdgeData, Cluster } from './types'; /** * Calculates the centroid based on the distance type and the given index. @@ -42,8 +42,8 @@ export const kMeans = ( ): ClusterData => { const nodes = graph.getAllNodes(); const edges = graph.getAllEdges(); - const nodeToOriginIdx = new Map(); - const nodeToCluster = new Map(); + const nodeToOriginIdx = new Map(); + const nodeToCluster = new Map(); const defaultClusterInfo: ClusterData = { clusters: [ {