首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何利用Spliterator.trySplit在N核上进行并行计算?

如何利用Spliterator.trySplit在N核上进行并行计算?
EN

Stack Overflow用户
提问于 2022-08-19 09:46:54
回答 1查看 69关注 0票数 0

假设我有一个包含10000个元素的列表,并希望在6个核上处理它们。我不想使用现有的Stream,我希望自己从头开始使用它(为了学习)。Spliterator接口似乎很适合这个目的。但是,无论何时调用,它都会将集合分成两部分。我可以得到5000-5000分裂,然后再进行一次拆分,得到2500-2500-2500-2500,然后2500-2500-2500-1250-1250 -1250-1250-1250-1250,将我的初始收藏分成6个部分。它似乎不平衡,没有办法平衡它超过6个核心。

来自Java.Doc

API :一种理想的trySplit方法,有效地(不需要遍历)将其元素分割成两半,允许平衡的并行计算。

然而,Stream .parallel()似乎以某种方式解决了这个问题。我试着阅读资料来源,但仍然无法理解其中的要点。也许有人能向我解释一下高层次的方法。

EN

回答 1

Stack Overflow用户

发布于 2022-08-19 12:21:19

如果您想要重新实现并行流提供的功能,那么除了将任务划分为子任务之外,还需要注意这些堆栈的执行并加入结果的结果。

在这个框架下,并行流使用Fork/Join框架

只需要Spliterator将数据拆分为子任务。但是,与合并中间结果的任务和正确性一起分配工作线程的顺序是通过Fork/Join实现的。

如果您想自己做,可以扩展抽象类RecursiveTask并重写它的方法compute()。这将是您的任务的“容器”(还有RecursiveAction类,它用于执行一个操作,不产生值,但问题是计算,我们需要并获得一个结果- RecursiveTask更适合用于这个目的)。

为了使它更灵活,您可以添加一个类型为FunctionPredicate的属性,该属性将在实例化它时提供,但它与并行流的功能和灵活性不匹配。

在实现compute()时,您需要提供拆分任务的逻辑。您可以为此使用Spliterator,如果源允许访问随机元素(如列表或数组),则可以手动完成。

如果您选择使用Spliterator来划分数据集,则可以使用trySplit()方法,该方法返回Spliterator,即不能进一步拆分null数据。因此,如果trySplit()生成null,则需要处理当前分配器的其余元素。否则,您需要基于trySplit()返回的新分配器创建一个新任务,并在其上应用fork(),然后将通过处理当前分配器中的其余元素而产生的结果与应用于新任务的join()方法返回的值合并。

但是注意到Spliterator在处理数据时将面临一个问题。与Iterator相反,这个接口不声明允许直接访问元素的方法,这不是它的目的。

Spliterator只提供了几种允许使用其元素拨号的方法:forEachRemaining()tryAdvance()。第一个是void,第二个返回boolean值,两者都期望有一个Consumer作为参数。这意味着您将被迫使用有状态函数(这是,而不是,一个很好的实践),以便从compute()返回一个值。

想要在6核上处理它们

通过使用ForkJoinPool的参数化构造之一,我们可以指定所需的并行级别(同时占用的最大线程数)。或者,我们也可以使用来自Executors类的Java8工厂方法Executors

基于分配器的并行处理

RecursiveTask实现:

代码语言:javascript
复制
public static class Task<T> extends RecursiveTask<T> {
    
    private Spliterator<T> spliterator;
    private BinaryOperator<T> accumulator;
    private Predicate<T> predicate = t -> true;
    private T identity;
    
    public Task(Spliterator<T> spliterator, BinaryOperator<T> accumulator, T identity) {
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
    }
    
    public Task(Spliterator<T> spliterator, BinaryOperator<T> accumulator, T identity, Predicate<T> predicate) {
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
        this.predicate = predicate;
    }
    
    @Override
    protected T compute() {
        Spliterator<T> newSpliterator = spliterator.trySplit();            
        AtomicReference<T> result = new AtomicReference<>(identity);
        
        if (newSpliterator != null) {
            Task<T> newTask = new Task<>(newSpliterator, accumulator, identity, predicate);
            newTask.fork();
            forEachRemaining(spliterator, result);
            return accumulator.apply(result.get(), newTask.join());
        }
        forEachRemaining(spliterator, result);
        return result.get();
    }
    
