likes
comments
collection
share

深度解析ForkJoinPool源码:Java并行计算的内部机制

作者站长头像
站长
· 阅读数 24

引言

Java的ForkJoinPool是一个强大的多线程并行计算工具,特别适用于解决分治问题和递归算法。本文将深入探讨ForkJoinPool的源码,解析其内部机制,以帮助您更好地理解并合理使用这一重要的并行计算工具。

ForkJoinPool的基本结构

ForkJoinPool的核心设计思想是将一个大任务划分成多个小任务,然后并行执行这些小任务,最后将结果合并起来。让我们从ForkJoinPool的基本结构开始解析。

1. Worker线程和WorkQueue

ForkJoinPool包括一组Worker线程,每个Worker线程都有一个关联的WorkQueue。Worker线程负责执行任务,WorkQueue用于存储任务。每个Worker线程都有一个本地WorkQueue,同时还可以从其他Worker的队列中窃取任务。

2. ForkJoinTask

任务是通过ForkJoinTask的子类来表示的。ForkJoinTask有两个主要子类:RecursiveTask和RecursiveAction,分别用于有返回值和无返回值的任务。任务通过fork()方法分割成子任务,然后通过join()方法等待子任务的完成。

3. 线程调度

ForkJoinPool中的线程采用工作窃取算法(Work-Stealing)执行任务。每个Worker线程都有一个工作窃取队列,当一个线程完成自己的任务后,它会尝试从其他线程的队列中窃取任务执行。这个机制保持了负载均衡,确保所有线程都在工作。

ForkJoinPool的核心方法

ForkJoinPool的源代码中包含了一些关键方法,用于实现任务的提交、执行和调度。

1. submit()

submit()方法用于提交一个任务到ForkJoinPool中,并返回一个Future对象,可以用于获取任务的结果。

    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
    
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }

将任务添加到提交队列中

    final void externalPush(ForkJoinTask<?> task) {
        WorkQueue[] ws; WorkQueue q; int m;
        int r = ThreadLocalRandom.getProbe();
        //获取当前线程的随机探测值r和运行状态rs。
        int rs = runState;
        if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
            (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
            U.compareAndSwapInt(q, QLOCK, 0, 1)) {
            ForkJoinTask<?>[] a; int am, n, s;
            if ((a = q.array) != null &&
                (am = a.length - 1) > (n = (s = q.top) - q.base)) {
                int j = ((am & s) << ASHIFT) + ABASE;
                U.putOrderedObject(a, j, task);
                U.putOrderedInt(q, QTOP, s + 1);
                U.putIntVolatile(q, QLOCK, 0);
                //如果任务数组只有一个任务(n <= 1),则调用signalWork方法通知其他线程有新任务可执行。
                if (n <= 1)
                    signalWork(ws, q);
                return;
            }
            U.compareAndSwapInt(q, QLOCK, 1, 0);
        }
        //调用externalSubmit方法将任务提交给ForkJoinPool的外部队列中。
        externalSubmit(task);
    }
    private void externalSubmit(ForkJoinTask<?> task) {
        //首先,初始化一个整数变量r,用于保存调用者的探测值。
        int r; 
        //如果r的值为0,说明调用者还没有进行过探测操作,需要进行初始化操作。首先调用ThreadLocalRandom.localInit()方法初始化线程本地随机数生成器,然后再次获取探测值r。
        if ((r = ThreadLocalRandom.getProbe()) == 0) {
            ThreadLocalRandom.localInit();
            r = ThreadLocalRandom.getProbe();
        }
        for (;;) {
            WorkQueue[] ws; WorkQueue q; int rs, m, k;
            boolean move = false;
            // rs < 0 线程池终止
            if ((rs = runState) < 0) {
                tryTerminate(false, false);     // help terminate
                throw new RejectedExecutionException();
            }
            else if ((rs & STARTED) == 0 ||     // initialize
                     ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
                int ns = 0;
                // 锁定 rs 并尝试初始化线程池
                rs = lockRunState();
                try {
                    //如果线程池还没有启动或者工作队列数组为空,需要创建一个新的工作队列数组。
                    if ((rs & STARTED) == 0) {
                        U.compareAndSwapObject(this, STEALCOUNTER, null,
                                               new AtomicLong());
                        // create workQueues array with size a power of two
                        int p = config & SMASK; // ensure at least 2 slots
                        int n = (p > 1) ? p - 1 : 1;
                        n |= n >>> 1; n |= n >>> 2;  n |= n >>> 4;
                        n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
                        workQueues = new WorkQueue[n];
                        ns = STARTED;
                    }
                } finally {
                    //更新线程池的工作队列数组,并将运行状态设置为STARTED。
                    unlockRunState(rs, (rs & ~RSLOCK) | ns);
                }
            }
            //如果线程池已经初始化,并且当前线程的探测值对应的工作队列不为空
            else if ((q = ws[k = r & m & SQMASK]) != null) {
                //如果当前工作队列没有被锁定,并且成功将工作队列的锁定状态从0修改为1
                if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
                    ForkJoinTask<?>[] a = q.array;
                    int s = q.top;
                    boolean submitted = false; // initial submission or resizing
                    try {                      // locked version of push
                        if ((a != null && a.length > s + 1 - q.base) ||
                            (a = q.growArray()) != null) {
                            //计算任务在数组中的索引位置j,并将任务放入数组中
                            int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
                            U.putOrderedObject(a, j, task);
                            U.putOrderedInt(q, QTOP, s + 1);
                            submitted = true;
                        }
                    } finally {
                        U.compareAndSwapInt(q, QLOCK, 1, 0);
                    }
                    //如果成功提交任务到工作队列,发送信号通知其他线程有新的任务可执行,并返回。
                    if (submitted) {
                        signalWork(ws, q);
                        return;
                    }
                }
                move = true;                   // move on failure
            }
            //如果当前线程的探测值对应的工作队列为空,并且运行状态的锁定位为0,说明需要创建一个新的工作队列
            else if (((rs = runState) & RSLOCK) == 0) { // create new queue
                q = new WorkQueue(this, null);
                q.hint = r;
                q.config = k | SHARED_QUEUE;
                q.scanState = INACTIVE;
                rs = lockRunState();           // publish index
                if (rs > 0 &&  (ws = workQueues) != null &&
                    k < ws.length && ws[k] == null)
                    ws[k] = q;                 // else terminated
                unlockRunState(rs, rs & ~RSLOCK);
            }
            else
                move = true;                   // move if busy
            if (move)
                r = ThreadLocalRandom.advanceProbe(r);
        }
    }

