funkygao/cp-ddd-framework

View on GitHub
dddplus-visualization/src/main/java/io/github/dddplus/ast/algorithm/KMeans.java

Summary

Maintainability
B
4 hrs
Test Coverage
/*
 * Copyright DDDplus Authors.
 *
 * Licensed under the Apache License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
 */
package io.github.dddplus.ast.algorithm;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * KMeans cluster java method names.
 */
public class KMeans {
    private static final int NUM_CLUSTERS = 3;
    private static final int MAX_ITERATIONS = 100;
    private static final double THRESHOLD = 0.001;

    private static Random random = new Random();

    private int numClusters; // 簇的数量
    private int maxIterations; // 最大迭代次数
    private double threshold; // 收敛阈值

    public KMeans() {
        this(NUM_CLUSTERS, MAX_ITERATIONS, THRESHOLD);
    }

    public KMeans(int numClusters, int maxIterations, double threshold) {
        this.numClusters = numClusters;
        this.maxIterations = maxIterations;
        this.threshold = threshold;
    }

    public List<List<String>> cluster(List<double[]> vectors, List<String> labels, Map<String, double[]> vectorMap) {
        // 随机初始化簇中心
        List<double[]> centroids = new ArrayList<>();
        for (int i = 0; i < numClusters; i++) {
            centroids.add(vectors.get(random.nextInt(vectors.size())));
        }

        // 迭代聚类
        List<List<String>> clusters = new ArrayList<>();
        for (int i = 0; i < maxIterations; i++) {
            // 分配样本到最近的簇
            clusters.clear();
            for (int j = 0; j < numClusters; j++) {
                clusters.add(new ArrayList<>());
            }
            for (int j = 0; j < vectors.size(); j++) {
                double[] vector = vectors.get(j);
                double minDistance = Double.MAX_VALUE;
                int nearestCluster = -1;
                for (int k = 0; k < numClusters; k++) {
                    double[] centroid = centroids.get(k);
                    double distance = euclideanDistance(vector, centroid);
                    if (distance < minDistance) {
                        minDistance = distance;
                        nearestCluster = k;
                    }
                }
                clusters.get(nearestCluster).add(labels.get(j));
            }

            // 更新簇中心
            boolean converged = true;
            for (int j = 0; j < numClusters; j++) {
                double[] newCentroid = calculateCentroid(clusters.get(j), vectorMap);
                double[] oldCentroid = centroids.get(j);
                if (euclideanDistance(newCentroid, oldCentroid) > threshold) {
                    centroids.set(j, newCentroid);
                    converged = false;
                }
            }

            // 如果簇中心未发生变化,则结束迭代
            if (converged) {
                break;
            }
        }

        return clusters;
    }

    private double[] calculateCentroid(List<String> cluster, Map<String, double[]> vectorMap) {
        int dimension = vectorMap.values().iterator().next().length;
        double[] centroid = new double[dimension];
        for (String label : cluster) {
            double[] vector = vectorMap.get(label);
            for (int i = 0; i < dimension; i++) {
                centroid[i] += vector[i];
            }
        }
        for (int i = 0; i < dimension; i++) {
            centroid[i] /= cluster.size();
        }
        return centroid;
    }

    private double euclideanDistance(double[] a, double[] b) {
        double sum = 0.0;
        for (int i = 0; i < a.length; i++) {
            sum += Math.pow(a[i] - b[i], 2);
        }
        return Math.sqrt(sum);
    }
}