Fork/Join Parallelism与ForkJoinPool

Fork/Join 算法简介

Fork/Join并行计算,是一种基于多核处理器、以分治的算法思想为基础、以尽可能利用硬件计算资源为目标的并行编程方式。
它的思想其实很简单:一个大型的任务可以被划分成小型任务,而这些小型任务的计算结果可以被合并成整体的结果。只要这些小任务是独立的,他们就可以被并行计算。

这十分类似于经典算法的分治算法,实际上For
k/Join框架就是并行的分治算法,典型的伪代码如下:

Result solve(Problem problem) {
    if (problem is small) 
        directly solve
    else {
        split problem into independent
        fork new subtasks to solve each part
        join all subtasks
        compose result from subresults
    }
}

其中,fork操作会开启一个新的并行的子任务。
Forkjoin操作会使当前任务阻塞,直到子任务完成。ForkFork/Join算法和其他的分治算法一样总是递归的,大任务会重复地被分割成小任务,而小任务也会被持续地分割成小小任务,直到一个任务足够的小,然后使用常规算法来计算。

Fork/Join 框架设计

Fork/Join框架建立在传统的线程池之上,其基本设计思想是:

  1. 任务(Task)不是线程(Thread),任务比线程更小更轻量;
  2. 每一个worker线程都维护着一个任务队列。默认状况下,worker线程的个数等于CPU的核心数;
  3. 任务队列是基于双端队列(deque)的,支持从顶部插入任务(push),从顶部弹出任务(pop)和从尾部取出任务(take);
  4. 一个任务只能将分割后的子任务插入到自己线程所维护的队列中,一个worker线程只能从自己队列的顶部以“后进先出”的方式取出一个任务;
  5. 如果一个worker线程的队列里没有可执行的任务了,那么它会随机地从其他线程的队列中以“先进先出”的方式从尾部取走(“窃取”, “steal”)一个并执行。
    Work-Stealing

Java中的ForkJoinPool

在Java 1.7中,ForkJoinPool被引入,常被计算密集型的并行程序当作线程池。
该线程池一般使用其默认配置的commonPool即可满足大部分的计算需求。
Java 1.8中的并行流(parallel stream)就是基于ForkJoinPool的commonPool实现的。
如果需要自定义线程池中线程的数量,可以使用其构造器构建线程池,只需要1个参数——并行级别(parallelism),即线程的数量。

举个例子,并行找出一个Integer型数组中最大的元素:

MainApp.java

import java.util.concurrent.ForkJoinPool;
import java.util.stream.Stream;

public class MainApp {
    public static void main(String[] args) {
        Integer[] data = Stream.iterate(1, i->i+1)
                .limit(1000000).toArray(Integer[]::new);

        ForkJoinPool pool = ForkJoinPool.commonPool();
        FindMaxTask rootTask = new FindMaxTask(data, 0, data.length);

        Integer result = pool.invoke(rootTask);

        System.out.println(result);
    }
}

FindMaxTask.java

import java.util.concurrent.RecursiveTask;

public class FindMaxTask extends RecursiveTask<Integer> {

    static final int THRESHOLD = 4;

    private Integer[] arr;
    private int start, end;

    public FindMaxTask(Integer[] arr, int start, int end) {
        this.arr = arr;
        this.start = start;
        this.end = end;
    }


    @Override
    protected Integer compute() {
        if (end - start <= THRESHOLD) {
            return computeDirectly();
        } else {
            int mid = (start + end) / 2;
            FindMaxTask left = new FindMaxTask(arr, start, mid);
            FindMaxTask right = new FindMaxTask(arr, mid, end);

            invokeAll(left, right);

            return Math.max(left.join(), right.join());
        }
    }

    private Integer computeDirectly() {
        int max = Integer.MIN_VALUE;
        for (int i = start; i < end; i++) {
            if (arr[i] > max) {
                max = arr[i];
            }
        }
        return max;
    }
}