当前位置 博文首页 > 阳阳的博客:【多线程】CyclicBarrier实现原理

    阳阳的博客:【多线程】CyclicBarrier实现原理

    作者:[db:作者] 时间:2021-08-04 11:58

    前言

    CyclicBarrier,字面意思“循环屏障”,用于多个线程一起到达屏障点后,多个线程再一起接着运行的情况。例如,线程1和线程2一起运行,线程1运行到屏障点a时,将会被阻塞,等到线程2运行到屏障点a后,线程1和线程2才可以打破屏障,接着运行。如果有屏障点b,则他们需要像打破屏障a一样打破屏障b,如此循环往复。


    常用方法

    public class CyclicBarrier {
     
        //构造方法,传入线程总数以及打破屏障点前的任务
        public CyclicBarrier(int parties, Runnable barrierAction);
    
        //构造方法,传入线程总数,不需要执行额外任务
        public CyclicBarrier(int parties);
         
        //在所有一起到达屏障点前,阻塞当前线程
        public int await();
    }
    

    下面给出一个简单的示例:

    package com.yang.testCB;
    
    import java.util.concurrent.BrokenBarrierException;
    import java.util.concurrent.CyclicBarrier;
    
    public class Main {
        public static void main(String[] args) {
            CyclicBarrier cb = new CyclicBarrier(2, () -> {
                System.out.println("即将打破屏障");
            });
    
            for (int i = 0; i < 2; i++) {
                new Thread(() -> {
                    System.out.println(Thread.currentThread().getName() + "开始运行");
                    try {
                        cb.await();
                        System.out.println(Thread.currentThread().getName() + "已经穿越了第1个屏障");
                        cb.await();
                        System.out.println(Thread.currentThread().getName() + "已经穿越了第2个屏障");
                    } catch (InterruptedException | BrokenBarrierException e) {
                        e.printStackTrace();
                    }
                }).start();
            }
        }
    }
    

    输出如下:

    由上面的例子,我们发现,线程0与线程1调用await()后,在所有线程到达第1个屏障点前,都会被阻塞。最后一个线程到达屏障点后,先执行额外任务,才能打破屏障,接着该线程唤醒其他阻塞的线程,此时所有线程继续运行。


    原理解析

    先看一下CyclicBarrier里面有哪些变量

    public class CyclicBarrier {
      
        //同步锁
        private final ReentrantLock lock = new ReentrantLock();
    
        //条件队列
        private final Condition trip = lock.newCondition();
    
        //传入构造方法的线程总数
        private final int parties;
    
        //打破屏障前需要执行的任务,或者称为换代任务
        private final Runnable barrierCommand;
    
        //当前代
        private Generation generation = new Generation();
        
        //计数器,代表当前还未达到屏障的线程数目。
        private int count;
    
    }
    
    • CyclicBarrier借助ReentranLock与Condition来对线程进行阻塞的。
    • parties是传入构造方法的线程总数,在该CyclicBarrier实例的整个生命周期内,该值保持不变,并且会在换代的时候,使得count=parties
    • barrierCommand,换代任务,打破屏障前需要执行的任务,任务执行完成后(不管成功还是失败),才会唤醒所有线程
    • generation代表当前代,两个屏障之间称为一代,原点与第一个屏障可以称为第一代
    • count,计数器,每有一个线程到达屏障时,count值就会减1。一旦减为0后,则先同步执行换代任务,接着打破屏障,开启下一代,然后唤醒所有阻塞的线程,最后将count重置为parties。

    CyclicBarrier内的主要方法:

    构造方法

     public CyclicBarrier(int parties, Runnable barrierAction) {
            if (parties <= 0) throw new IllegalArgumentException();
            this.parties = parties;
            this.count = parties;
            this.barrierCommand = barrierAction;
        }
    
        public CyclicBarrier(int parties) {
            this(parties, null);
        }

    CyclicBarrier提供两个构造方法,不过最核心的是两参数的构造方法。构造方法中设置了parties与count的值都是传入的线程总数,barrierAction为换代任务,当然也可以不指定换代任务。

    await()方法

     public int await() throws InterruptedException, BrokenBarrierException {
            try {
                return dowait(false, 0L);
            } catch (TimeoutException toe) {
                throw new Error(toe); // cannot happen
            }
        }
    
        public int await(long timeout, TimeUnit unit)
            throws InterruptedException,
                   BrokenBarrierException,
                   TimeoutException {
            return dowait(true, unit.toNanos(timeout));
        }
    

    CyclicBarrier同样提供了定时等待与非定时等待,不过都调用了dowait()方法,该方法是CyclicBarrier内最为核心的方法。

        private int dowait(boolean timed, long nanos)
            throws InterruptedException, BrokenBarrierException,
                   TimeoutException {
            final ReentrantLock lock = this.lock;
            lock.lock();
            try {
                //当前代
                final Generation g = generation;
                //判断当前代的状态,如果当前代后的屏障被打破,则g.broken返回true,否则返回false。
                if (g.broken)
                    throw new BrokenBarrierException();
                //判断当前线程是否被中断
                if (Thread.interrupted()) {
                    //如果当前线程已经被中断,则调用breakBarrier()
                    //该方法代码为generation.broken = true;count = parties;trip.signalAll();
                    //可见,只做了3件事:先将当前代的屏障变为打破状态,接着重置计数器的值,最后唤醒所有被阻塞的线程
                    breakBarrier();
                    //最后抛出中断异常
                    throw new InterruptedException();
                }
                //将计数器的值减1
                int index = --count;
                if (index == 0) {
                    //如果当前计数器的值为0
                    boolean ranAction = false;
                    try {
                        //则先执行换代任务,可以看得出来,是由最后一个到达屏障的线程执行的
                        final Runnable command = barrierCommand;
                        if (command != null)
                            command.run();
                        ranAction = true;
                        //开启下一代,这个方法的代码为trip.signalAll();count = parties;generation = new Generation();
                        //该代码唤醒所有被阻塞的线程,重置计数器的值,并且实例化下一代
                        nextGeneration();
                        return 0;
                    } finally {
                        //如果换代任务未执行成功,则先将当前代的屏障变为打破状态,接着重置计数器的值,最后唤醒所有被阻塞的线程
                        if (!ranAction)
                            breakBarrier();
                    }
                }
    
                //当前线程一直阻塞,直到“有parties个线程到达barrier” 或 “当前线程被中断” 或 “超时”这3者之一发生
                //死循环
                for (;;) {
                    try {
                        if (!timed)
                            //如果不是定时等待,则调用条件队列的await()进行阻塞
                            trip.await();
                        else if (nanos > 0L)
                            //如果是定时等待,则调用条件队列的awaitNanos进行等待
                            nanos = trip.awaitNanos(nanos);
                    } catch (InterruptedException ie) {
                        //如果在等待过程中,当前线程被打断
                        if (g == generation && ! g.broken) {
                            //被打断后,还处于当前代,且当前代的屏障也未被打破
                            //现在的情况是,最后一个线程还未到屏障,当前线程早早到了,并且在进行等待,但是在等待的过程中,被打断了。
                            //则打破当前代的屏障,唤醒所有被阻塞的线程
                            breakBarrier();
                            throw ie;
                        } else {
                            //如果已经换代,则手动进行打断
                            Thread.currentThread().interrupt();
                        }
                    }
                    //此时线程被唤醒,需要判断自己为什么被唤醒了                
                    
                    //如果是其他某个线程被打断或者是由于超时导致当前代的屏障被打破,则抛出异常
                    if (g.broken)
                        throw new BrokenBarrierException();
    
                    //如果是正常换代,则返回index值
                    if (g != generation)
                        return index;
    
                    //如果是定时等待,且时间已经到了,则打破屏障,唤醒所有阻塞的线程,最后抛出异常
                    if (timed && nanos <= 0L) {
                        breakBarrier();
                        throw new TimeoutException();
                    }
                }
            } finally {
                lock.unlock();
            }
        }
    

    用一开始的例子,简单说下整个流程:

    线程0首先运行,输出“Thread-0开始运行”,接着调用了await()方法,然后进入dowait()方法中,获得lock锁,此时count-1=1,count值不为0,因此进入了for循环中,最后调用了trip.await()方法,于是线程0释放了lock锁,被阻塞住了。

    线程1和线程0差不多同时运行,但线程0首先获取到了锁,线程1输出“Thread1-开始运行”后,需要等待线程0释放锁。此时线程0释放了lock锁,线程1可以进入到同步代码中,此时count-1=0,因此线程1首先执行换代任务,输出“即将打破屏障”。接着调用nextGeneration()方法开启下一代,最后直接返回0,然后输出“Thread1-已经穿越了第1个屏障”。其中nextGeneration()方法将会唤醒线程0,线程0继续从trip.await()处运行,由于已经发生了换代,因此直接返回1,最后输出“Thread0-已经穿越了第1个屏障”

    这个时候,线程0和线程1都已经穿越了第一层屏障,当再次调用await()方法时,将会进行第二次换代。


    CyclicBarrier与CountDownLatch的区别

    对CountDownLatch不熟悉的同学,可以先参考我的另外一篇文章CountDownLatch实现原理

    CountDownLatch,是一个线程或多个线程等待另外多个线程执行完毕之后才执行。内部维护一个计数器,每个线程调用一次countDown后,计数器减1,计数器减为0后,会唤醒因调用await()而阻塞的线程。

    CyclicBarrier,多个线程互相等待,直到所有的线程都达到屏障点,才可以一起接着执行。同样可以理解为内部有一个可重置的计数器,每个线程调用await()后,计数器减1,若计数器的值不为0,将会阻塞该线程。当最后一个线程调用await()后,计数器为0,将会唤醒所有阻塞的线程,并开启下一代,重置此计数器的值。

    当然两者也有共同点,调用对应的await()方法都会阻塞当前线程。

    cs