首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Java中K-均值聚类算法的多线程实现

Java中K-均值聚类算法的多线程实现
EN

Code Review用户
提问于 2021-02-26 18:50:28
回答 1查看 331关注 0票数 0

你好,我编写了一个多线程实现的K-均值聚类算法.其主要目标是在多核CPU上实现正确性和可扩展性能。我希望代码没有竞争条件,没有数据竞争,并且使用更多的CPU内核可以很好地扩展。

代码语言:javascript
复制
package bg.unisofia.fmi.rsa;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ParallelKmeans {

    private static CountDownLatch countDownLatch;
    private final int n;
    private final int k;
    public int numThreads = 1;
    List<Node> observations = new ArrayList<>();
    float[][] clusters;

    public ParallelKmeans(int n, int k) {
        this.n = n;
        this.k = k;
        clusters = new float[k][n];
        for (float[] cluster : clusters) {
            for (int i = 0; i < cluster.length; i++) {
                cluster[i] = (float) Math.random();
            }
        }
    }

    public void assignStep(ExecutorService executorService) throws InterruptedException {
        Runnable[] assignWorkers = new AssignWorker[numThreads];
        final int chunk = observations.size() / assignWorkers.length;
        countDownLatch = new CountDownLatch(numThreads);
        for (int j = 0; j < assignWorkers.length; j++) {
            assignWorkers[j] = new AssignWorker(j * chunk, (j + 1) * chunk);
            executorService.execute(assignWorkers[j]);
        }
        countDownLatch.await();

    }

    public void updateStep(ExecutorService executorService) throws InterruptedException {

        countDownLatch = new CountDownLatch(numThreads);

        UpdateWorker[] updateWorkers = new UpdateWorker[numThreads];
        final int chunk = observations.size() / updateWorkers.length;
        for (int j = 0; j < updateWorkers.length; j++) {
            updateWorkers[j] = new UpdateWorker(j * chunk, (j + 1) * chunk);
            executorService.execute(updateWorkers[j]);
        }
        countDownLatch.await();
        clusters = new float[k][n];
        int[] counts = new int[k];

        for (UpdateWorker u : updateWorkers) {
            VectorMath.add(counts, u.getCounts());
            for (int j = 0; j < k; j++) {
                VectorMath.add(clusters[j], u.getClusters()[j]);
            }
        }

        for (int j = 0; j < clusters.length; j++) {
            if (counts[j] != 0) {
                VectorMath.divide(clusters[j], counts[j]);
            }
        }
    }

    void cluster() throws InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
        for (int i = 0; i < 50; i++) {
            assignStep(executorService);
            updateStep(executorService);
        }
        executorService.shutdown();
    }

    public static class Node {
        float[] vec;
        int cluster;
    }

    class AssignWorker implements Runnable {
        int l, r;

        public AssignWorker(int l, int r) {
            this.l = l;
            this.r = r;
        }

        @Override
        public void run() {
            List<Node> chunk = observations.subList(l, r);
            for (Node ob : chunk) {
                float minDist = Float.POSITIVE_INFINITY;
                int idx = 0;
                for (int i = 0; i < clusters.length; i++) {
                    if (minDist > VectorMath.dist(ob.vec, clusters[i])) {
                        minDist = VectorMath.dist(ob.vec, clusters[i]);
                        idx = i;
                    }
                }
                ob.cluster = idx;
            }
            countDownLatch.countDown();
        }
    }

    class UpdateWorker implements Runnable {
        int[] counts;
        int l, r;
        float[][] clusters;

        UpdateWorker(int l, int r) {
            this.l = l;
            this.r = r;
        }

        int[] getCounts() {
            return counts;
        }

        public float[][] getClusters() {
            return clusters;
        }

        @Override
        public void run() {
            this.counts = new int[k];
            this.clusters = new float[k][n];
            for (Node ob : observations.subList(l, r)) {
                VectorMath.add(this.clusters[ob.cluster], ob.vec);
                this.counts[ob.cluster]++;
            }
            countDownLatch.countDown();
        }
    }

}
EN

回答 1

Code Review用户

回答已采纳

发布于 2021-03-07 17:48:49

接口

您的类接口令人困惑。您有一个内部方法cluster,它似乎是ParallelKmeans类的主要入口点。但是,它随后调用了两个执行实际工作的公共方法(assignStepupdateStep)。这似乎不对。特别是因为assignStepupdateStep不能同时安全运行。

countDownLatch

您使用的是静态CoundDownLatch,您正在assignStepupdateStep方法中重新创建它。这对我来说没什么意义。通过让它是静态的,您就可以在ParallelKmeans类的所有实例之间共享它。这真的是预期的行为吗?当您在两个公共方法中重新初始化静态时,它会产生意外更改的可能性。如果希望继续使用CountDownLatch,请考虑将其作为每个公共方法的局部变量,并将其传递给工作人员的构造函数,以便他们能够访问它。

多少个线程

您正在创建一个线程池,该线程池基于机器拥有的处理器数量。但是,您的更新/分配步骤都使用了成员变量numThreads,该变量硬编码为1。这种断开很奇怪。考虑更改启动代码以计算要使用多少线程,然后使用此数字执行构造和分配。

票数 2
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/256483

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档