Go并发编程 — sync.WaitGroup
简介
WaitGroup
可以解决一个 goroutine 等待多个 goroutine 同时结束的场景,常见的场景例如启动了多个worker goruntine 进行并发处理,然后某个goruntine需要汇总信息。
使用
WaitGroup
使用比较简单,主要有下面这3个方法。
func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()
- Add方法:增加WaitGroup的计数值
- Done方法:减少WaitGroup的计数值,表示goruntine完成了,内部实现就是调用了Add(-1)
- Wait方法:需要等待的goruntine可以调用Wait进行阻塞,直到WaitGroup的计数值变为0
看一下下面这个Demo,启动了3个 goruntine 来做 worker ,注意 Add 方法需要提前设置,我这里是在for循环里面设置的,然后完成之后需要调用 Done 方法,表示 goruntine 处理结束了,最后在 main goruntine 中调用 Wait 方法来等待 3 个 worker goruntine 处理完成。
package main
import (
"fmt"
"sync"
"time"
)
func main() {
var wg sync.WaitGroup
for i := 1; i <= 3; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
fmt.Printf("【goruntine#%d】开始工作\n", index)
time.Sleep(time.Second * 2)
fmt.Printf("【goruntine#%d】结束工作\n", index)
}(i)
}
wg.Wait()
fmt.Printf("所有goruntine全部结束,可以处理数据了")
}
输出结果:
【goruntine#3】开始工作
【goruntine#1】开始工作
【goruntine#2】开始工作
【goruntine#3】结束工作
【goruntine#1】结束工作
【goruntine#2】结束工作
所有goruntine全部结束,可以处理数据了
实现
数据结构
来看一下 WaitGroup 的数据结构吧,主要由 noCopy 和 state1 组成。
- noCopy:保证 vet 工具检查是否 copy 复制这个 WaitGroup 实例
- state1:是一个复合字段,存储了 WaitGroup 的 counter(计数值),waiter数量和信号量
type WaitGroup struct {
noCopy noCopy
state1 [3]uint32
}
在来看一下内部的一个 state 方法,主要是获取 计数值、waiter数量和信号量的方法。有没有发现这里面还有一段判断逻辑,由于 atomic 后续需要进行 64 位的操作(拿到statep的返回值,进行原子操作),需要 64 位的内存对齐,但是在 32 位的机器上是不能保证 64 位的内存对齐的。
在 64 位环境下,state1 的第一个元素是 waiter 数,第二个元素是 WaitGroup 的计数值,第三个元素是信号量。
在 32 位环境下,如果 state1 不是 64 位对齐的地址,那么 state1 的第一个元素是信号量,后两个元素分别是 waiter 数和计数值,来保证 statep 是对齐的。
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}
Add方法
我这里删除了部分 race 检查的代码,主要逻辑首先增加counter值,由于delta可以传负数,所以如果counter值为0,需要唤醒等待者。
func (wg *WaitGroup) Add(delta int) {
statep, semap := wg.state() // 获取到statep和semap
state := atomic.AddUint64(statep, uint64(delta)<<32) // 将delta左移32位,然后进行原子加操作
v := int32(state >> 32) // 获取增加后的counter值
w := uint32(state) // 获取waiter值
// counter不能小于0
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// waiter值不等0的情况下,delta > 0 && v == int32(delta)表示counter是第一次增加
// 表示Add方法和Wait存在并发调用,也就是复用Waiter的时候需要waiter值变成0才行
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// counter大于0,或者没有等待者,直接返回
if v > 0 || w == 0 {
return
}
// 避免并发调用 add 和 wait
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 走到这里的话,表示counter是0,waiter值不是0,表示所有goruntine都完成操作了,需要通知等待者了,然后将state值设置成0
*statep = 0
for ; w != 0; w-- {
runtime_Semrelease(semap, false, 0)
}
}
Wait方法
Wait的主要逻辑增加waiter值,然后进行休眠,等待被唤醒。
func (wg *WaitGroup) Wait() {
statep, semap := wg.state() // 获取statep和semap
for {
state := atomic.LoadUint64(statep) // 原子获取state
v := int32(state >> 32) // 获取counter
w := uint32(state) // 获取waiter值
// counter为0,不需要阻塞
if v == 0 {
return
}
// CAS操作自增waiter值
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// 阻塞休眠
runtime_Semacquire(semap)
// 被唤醒后statep不为0,代表出现异常
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
// 退出
return
}
}
}
Done方法
Done方法的逻辑就是调用Add(-1)
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
总结
- 调用
Add
方法建议不要传负数,直接调用Done
方法 Add
方法需要在启动goruntine
前调用,Done
方法需要在goruntine
完成时调用- 调用
Done
的次数超过了WaitGroup
的counter
值,所以需要预先确定好WaitGroup
的计数值,然后调用相同次数的Done
完成相应的任务 WaitGroup
必须在Wait
方法返回之后才能再次使用,主要是Wait
方法和Add
方法可能存在并发,我的建议是最好不要复用,直接创建一个新的- 可以同时有多个
goroutine
等待当前WaitGroup
的counter
值归零,这些goroutine
会被同时唤醒
转载自:https://juejin.cn/post/7094094076974202893