Java中的ForkJoinPool使用详解(带案例)


在这里插入图片描述

一、ForkJoinPool 是什么?

ForkJoinPool 是 Java 7 引入的一个线程池实现,位于 java.util.concurrent 包下,专为分治算法(Divide-and-Conquer)设计,特别适合处理可分解为多个子任务的复杂任务。它的核心思想是将大任务 “拆分”(fork)为小任务,并行执行后再 “合并”(join)结果,同时通过工作窃取(Work Stealing)机制提高线程利用率。

二、核心特点

  1. 工作窃取机制
    每个线程都有自己的任务队列(双端队列),当线程完成自身任务后,会主动从其他线程的队列中 "窃取"任务执行,减少线程空闲时间。 (简单理解:有两个线程做两个任务,另一个线程做完了,会帮另一个线程分担任务,从而提高效率)
  2. 分治任务模型
    适合处理可拆分的任务:通过 fork() 拆分任务,通过 join()等待子任务完成并合并结果。
  3. 并行度控制
    默认并行度为 CPU 核心数(Runtime.getRuntime().availableProcessors()),也可手动指定。
  4. 轻量级线程管理
    相比传统线程池,ForkJoinPool 对线程的调度更高效,尤其适合计算密集型任务。

三、核心类与接口

  1. ForkJoinPool:线程池核心类,负责管理线程和任务调度。
  2. ForkJoinTask< v >:任务抽象类,是所有可在 ForkJoinPool 中执行的任务的父类,主要子类有:
    RecursiveAction:无返回值的任务(void)。
    RecursiveTask:有返回值的任务(泛型 V)。

四、ForkJoinPool 方法介绍

1、构造方法

用于创建 ForkJoinPool 实例,控制并行度、线程工厂和异常处理器等:

  1. ForkJoinPool()
    默认构造器:并行度为 CPU 核心数(Runtime.getRuntime().availableProcessors()),使用默认线程工厂和异常处理器。

  2. ForkJoinPool(int parallelism)
    指定并行度(期望的线程数),其他参数使用默认值。
    parallelism:并行级别,通常设为 CPU 核心数(非强制,实际线程数可能动态调整)。

  3. ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode)
    全参数构造器:
    3.1、factory:创建工作线程的工厂(默认 DefaultForkJoinWorkerThreadFactory)。
    3.2、handler:线程未捕获异常的处理器(默认 null)。
    3.3、asyncMode:任务队列是否为异步模式(false 为 LIFO 模式,true 为 FIFO 模式,影响任务调度顺序)。

2、提交任务的核心方法

用于向线程池提交 ForkJoinTask 任务(或其他类型任务):

  1. T invoke(ForkJoinTask task)
    提交任务并阻塞等待其完成,返回任务结果。
    示例:
ForkJoinPool pool = new ForkJoinPool();
Integer result = pool.invoke(new SumTask(array, 0, array.length)); // 阻塞至任务完成
  1. void execute(ForkJoinTask<?> task)
    提交任务异步执行(无返回值,不阻塞),任务结果需通过 task.join() 获取。
    示例:
SumTask task = new SumTask(array, 0, array.length);
pool.execute(task);
// 后续通过 task.join() 获取结果(会阻塞)
  1. ForkJoinTask submit(ForkJoinTask task)
    提交任务异步执行,返回 ForkJoinTask 对象,可通过其 get() 或 join() 方法获取结果。
    与 execute() 类似,但返回任务本身,更灵活。

  2. List<Future> invokeAll(Collection<? extends Callable> tasks)
    批量提交 Callable 任务,阻塞等待所有任务完成,返回结果列表(类似普通线程池的 invokeAll)。

3、任务执行相关方法

  1. boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException
    阻塞等待线程池终止(所有任务完成且关闭),超时后返回 false。
    需配合 shutdown() 使用。

  2. ForkJoinTask<?> pollSubmission()
    从提交队列中获取并移除一个未执行的任务(主要用于内部调度,一般不直接调用)。

  3. int getQueuedSubmissionCount()
    返回等待执行的提交任务数量。

4、线程池状态管理

  1. void shutdown()
    平缓关闭线程池:不再接受新任务,等待已提交任务完成后终止线程。

  2. List shutdownNow()
    立即关闭线程池:尝试中断正在执行的任务,返回未执行的任务列表。

  3. boolean isShutdown()
    判断线程池是否已关闭(调用过 shutdown() 或 shutdownNow())。

  4. boolean isTerminated()
    判断线程池是否已终止(所有任务完成且线程已退出)。
    static ForkJoinPool commonPool()
    返回公共线程池(静态单例),适用于轻量级并行任务,无需手动创建线程池。
    示例:

