likes
comments
collection
share

Go 瞧瞧WaitGroup

作者站长头像
站长
· 阅读数 15

前言

WaitGroup常用于主协程等待一组goroutine完成,才继续下一步任务。其源码也较为简单,那不妨通过业务推导方式,自己梳理出实现逻辑,这样以后就靠推导而无需记忆实现原理了。

情景分析

使用WaitGroup时,我们发现协程分为两种:

  • 一种是用Add方法,表明自己是要等待的协程
  • 一种是用Wait方法,表明自己是正在等待的协程

要等待的协程任务:

  • 将要等待的协程数量加一
  • 完成任务后调用Done方法,将要等待的协程数量减一
  • 如果最后一个要等待的协程也完成了,则唤醒正在等待的协程

正在等待的协程任务:

  • 将正在等待的协程数量加一
  • 陷入休眠,等待goroutine任务组完成
  • 任务组完成后被唤醒,执行后续任务

自己实现

这里面有两种角色:

  • 要等待的协程标记为counter,Add方法会使counter加一
  • 正在等待的协程标记为waiter,Wait方法会使waiter加一
  • 同时还有休眠唤醒操作,需要一个sema信号量

因此抽象出结构体如下:

type WaitGroup struct {
   counter int64
   waiter  int64
   sema    uint32
}

实现Add方法。

func (wg *WaitGroup) Add(delta int64) {
   v := atomic.AddInt64(&wg.counter, delta)
   w := wg.waiter
   if v > 0 || w == 0 {
      return
   }

   wg.waiter = 0
   for ; w > 0; w-- {
      runtime_Semrelease(&wg.sema, false, 1)
   }
}

Add方法中,如果有counter或者waiter数量为0则返回。否则counter肯定为0且有waiter,那么唤醒waiter。唤醒waiter前需要将waiter数量重置为0。

实现Wait方法。

func (wg *WaitGroup) Wait() {
   if wg.counter == 0 {
      return
   }
   atomic.AddInt64(&wg.waiter, 1)
   runtime_SemacquireMutex(&wg.sema, false, 1)
}

将waiter加一后休眠,被唤醒就说明counter为0了,结束等待。

Done方法直接调用Add(-1),意味着一个协程完成了。

func (wg *WaitGroup) Done() {
   wg.Add(-1)
}

以上实现方法只是简单的还原WaitGroup的核心实现,但对于并发问题没有过多考虑。接下来解析下源码,看严谨的实现是怎么写的。

源码解析

WaitGroup结构体

type WaitGroup struct {
   noCopy noCopy

   state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
   sema  uint32
}
  • noCopy字段是告诉编译器,该结构体不能被拷贝,否则容易引起并发问题
  • state是一个atomic.Uint64类型,高32为表示counter,低32位表示waiter
  • sema是信号量,用于协程的休眠和唤醒

图解一下:

Go 瞧瞧WaitGroup

Add方法

func (wg *WaitGroup) Add(delta int) {
   state := wg.state.Add(uint64(delta) << 32)
   v := int32(state >> 32)
   w := uint32(state)
   if v < 0 {
      panic("sync: negative WaitGroup counter")
   }
   if w != 0 && delta > 0 && v == int32(delta) {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   if v > 0 || w == 0 {
      return
   }

   if wg.state.Load() != state {
      panic("sync: WaitGroup misuse: Add called concurrently with Wait")
   }
   // Reset waiters count to 0.
   wg.state.Store(0)
   for ; w != 0; w-- {
      runtime_Semrelease(&wg.sema, false, 0) // 重点,唤醒等待协程
   }
}
  • wg.state.Add(uint64(delta) << 32),将delta左移32位,才能加入到counter中,因为counter是64位中前32位。
  • v代表计算后的counter,w代表waiter,之后进行了一些panic判断
  • 如果counter不为0或者waiter为0,则结束该函数,让协程继续执行任务
  • 否则counter肯定为0且waiter不为0,于是唤醒正在等待的协程

Done方法

func (wg *WaitGroup) Done() {
   wg.Add(-1)
}
  • 协程完成任务了,因此调用Add方法,将counter减一
  • 如果刚好counter为0,就会唤醒协程

Wait方法

func (wg *WaitGroup) Wait() {
   for {
      state := wg.state.Load()
      v := int32(state >> 32)
      w := uint32(state)
      if v == 0 {
         // Counter is 0, no need to wait.
         return
      }
      // Increment waiters count.
      if wg.state.CompareAndSwap(state, state+1) {
      
         runtime_Semacquire(&wg.sema) // 重点,休眠自己
         if wg.state.Load() != 0 {
            panic("sync: WaitGroup is reused before previous Wait has returned")
         }
         return
      }
   }
}
  • v代表计算后的counter,w代表waiter
  • counter为0就不需要等待了,直接结束函数
  • 否则要等待,将waiter加一,然后休眠自己
  • 等待唤醒后直接return

总结

WaitGroup把协程分为counter和waiter,我们结合各自的任务情景,就能自己推导出大致实现过程了。当然源码也较为简单,建议动手细读。