一、ForkJoinPool 是什么?
ForkJoinPool 是 Java 7 引入的一个线程池实现,位于 java.util.concurrent 包下,专为分治算法(Divide-and-Conquer)设计,特别适合处理可分解为多个子任务的复杂任务。它的核心思想是将大任务 “拆分”(fork)为小任务,并行执行后再 “合并”(join)结果,同时通过工作窃取(Work Stealing)机制提高线程利用率。
二、核心特点
- 工作窃取机制
每个线程都有自己的任务队列(双端队列),当线程完成自身任务后,会主动从其他线程的队列中 "窃取"任务执行,减少线程空闲时间。 (简单理解:有两个线程做两个任务,另一个线程做完了,会帮另一个线程分担任务,从而提高效率) - 分治任务模型
适合处理可拆分的任务:通过 fork() 拆分任务,通过 join()等待子任务完成并合并结果。 - 并行度控制
默认并行度为 CPU 核心数(Runtime.getRuntime().availableProcessors()),也可手动指定。 - 轻量级线程管理
相比传统线程池,ForkJoinPool 对线程的调度更高效,尤其适合计算密集型任务。
三、核心类与接口
- ForkJoinPool:线程池核心类,负责管理线程和任务调度。
- ForkJoinTask< v >:任务抽象类,是所有可在 ForkJoinPool 中执行的任务的父类,主要子类有:
RecursiveAction:无返回值的任务(void)。
RecursiveTask:有返回值的任务(泛型 V)。
四、ForkJoinPool 方法介绍
1、构造方法
用于创建 ForkJoinPool 实例,控制并行度、线程工厂和异常处理器等:
-
ForkJoinPool():
默认构造器:并行度为 CPU 核心数(Runtime.getRuntime().availableProcessors()),使用默认线程工厂和异常处理器。 -
ForkJoinPool(int parallelism):
指定并行度(期望的线程数),其他参数使用默认值。
parallelism:并行级别,通常设为 CPU 核心数(非强制,实际线程数可能动态调整)。 -
ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode)
全参数构造器:
3.1、factory:创建工作线程的工厂(默认 DefaultForkJoinWorkerThreadFactory)。
3.2、handler:线程未捕获异常的处理器(默认 null)。
3.3、asyncMode:任务队列是否为异步模式(false 为 LIFO 模式,true 为 FIFO 模式,影响任务调度顺序)。
2、提交任务的核心方法
用于向线程池提交 ForkJoinTask 任务(或其他类型任务):
- T invoke(ForkJoinTask task)
提交任务并阻塞等待其完成,返回任务结果。
示例:
ForkJoinPool pool = new ForkJoinPool();
Integer result = pool.invoke(new SumTask(array, 0, array.length)); // 阻塞至任务完成
- void execute(ForkJoinTask<?> task)
提交任务异步执行(无返回值,不阻塞),任务结果需通过 task.join() 获取。
示例:
SumTask task = new SumTask(array, 0, array.length);
pool.execute(task);
// 后续通过 task.join() 获取结果(会阻塞)
-
ForkJoinTask submit(ForkJoinTask task)
提交任务异步执行,返回 ForkJoinTask 对象,可通过其 get() 或 join() 方法获取结果。
与 execute() 类似,但返回任务本身,更灵活。 -
List<Future> invokeAll(Collection<? extends Callable> tasks)
批量提交 Callable 任务,阻塞等待所有任务完成,返回结果列表(类似普通线程池的 invokeAll)。
3、任务执行相关方法
-
boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException
阻塞等待线程池终止(所有任务完成且关闭),超时后返回 false。
需配合 shutdown() 使用。 -
ForkJoinTask<?> pollSubmission()
从提交队列中获取并移除一个未执行的任务(主要用于内部调度,一般不直接调用)。 -
int getQueuedSubmissionCount()
返回等待执行的提交任务数量。
4、线程池状态管理
-
void shutdown()
平缓关闭线程池:不再接受新任务,等待已提交任务完成后终止线程。 -
List shutdownNow()
立即关闭线程池:尝试中断正在执行的任务,返回未执行的任务列表。 -
boolean isShutdown()
判断线程池是否已关闭(调用过 shutdown() 或 shutdownNow())。 -
boolean isTerminated()
判断线程池是否已终止(所有任务完成且线程已退出)。
static ForkJoinPool commonPool()
返回公共线程池(静态单例),适用于轻量级并行任务,无需手动创建线程池。
示例:
ForkJoinPool.commonPool().invoke(new SumTask(array, 0, array.length));
5、获取线程池信息
- int getParallelism()
返回创建时设置的并行度。 - int getPoolSize()
返回当前线程池中的工作线程数量。 - int getActiveThreadCount()
返回正在执行任务的活跃线程数量。 - long getStealCount()
返回工作窃取的总次数(反映线程利用效率,次数越高说明负载越均衡)。 - int getQueuedTaskCount()
返回所有工作线程的任务队列中待执行的任务总数。
6、异常处理
ForkJoinPool 中的任务异常会被包装在 ExecutionException 中,可通过以下方式捕获:
- 对于 RecursiveTask:调用 join() 或 get() 时,异常会被抛出(join() 抛出未检查异常,get() 抛出受检异常)。
- 对于 RecursiveAction:需重写 completedAbnormally() 方法或通过 getException() 获取异常。
五、使用步骤
- 定义任务:继承 RecursiveTask(有返回值)或 RecursiveAction(无返回值),重写 compute() 方法。
- 在 compute() 中实现任务拆分逻辑:
若任务小到无需拆分,则直接执行并返回结果。
若需要拆分,则创建子任务,通过 fork() 提交子任务,通过 join() 获取结果并合并。 - 创建 ForkJoinPool 实例,提交任务并获取结果。
1.计算数组总和(RecursiveTask)
下面用 ForkJoinPool 并行计算数组总和,展示分治思想:
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
// 有返回值的任务:计算数组指定范围的和
class SumTask extends RecursiveTask<Integer> {
// 任务拆分阈值(小于该值则不拆分)
private static final int THRESHOLD = 1000;
private int[] array;
private int start;
private int end;
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
// 若任务足够小,直接计算
if (end - start <= THRESHOLD) {
int sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
} else {
// 拆分任务:分为左右两个子任务
int mid = (start + end) / 2;
SumTask leftTask = new SumTask(array, start, mid); // 左半部分任务
SumTask rightTask = new SumTask(array, mid, end); // 右半部分任务
// 执行子任务(fork() 会将任务提交到线程池,可能并行执行(非阻塞)
leftTask.fork();
// 直接执行右任务(当前线程可参与执行,减少线程创建)
int rightSum = rightTask.compute();
// 等待左任务完成并获取结果
int leftSum = leftTask.join();
// 合并结果
return leftSum + rightSum;
}
}
}
public class ForkJoinPoolExample {
public static void main(String[] args) {
// 创建测试数组
int[] array = new int[10_000_000];
for (int i = 0; i < array.length; i++) {
array[i] = i + 1; // 1~10,000,000
}
// 创建ForkJoinPool(默认并行度为CPU核心数)
try (ForkJoinPool pool = new ForkJoinPool()) {
// 提交任务
SumTask task = new SumTask(array, 0, array.length);
Integer result = pool.invoke(task); // 阻塞等待结果
System.out.println("数组总和:" + result); // 预期结果:50000005000000
}
}
}
2.示例:无返回值任务(RecursiveAction)
如果任务无需返回值,可继承 RecursiveAction,重写 compute() 方法:
代码如下(示例):
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
// 无返回值的任务:打印数组指定范围元素
class PrintTask extends RecursiveAction {
private static final int THRESHOLD = 10;
private int[] array;
private int start;
private int end;
public PrintTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
if (end - start <= THRESHOLD) {
// 直接打印
for (int i = start; i < end; i++) {
System.out.print(array[i] + " ");
}
} else {
// 拆分任务
int mid = (start + end) / 2;
PrintTask leftTask = new PrintTask(array, start, mid);
PrintTask rightTask = new PrintTask(array, mid, end);
leftTask.fork();
rightTask.fork(); // 也可直接执行rightTask.compute()
// 等待子任务完成(无返回值,仅等待)
leftTask.join();
rightTask.join();
}
}
}
public class RecursiveActionExample {
public static void main(String[] args) {
int[] array = new int[30];
for (int i = 0; i < array.length; i++) {
array[i] = i + 1;
}
try (ForkJoinPool pool = new ForkJoinPool()) {
pool.invoke(new PrintTask(array, 0, array.length));
}
}
}
3.示例:分批量保存数据(RecursiveAction)
@Service
public class UserServiceImpl extends ServiceImpl<UserMapper, User> implements UserService {
@Resource
private UserService userService;
/**
* 批量新增用户数据
*/
private void createBomPackage(List<User> users) {
AtomicReference<Exception> exceptionRef = new AtomicReference<>();
if (CollUtil.isNotEmpty(users)) {
EXECUTOR.execute(() -> {
ForkJoinPool forkJoinPool = new ForkJoinPool();
try {
forkJoinPool.submit(new SumRecursiveTask(users, userService));
} catch (Exception e) {
log.error("新增用户数据失败", e);
exceptionRef.set(e);
} finally {
if (!forkJoinPool.isShutdown()) {
forkJoinPool.shutdown();
}
}
});
if (exceptionRef.get() != null) {
exceptionRef.get().printStackTrace();
throw new BaseException("异步任务执行新增用户异常");
}
}
}
/**
* 分批创建类
*/
private static class SumRecursiveTask extends RecursiveAction {
private static final int MAX_STRIDE = 200; //可以处理任务
@Resource
private final UserService userService;
private final List<User> users;
public SumRecursiveTask(List<User> users, UserService userService) {
this.userService= userService; //保存的service接口
this.users= users; //你要保存的List集合
}
@Override
protected void compute() {
int size = users.size(); //本次处理大小
if (size <= MAX_STRIDE) { //如果小于设定阈值,直接保存
userService.saveBatch(temp);
} else { // 如果大于,将任务拆分,分而治之。
int middle = (size) / 2;
SumRecursiveTask left = new SumRecursiveTask(temp.subList(0, middle), userService);
SumRecursiveTask right = new SumRecursiveTask(temp.subList(middle, size), userService);
left.fork(); // 提交左任务(非阻塞)
right.fork(); // 也可直接执行rightTask.compute()
}
}
}
六、工作原理
ForkJoinPool 的核心工作原理可以概括为 “分治(Fork)+ 合并(Join)+ 工作窃取(Work Stealing)”。下面通过一个具体的例子(计算数组总和)来详细说明其运作过程。
1、场景:计算大数组的总和
假设我们有一个包含 8 个元素的数组 [1, 2, 3, 4, 5, 6, 7, 8],需要计算所有元素的总和。如果用单线程计算,直接遍历累加即可;但用 ForkJoinPool 时,会通过 “分治” 将大任务拆分成小任务,并行执行后再合并结果,同时通过 “工作窃取” 提高线程利用率。
2、步骤 1:定义分治任务(RecursiveTask)
首先需要创建一个继承 RecursiveTask(有返回值的任务)的子类,重写 compute() 方法,实现 “拆分任务” 和 “合并结果” 的逻辑。
import java.util.concurrent.RecursiveTask;
// 计算数组总和的任务(有返回值,用 RecursiveTask)
class SumTask extends RecursiveTask<Integer> {
// 任务拆分的阈值:当数组长度 <= 2 时,直接计算(不再拆分)
private static final int THRESHOLD = 2;
private int[] array; // 目标数组
private int start; // 起始索引
private int end; // 结束索引(不包含)
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
int length = end - start;
// 若任务足够小(小于阈值),直接计算结果
if (length <= THRESHOLD) {
int sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
System.out.println("直接计算:" + start + "~" + end + ",结果=" + sum);
return sum;
}
// 否则,拆分任务(Fork)
int mid = (start + end) / 2;
SumTask leftTask = new SumTask(array, start, mid); // 左半部分任务
SumTask rightTask = new SumTask(array, mid, end); // 右半部分任务
// 执行子任务(fork() 会将任务提交到线程池,可能并行执行)
leftTask.fork(); // 拆分左任务
rightTask.fork(); // 拆分右任务
// 合并子任务结果(join() 会等待子任务完成并获取结果)
int leftResult = leftTask.join();
int rightResult = rightTask.join();
int total = leftResult + rightResult;
System.out.println("合并:" + start + "~" + end + ",左=" + leftResult + ",右=" + rightResult + ",总=" + total);
return total;
}
}
3、步骤 2:用 ForkJoinPool 执行任务
创建 ForkJoinPool 实例,提交任务并获取结果:
import java.util.concurrent.ForkJoinPool;
public class ForkJoinDemo {
public static void main(String[] args) {
int[] array = {1, 2, 3, 4, 5, 6, 7, 8};
// 创建 ForkJoinPool(默认并行度为 CPU 核心数,这里假设为 2 个核心)
try (ForkJoinPool pool = new ForkJoinPool()) {
// 提交根任务(计算整个数组的和)
SumTask rootTask = new SumTask(array, 0, array.length);
int result = pool.invoke(rootTask); // 阻塞等待结果
System.out.println("最终总和:" + result); // 输出 36
}
}
}
4、步骤 3:详解工作原理(分治 + 工作窃取)
假设 ForkJoinPool 的并行度为 2(即初始有 2 个工作线程 Thread-1 和 Thread-2),整个执行过程如下:
- 分治(Fork):拆分任务
-
根任务:SumTask(0~8)(计算整个数组的和)。由于数组长度为 8(> 阈值 2),根任务会拆分成两个子任务:
- 左任务:SumTask(0~4)(计算前 4 个元素)。
- 右任务:SumTask(4~8)(计算后 4 个元素)。
-
根任务调用 leftTask.fork() 和 rightTask.fork(),将两个子任务提交到线程池,由 Thread-1 和 Thread-2 分别执行。
- 子任务继续拆分
- Thread-1 执行 SumTask(0~4):
数组长度为 4(> 阈值 2),继续拆分:- 左子任务:SumTask(0~2)(元素 [1,2])。
- 右子任务:SumTask(2~4)(元素 [3,4])。
- Thread-1 调用 fork() 提交这两个子任务,然后等待它们的结果(join())。
- Thread-2 执行 SumTask(4~8):
数组长度为 4(> 阈值 2),继续拆分:- 左子任务:SumTask(4~6)(元素 [5,6])。
- 右子任务:SumTask(6~8)(元素 [7,8])。
- Thread-2 调用 fork() 提交这两个子任务,然后等待结果(join())。
- 执行最小任务(达到阈值)
当子任务的数组长度 ≤ 阈值(2)时,不再拆分,直接计算结果:
- SumTask(0~2) 计算 1+2=3(假设由 Thread-1 执行)。
- SumTask(2~4) 计算 3+4=7(假设 Thread-1 执行完左子任务后,继续执行右子任务)。
- SumTask(4~6) 计算 5+6=11(假设由 Thread-2 执行)。
- SumTask(6~8) 计算 7+8=15(假设 Thread-2 执行完左子任务后,继续执行右子任务)。
- 合并结果(Join)
子任务完成后,上层任务会合并结果:
- SumTask(0~4) 合并 3 + 7 = 10。
- SumTask(4~8) 合并 11 + 15 = 26。
- 根任务 SumTask(0~8) 合并 10 + 26 = 36(最终结果)。
- 工作窃取(Work Stealing):提升效率的关键
假设 Thread-1 先完成了自己的所有子任务(0~2 和 2~4),而 Thread-2 还在处理 6~8。此时 Thread-1 的任务队列已空,它会主动 “窃取”Thread-2 队列中未执行的任务(比如 6~8),帮助 Thread-2 执行,避免线程空闲。
- 窃取规则:线程的任务队列是 “双端队列”(Deque)。自己的任务从队列头部(LIFO)取,窃取的任务从其他队列的尾部(FIFO)取,减少竞争。
- 效果:避免某一线程任务过多而其他线程空闲,充分利用多核 CPU 资源。
5、核心原理总结
- 分治(Fork):将大任务递归拆分成小任务,直到任务足够小(达到阈值),适合并行处理。
- 合并(Join):等待所有子任务完成后,汇总结果得到最终答案。
- 工作窃取(Work Stealing):空闲线程主动窃取其他线程的任务,平衡负载,最大化利用 CPU 核心。
七、注意事项
- 任务拆分粒度:拆分阈值(THRESHOLD)需合理设置,过细会增加任务调度开销,过粗则无法充分利用并行性。
- 异常处理:任务中抛出的异常会被包装在 ExecutionException 中,需通过 get() 或 join() 捕获。
- 资源管理:ForkJoinPool 实现了 AutoCloseable 接口,建议用 try-with-resources 自动关闭。
- 适用场景:适合计算密集型任务(如大数据量计算、排序等),不适合 I/O 密集型任务(线程等待时间长,工作窃取效率低)。
- 与 Executors.newWorkStealingPool() 的关系:newWorkStealingPool() 本质上是 ForkJoinPool 的封装(Java 8+),默认并行度为 CPU 核心数。
ForkJoinPool 是 Java 并行计算的重要工具,尤其在处理大规模数据并行任务时,能有效利用多核 CPU 提升效率。