    private void forEachRemaining(Spliterator<T> spliterator, AtomicReference<T> result) {
        spliterator.forEachRemaining(t -> {
            if (predicate.test(t)) {
                result.set(accumulator.apply(result.get(), t));
            }
        });
    }
}

main() --让我们生成给定集合中的所有数字,并分别将所有奇数和偶数元素相加。

代码语言:javascript
复制
public static void main(String[] args) {

    ForkJoinPool pool = new ForkJoinPool(6); // required parallelism 6
    
    Set<Integer> test = Set.of(1, 2, 3, 4, 5, 6, 7, 8, 9);
    
    System.out.println(pool.invoke(new Task<>(test.spliterator(), Integer::sum, 0)));
    System.out.println(pool.invoke(new Task<>(test.spliterator(), Integer::sum, 0, t -> t % 2 == 0)));
    System.out.println(pool.invoke(new Task<>(test.spliterator(), Integer::sum, 0, t -> t % 2 != 0)));
}

输出:

代码语言:javascript
复制
45   // total of: 1, 2, 3, 4, 5, 6, 7, 8, 9
20   // total of: 2, 4, 6, 8
25   // total of: 1, 3, 5, 7, 9

复迭代器+迭代器

我们可以通过引入Iterator作为附加属性来改进上述方法。

这将使Spliterator只负责拆分任务,同时Iterator将用于处理数据。它还允许避免使用像前面示例中那样的有状态函数。

RecursiveTask实现:

代码语言:javascript
复制
public static class Task<T> extends RecursiveTask<T> {
    
    private Iterator<T> iterator;
    private Spliterator<T> spliterator;
    private BinaryOperator<T> accumulator;
    private Predicate<T> predicate = t -> true;
    private T identity;
    
    public Task(Iterator<T> iterator, Spliterator<T> spliterator, BinaryOperator<T> accumulator, T identity) {
        this.iterator = iterator;
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
    }
    
    public Task(Iterator<T> iterator, Spliterator<T> spliterator,
                BinaryOperator<T> accumulator, T identity, Predicate<T> predicate) {
        
        this.iterator = iterator;
        this.spliterator = spliterator;
        this.accumulator = accumulator;
        this.identity = identity;
        this.predicate = predicate;
    }
    
    @Override
    protected T compute() {
        Spliterator<T> newSpliterator = spliterator.trySplit();
        
        if (newSpliterator != null) {
            Task<T> newTask = new Task<>(iterator, newSpliterator, accumulator, identity, predicate);
            newTask.fork();
            T result = forEachRemaining(iterator);
            return accumulator.apply(result, newTask.join());
        }
        return forEachRemaining(iterator);
    }
    
    private T forEachRemaining(Iterator<T> iterator) {
        T result = identity;
        while (iterator.hasNext()) {
            T next = iterator.next();
            if (predicate.test(next)) {
                result = accumulator.apply(result, next);
            }
        }
        return result;
    }
}

main() -相同的样本数据。

代码语言:javascript
复制
public static void main(String[] args) {

    ForkJoinPool pool = new ForkJoinPool(6); // required parallelism 6
    
    Set<Integer> test = Set.of(1, 2, 3, 4, 5, 6, 7, 8, 9);
    
    System.out.println(pool.invoke(new Task<>(test.iterator(), test.spliterator(), Integer::sum, 0)));
    System.out.println(pool.invoke(new Task<>(test.iterator(), test.spliterator(), Integer::sum, 0, t -> t % 2 == 0)));
    System.out.println(pool.invoke(new Task<>(test.iterator(), test.spliterator(), Integer::sum, 0, t -> t % 2 != 0)));
}

输出:

代码语言:javascript
复制
45   // total of: 1, 2, 3, 4, 5, 6, 7, 8, 9
20   // total of: 2, 4, 6, 8
25   // total of: 1, 3, 5, 7, 9
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73414878

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档