likes
comments
collection
share

并发-AQS之CyclicBarrier源码解读

作者站长头像
站长
· 阅读数 10

CyclicBarrier是Java并发包中的一个类,它用于协调多个线程之间的同步。CyclicBarrier允许多个线程等待彼此达到一个共同的屏障点,然后同时继续执行。它是一种同步机制,用于控制多个线程的执行流程。

CyclicBarrier的主要特点如下:

  1. 它允许一组线程互相等待,直到达到一个共同的屏障点,然后同时继续执行。
  2. 它是可重用的,即在所有线程都通过屏障后,它可以被重置并再次使用。
  3. 它支持一个可选的回调方法,当屏障点被达到时,该方法将被调用。
  4. 它可以用于任何需要等待多个线程完成某个操作的场景,例如多线程数据处理或计算。

UML类图如下:

并发-AQS之CyclicBarrier源码解读

构造函数

CyclicBarrier(int parties, Runnable barrierAction)

带回调方法的构造函数

private static class Generation {
    boolean broken = false;
}

/** The lock for guarding barrier entry */
private final ReentrantLock lock = new ReentrantLock();
/** Condition to wait on until tripped */
private final Condition trip = lock.newCondition();
/** The number of parties */
private final int parties;
/* The command to run when tripped */
private final Runnable barrierCommand;
/** The current generation */
private Generation generation = new Generation();

/**
 * Number of parties still waiting. Counts down from parties to 0
 * on each generation.  It is reset to parties on each new
 * generation or when broken.
 */
private int count;

public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    //需要等待的线程数
    this.parties = parties;
    //count变量表示当前还需要等待的线程数量
    this.count = parties;
    this.barrierCommand = barrierAction;
}

CyclicBarrier(int parties)

不带回调函数构造函数,传入等待的线程数

public CyclicBarrier(int parties) {
    this(parties, null);
}

核心方法

await() throws InterruptedException, BrokenBarrierException

该方法用于使当前线程等待其他线程到达屏障点。如果当前线程不是最后一个到达的线程,那么它将被阻塞,直到其他线程都到达屏障点。如果当前线程是最后一个到达的线程,那么它将唤醒所有被阻塞的线程,并继续执行。

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

dowait()方法是CyclicBarrier的核心方法,它会进行一些计数器的操作,以及等待和唤醒线程等操作。 解析如下

private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        final Generation g = generation;
        // broken为true,说明屏障点已经破坏
        if (g.broken)
            throw new BrokenBarrierException();
        // 线程中断
        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }
        //剩余等待线程数减一
        int index = --count;
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                //创建一个新的代数,重置计数器,唤醒所有等待的线程
                nextGeneration();
                //等待所有线程都执行await返回
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        for (;;) {
            try {
                //不需要限时等待
                if (!timed)
                    //不超时在此等待
                    trip.await();
                else if (nanos > 0L)
                    //限时在此等待
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

breakBarrier设置破坏屏障点标记

private void breakBarrier() {
    generation.broken = true;
    count = parties;
    trip.signalAll();
}

await(long timeout, TimeUnit unit)

超时时间的await()方法。该方法会使当前线程等待其他线程到达屏障点,如果等待超时,当前线程会抛出TimeoutException异常。 该方法接受两个参数:timeout和unit,分别表示等待的时间和时间单位

public int await(long timeout, TimeUnit unit)
    throws InterruptedException,
           BrokenBarrierException,
           TimeoutException {
    return dowait(true, unit.toNanos(timeout));
}

其他方法

getNumberWaiting()

获取当前等待的线程数

public int getNumberWaiting() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        return parties - count;
    } finally {
        lock.unlock();
    }
}

reset()

重置等待线程数

public void reset() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        breakBarrier();   // break the current generation
        nextGeneration(); // start a new generation
    } finally {
        lock.unlock();
    }
}

isBroken()

h获取破坏屏障点标记

public boolean isBroken() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        return generation.broken;
    } finally {
        lock.unlock();
    }
}

getParties()

获取等待线程数

public int getParties() {
    return parties;
}

其他

使用举例

public static void main(String[] args) {
    final int THREAD_COUNT = 5;

    CyclicBarrier barrier = new CyclicBarrier(THREAD_COUNT, () -> {
        System.out.println("All threads have reached the barrier");
    });

    for (int i = 0; i < THREAD_COUNT; i++) {
        new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + " is waiting at the barrier");
            try {
                barrier.await();
            } catch (InterruptedException | BrokenBarrierException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() + " has passed the barrier");
        }).start();
    }
}

JDK中实现AQS简介

同步工具与AQS关联详细介绍
AQS原理讲解AQS原理介绍并发-AQS原理讲解
ReentrantLock使用AQS保存锁重复持有的次数。当一个线程获取锁时,ReentrantLock记录当前获得锁的线程标识,用于检测是否重复获取,以及错误线程试图解锁操作时异常情况的处理。AQS之Reentrantlonk源码解读
Semaphore使用AQS同步状态来保存信号量的当前计数。tryRelease会增加计数,acquireShared会减少计数。Semaphore 源码分析以及AQS共享加解锁
CountDownLatch在多线程并发执行任务时,有时需要让某些线程等待某些条件达成后再开始执行,这时就可以使用CountDownLatch来实现CountDownLatch 源码分析
ThreadPoolExecutor创建线程池中的工作线程worker继承AQS,实现独占资源参考 并发-AQS之ThreadPoolExecutor源码解读(一)
CyclicBarrier多个线程等待彼此达到一个共同的屏障点,然后同时继续执行。并发-AQS之CyclicBarrier源码解读
ReentrantReadWriteLock可重入读写锁,它允许多个线程同时读取一个共享资源,但只允许一个线程写入该共享资源。参考 并发-AQS之ReentrantReadWriteLock源码解读(一)