ForkJoinPool.commonPool().invoke(new SumTask(array, 0, array.length));

5、获取线程池信息

  1. int getParallelism()
    返回创建时设置的并行度。
  2. int getPoolSize()
    返回当前线程池中的工作线程数量。
  3. int getActiveThreadCount()
    返回正在执行任务的活跃线程数量。
  4. long getStealCount()
    返回工作窃取的总次数(反映线程利用效率,次数越高说明负载越均衡)。
  5. int getQueuedTaskCount()
    返回所有工作线程的任务队列中待执行的任务总数。

6、异常处理

ForkJoinPool 中的任务异常会被包装在 ExecutionException 中,可通过以下方式捕获:

  • 对于 RecursiveTask:调用 join() 或 get() 时,异常会被抛出(join() 抛出未检查异常,get() 抛出受检异常)。
  • 对于 RecursiveAction:需重写 completedAbnormally() 方法或通过 getException() 获取异常。

五、使用步骤

  1. 定义任务:继承 RecursiveTask(有返回值)或 RecursiveAction(无返回值),重写 compute() 方法。
  2. 在 compute() 中实现任务拆分逻辑
    若任务小到无需拆分,则直接执行并返回结果。
    若需要拆分,则创建子任务,通过 fork() 提交子任务,通过 join() 获取结果并合并。
  3. 创建 ForkJoinPool 实例,提交任务并获取结果

1.计算数组总和(RecursiveTask)

下面用 ForkJoinPool 并行计算数组总和,展示分治思想:

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

// 有返回值的任务:计算数组指定范围的和
class SumTask extends RecursiveTask<Integer> {
    // 任务拆分阈值(小于该值则不拆分)
    private static final int THRESHOLD = 1000;
    private int[] array;
    private int start;
    private int end;

    public SumTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        // 若任务足够小,直接计算
        if (end - start <= THRESHOLD) {
            int sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            return sum;
        } else {
            // 拆分任务:分为左右两个子任务
            int mid = (start + end) / 2;
            SumTask leftTask = new SumTask(array, start, mid); // 左半部分任务
            SumTask rightTask = new SumTask(array, mid, end); // 右半部分任务

            //  执行子任务(fork() 会将任务提交到线程池,可能并行执行(非阻塞)
            leftTask.fork();
            // 直接执行右任务(当前线程可参与执行,减少线程创建)
            int rightSum = rightTask.compute();
            // 等待左任务完成并获取结果
            int leftSum = leftTask.join();

            // 合并结果
            return leftSum + rightSum;
        }
    }
}

public class ForkJoinPoolExample {
    public static void main(String[] args) {
        // 创建测试数组
        int[] array = new int[10_000_000];
        for (int i = 0; i < array.length; i++) {
            array[i] = i + 1; // 1~10,000,000
        }

        // 创建ForkJoinPool(默认并行度为CPU核心数)
        try (ForkJoinPool pool = new ForkJoinPool()) {
            // 提交任务
            SumTask task = new SumTask(array, 0, array.length);
            Integer result = pool.invoke(task); // 阻塞等待结果

            System.out.println("数组总和:" + result); // 预期结果:50000005000000
        }
    }
}

2.示例:无返回值任务(RecursiveAction)

如果任务无需返回值,可继承 RecursiveAction,重写 compute() 方法:
代码如下(示例):

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

// 无返回值的任务:打印数组指定范围元素
class PrintTask extends RecursiveAction {
    private static final int THRESHOLD = 10;
    private int[] array;
    private int start;
    private int end;

    public PrintTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected void compute() {
        if (end - start <= THRESHOLD) {
            // 直接打印
            for (int i = start; i < end; i++) {
                System.out.print(array[i] + " ");
            }
        } else {
            // 拆分任务
            int mid = (start + end) / 2;
            PrintTask leftTask = new PrintTask(array, start, mid);
            PrintTask rightTask = new PrintTask(array, mid, end);

            leftTask.fork();
            rightTask.fork(); // 也可直接执行rightTask.compute()

            // 等待子任务完成(无返回值,仅等待)
            leftTask.join();
            rightTask.join();
        }
    }
}

public class RecursiveActionExample {
    public static void main(String[] args) {
        int[] array = new int[30];
        for (int i = 0; i < array.length; i++) {
            array[i] = i + 1;
        }

        try (ForkJoinPool pool = new ForkJoinPool()) {
            pool.invoke(new PrintTask(array, 0, array.length));
        }
    }
}

3.示例:分批量保存数据(RecursiveAction)

@Service
public class UserServiceImpl extends ServiceImpl<UserMapper, User> implements UserService {