激活一个空闲的工作队列,以便执行任务

    final void signalWork(WorkQueue[] ws, WorkQueue q) {
        long c; int sp, i; WorkQueue v; Thread p;
        while ((c = ctl) < 0L) {                       // too few active
            if ((sp = (int)c) == 0) {                  // no idle workers
                if ((c & ADD_WORKER) != 0L)            // too few workers
                    //没有空闲的工作者线程 或工作者线程数太少 尝试添加新的工作线程
                    tryAddWorker(c);
                break;
            }
            if (ws == null)                            // unstarted/terminated
                break;
            if (ws.length <= (i = sp & SMASK))         // terminated
                break;
            if ((v = ws[i]) == null)                   // terminating
                break;
            int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
            int d = sp - v.scanState;                  // screen CAS
            long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
            if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
                v.scanState = vs;                      // activate v
                if ((p = v.parker) != null)
                    U.unpark(p);
                break;
            }
            if (q != null && q.base == q.top)          // no more work
                break;
        }
    }

2. fork()和join()

fork()方法是ForkJoinTask类的成员方法,用于将当前任务分割成子任务,并将子任务推入适当的工作队列中。如果当前线程是ForkJoinWorkerThread,则将子任务推入该工作线程的工作队列中。否则,调用externalPush方法将任务推入适当的工作队列。而join()方法用于等待子任务的完成并获取结果。这是ForkJoinTask的核心操作。

    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }
    public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }

3. execute()

execute()方法用于提交一个任务,但不返回结果。通常用于提交RecursiveAction类型的任务。

    public void execute(ForkJoinTask<?> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
    }

ForkJoinPool的关键源码解析

1. Worker线程的创建

ForkJoinPool会创建一组Worker线程:

     private boolean createWorker() {
        ForkJoinWorkerThreadFactory fac = factory;
        Throwable ex = null;
        ForkJoinWorkerThread wt = null;
        try {
            if (fac != null && (wt = fac.newThread(this)) != null) {
                wt.start();
                return true;
            }
        } catch (Throwable rex) {
            ex = rex;
        }
        deregisterWorker(wt, ex);
        return false;
    }

这段代码创建了一组Worker线程,并启动它们。每个Worker线程都有自己的工作窃取队列。