    @Resource
    private UserService userService;

    /**
     * 批量新增用户数据
     */
    private void createBomPackage(List<User> users) {
        AtomicReference<Exception> exceptionRef = new AtomicReference<>();
        if (CollUtil.isNotEmpty(users)) {
            EXECUTOR.execute(() -> {
                ForkJoinPool forkJoinPool = new ForkJoinPool();
                try {
                    forkJoinPool.submit(new SumRecursiveTask(users, userService));
                } catch (Exception e) {
                    log.error("新增用户数据失败", e);
                    exceptionRef.set(e);
                } finally {
                    if (!forkJoinPool.isShutdown()) {
                        forkJoinPool.shutdown();
                    }
                }
            });
            if (exceptionRef.get() != null) {
                exceptionRef.get().printStackTrace();
                throw new BaseException("异步任务执行新增用户异常");
            }
        }
    }

    /**
     * 分批创建类
     */
    private static class SumRecursiveTask extends RecursiveAction {
        private static final int MAX_STRIDE = 200;      //可以处理任务
        @Resource
        private final UserService userService;
        private final List<User> users;
        
        public SumRecursiveTask(List<User> users, UserService userService) {
            this.userService= userService; //保存的service接口
            this.users= users; //你要保存的List集合
        }
        
        @Override
        protected void compute() {
        
            int size = users.size(); //本次处理大小
            
            if (size <= MAX_STRIDE) { //如果小于设定阈值,直接保存
                userService.saveBatch(temp);
            } else {        // 如果大于,将任务拆分,分而治之。
                int middle = (size) / 2;
                
                SumRecursiveTask left = new SumRecursiveTask(temp.subList(0, middle), userService);
                SumRecursiveTask right = new SumRecursiveTask(temp.subList(middle, size), userService);
               
               left.fork();   // 提交左任务(非阻塞)
                right.fork(); // 也可直接执行rightTask.compute()
            }
        }
    }

六、工作原理

ForkJoinPool 的核心工作原理可以概括为 “分治(Fork)+ 合并(Join)+ 工作窃取(Work Stealing)”。下面通过一个具体的例子(计算数组总和)来详细说明其运作过程。

1、场景:计算大数组的总和

假设我们有一个包含 8 个元素的数组 [1, 2, 3, 4, 5, 6, 7, 8],需要计算所有元素的总和。如果用单线程计算,直接遍历累加即可;但用 ForkJoinPool 时,会通过 “分治” 将大任务拆分成小任务,并行执行后再合并结果,同时通过 “工作窃取” 提高线程利用率。

2、步骤 1:定义分治任务(RecursiveTask)

首先需要创建一个继承 RecursiveTask(有返回值的任务)的子类,重写 compute() 方法,实现 “拆分任务” 和 “合并结果” 的逻辑。

import java.util.concurrent.RecursiveTask;

// 计算数组总和的任务(有返回值,用 RecursiveTask)
class SumTask extends RecursiveTask<Integer> {
    // 任务拆分的阈值:当数组长度 <= 2 时,直接计算(不再拆分)
    private static final int THRESHOLD = 2;
    private int[] array; // 目标数组
    private int start;   // 起始索引
    private int end;     // 结束索引(不包含)

    public SumTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        int length = end - start;
        // 若任务足够小(小于阈值),直接计算结果
        if (length <= THRESHOLD) {
            int sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            System.out.println("直接计算:" + start + "~" + end + ",结果=" + sum);
            return sum;
        }

        // 否则,拆分任务(Fork)
        int mid = (start + end) / 2;
        SumTask leftTask = new SumTask(array, start, mid);  // 左半部分任务
        SumTask rightTask = new SumTask(array, mid, end);   // 右半部分任务

        // 执行子任务(fork() 会将任务提交到线程池,可能并行执行)
        leftTask.fork();  // 拆分左任务
        rightTask.fork(); // 拆分右任务

        // 合并子任务结果(join() 会等待子任务完成并获取结果)
        int leftResult = leftTask.join();
        int rightResult = rightTask.join();
        int total = leftResult + rightResult;
        System.out.println("合并:" + start + "~" + end + ",左=" + leftResult + ",右=" + rightResult + ",总=" + total);
        return total;
    }
}

3、步骤 2:用 ForkJoinPool 执行任务

创建 ForkJoinPool 实例,提交任务并获取结果:

import java.util.concurrent.ForkJoinPool;