2. Worker(ForkJoinWorkerThread) 的执行

    public void run() {
        //检查workQueue的数组是否为空 空只跑一次
        if (workQueue.array == null) { // only run once
            Throwable exception = null;
            try {
                onStart();
                pool.runWorker(workQueue);
            } catch (Throwable ex) {
                exception = ex;
            } finally {
                try {
                    onTermination(exception);
                } catch (Throwable ex) {
                    if (exception == null)
                        exception = ex;
                } finally {
                    // 销毁 worker
                    pool.deregisterWorker(this, exception);
                }
            }
        }
    }

执行工作队列中的任务

     final void runWorker(WorkQueue w) {
        w.growArray();                   // allocate queue
        int seed = w.hint;               // initially holds randomization hint
        int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
        //进入循环,不断扫描工作队列。
        for (ForkJoinTask<?> t;;) {
            //如果扫描到有任务,就执行该任务
            if ((t = scan(w, r)) != null)
                w.runTask(t);
            //如果没有任务,就等待工作队列中有任务为止
            else if (!awaitWork(w, r))
                break;
            r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
        }
    }
    final void execLocalTasks() {
            int b = base, m, s;
            ForkJoinTask<?>[] a = array;
            if (b - (s = top - 1) <= 0 && a != null &&
                (m = a.length - 1) >= 0) {
                // 先进先出
                if ((config & FIFO_QUEUE) == 0) {
                    for (ForkJoinTask<?> t;;) {
                        if ((t = (ForkJoinTask<?>)U.getAndSetObject
                             (a, ((m & s) << ASHIFT) + ABASE, null)) == null)
                            break;
                        U.putOrderedInt(this, QTOP, s);
                        // 执行任务
                        t.doExec();
                        if (base - (s = top - 1) > 0)
                            break;
                    }
                }
                //先进后出
                else
                    pollAndExecAll();
            }
        }


3. 工作窃取

private ForkJoinTask<?> scan(WorkQueue w, int r) {
  WorkQueue[] ws; int m;
  if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
    //获取当前扫描状态
    int ss = w.scanState;                     // initially non-negative
    
    //随机一个起始位置,并赋值给k
    for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
      WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
      int b, n; long c;
      //如果k槽位不为空
      if ((q = ws[k]) != null) {
        //base-top小于零,并且任务q不为空
        if ((n = (b = q.base) - q.top) < 0 &&
            (a = q.array) != null) {      // non-empty
          //获取base的偏移量,赋值给i
          long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
          //从base端获取任务,从base端steal
          if ((t = ((ForkJoinTask<?>)
                    U.getObjectVolatile(a, i))) != null &&
              q.base == b) {
            //是active状态
            if (ss >= 0) {
              //更新WorkQueue中数组i索引位置为空,并且更新base的值
              if (U.compareAndSwapObject(a, i, t, null)) {
                q.base = b + 1;
                //n<-1,说明当前队列还有剩余任务,继续唤醒可能存在的其他线程
                if (n < -1)       // signal others
                  signalWork(ws, q);
                //直接返回任务
                return t;
              }
            }
            else if (oldSum == 0 &&   // try to activate
                     w.scanState < 0)
              tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
          }
          
          //如果获取任务失败,则准备换位置扫描
          if (ss < 0)                   // refresh
            ss = w.scanState;
          r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
          origin = k = r & m;           // move and rescan
          oldSum = checkSum = 0;
          continue;
        }
        checkSum += b;
      }
      
      //k到最后,如果等于origin,说明已经扫描了一圈还没扫描到任务
      if ((k = (k + 1) & m) == origin) {    // continue until stable
        if ((ss >= 0 || (ss == (ss = w.scanState))) &&
            oldSum == (oldSum = checkSum)) {
          if (ss < 0 || w.qlock < 0)    // already inactive
            break;
          //准备inactive当前工作队列
          int ns = ss | INACTIVE;       // try to inactivate
          //活动线程数AC减1
          long nc = ((SP_MASK & ns) |
                     (UC_MASK & ((c = ctl) - AC_UNIT)));
          w.stackPred = (int)c;         // hold prev stack top
          U.putInt(w, QSCANSTATE, ns);
          if (U.compareAndSwapLong(this, CTL, c, nc))
            ss = ns;
          else
            w.scanState = ss;         // back out
        }
        checkSum = 0;
      }
    }
  }
  return null;
}

结语

通过深入分析ForkJoinPool的源代码,我们可以更好地理解其内部机制,包括任务的提交和调度、工作窃取算法、Worker线程的管理等方面的细节。这有助于开发者更好地利用ForkJoinPool来实现高效的并行计算。当然,要真正掌握ForkJoinPool,还需要不断实践和深入研究其源码,以满足特定的并发处理需求。