public class ForkJoinDemo {
    public static void main(String[] args) {
        int[] array = {1, 2, 3, 4, 5, 6, 7, 8};
        // 创建 ForkJoinPool(默认并行度为 CPU 核心数,这里假设为 2 个核心)
        try (ForkJoinPool pool = new ForkJoinPool()) {
            // 提交根任务(计算整个数组的和)
            SumTask rootTask = new SumTask(array, 0, array.length);
            int result = pool.invoke(rootTask); // 阻塞等待结果
            System.out.println("最终总和:" + result); // 输出 36
        }
    }
}

4、步骤 3:详解工作原理(分治 + 工作窃取)

假设 ForkJoinPool 的并行度为 2(即初始有 2 个工作线程 Thread-1 和 Thread-2),整个执行过程如下:

  1. 分治(Fork):拆分任务
  • 根任务:SumTask(0~8)(计算整个数组的和)。由于数组长度为 8(> 阈值 2),根任务会拆分成两个子任务:

    • 左任务:SumTask(0~4)(计算前 4 个元素)。
    • 右任务:SumTask(4~8)(计算后 4 个元素)。
  • 根任务调用 leftTask.fork() 和 rightTask.fork(),将两个子任务提交到线程池,由 Thread-1 和 Thread-2 分别执行。

  1. 子任务继续拆分
  • Thread-1 执行 SumTask(0~4):
    数组长度为 4(> 阈值 2),继续拆分:
    • 左子任务:SumTask(0~2)(元素 [1,2])。
    • 右子任务:SumTask(2~4)(元素 [3,4])。
      • Thread-1 调用 fork() 提交这两个子任务,然后等待它们的结果(join())。
  • Thread-2 执行 SumTask(4~8):
    数组长度为 4(> 阈值 2),继续拆分:
    • 左子任务:SumTask(4~6)(元素 [5,6])。
    • 右子任务:SumTask(6~8)(元素 [7,8])。
      • Thread-2 调用 fork() 提交这两个子任务,然后等待结果(join())。
  1. 执行最小任务(达到阈值)
    当子任务的数组长度 ≤ 阈值(2)时,不再拆分,直接计算结果:
  • SumTask(0~2) 计算 1+2=3(假设由 Thread-1 执行)。
  • SumTask(2~4) 计算 3+4=7(假设 Thread-1 执行完左子任务后,继续执行右子任务)。
  • SumTask(4~6) 计算 5+6=11(假设由 Thread-2 执行)。
  • SumTask(6~8) 计算 7+8=15(假设 Thread-2 执行完左子任务后,继续执行右子任务)。
  1. 合并结果(Join)
    子任务完成后,上层任务会合并结果:
  • SumTask(0~4) 合并 3 + 7 = 10。
  • SumTask(4~8) 合并 11 + 15 = 26。
  • 根任务 SumTask(0~8) 合并 10 + 26 = 36(最终结果)。
  1. 工作窃取(Work Stealing):提升效率的关键
    假设 Thread-1 先完成了自己的所有子任务(0~2 和 2~4),而 Thread-2 还在处理 6~8。此时 Thread-1 的任务队列已空,它会主动 “窃取”Thread-2 队列中未执行的任务(比如 6~8),帮助 Thread-2 执行,避免线程空闲。
  • 窃取规则:线程的任务队列是 “双端队列”(Deque)。自己的任务从队列头部(LIFO)取,窃取的任务从其他队列的尾部(FIFO)取,减少竞争。
  • 效果:避免某一线程任务过多而其他线程空闲,充分利用多核 CPU 资源。

5、核心原理总结

  • 分治(Fork):将大任务递归拆分成小任务,直到任务足够小(达到阈值),适合并行处理。
  • 合并(Join):等待所有子任务完成后,汇总结果得到最终答案。
  • 工作窃取(Work Stealing):空闲线程主动窃取其他线程的任务,平衡负载,最大化利用 CPU 核心。

七、注意事项

  1. 任务拆分粒度:拆分阈值(THRESHOLD)需合理设置,过细会增加任务调度开销,过粗则无法充分利用并行性。
  2. 异常处理:任务中抛出的异常会被包装在 ExecutionException 中,需通过 get() 或 join() 捕获。
  3. 资源管理:ForkJoinPool 实现了 AutoCloseable 接口,建议用 try-with-resources 自动关闭。
  4. 适用场景:适合计算密集型任务(如大数据量计算、排序等),不适合 I/O 密集型任务(线程等待时间长,工作窃取效率低)。
  5. 与 Executors.newWorkStealingPool() 的关系:newWorkStealingPool() 本质上是 ForkJoinPool 的封装(Java 8+),默认并行度为 CPU 核心数。

ForkJoinPool 是 Java 并行计算的重要工具,尤其在处理大规模数据并行任务时,能有效利用多核 CPU 提升效率。

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

喝汽水的猫^

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值