likes
comments
collection
share

并发控制利器:WaitGroup实现原理及使用注意事项

作者站长头像
站长
· 阅读数 7

一、场景介绍

在Go语言中,WaitGroupsync包中用于做任务编排的一个并发原语。

WaitGroup主要解决的问题就是并发——等待的问题。具体来说:当某一个goroutine需要在检查点(checkpoint)等待一组goroutine任务全部完成,如果这一组goroutine任务没有全部都完成,则这个goroutine会阻塞在检查点,直到所有的goroutine都完成任务后,才能够继续执行后续的逻辑。

举一个使用WaitGroup的场景。

当一个goroutine大任务需要并行三个goroutine小任务,并且需要这三个goroutine小任务全部完成后,才能够继续执行这个goroutine大任务。如果需要知道这三个goroutine小任务是否完成,需要通过定时轮训的方式询问这三个小任务是否完成。

轮询的方法存在两个问题:

  • 性能低:在轮询的过程中,很有可能还未轮询到的任务早已经完成,却要等待很长才能被轮询到。
  • 空耗CPU资源:产生很多无谓的轮询,空耗CPU资源。

面对这种场景,WaitGroup并发原语便派上了用场。WaitGroup可以阻塞等待当前的goroutine,等到三个goroutine小任务完成后,再唤醒当前阻塞的goroutine,避免了轮询带来的空耗。

二、基本用法

在Go标准库中,WaitGroup提供了三种方法:

方法说明
Add用于设置WaitGroup的计数值
DoneWaitGroup的计数值减1
Wait调用Wait方法的goroutine会阻塞,直到WaitGroup的计数值变为0后被唤醒
func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()

通过计数器的例子,来具体熟悉WaitGroup的Add、Done、Wait 方法的基本用法。

package main

import (
    "fmt"
    "sync"
    "time"
)

// Counter 线程安全的计数器
type Counter struct {
    mu    sync.Mutex
    count uint64
}

// Incr 计数值加1
func (c *Counter) Incr() {
    c.mu.Lock()
    c.count++
    c.mu.Unlock()
}

// 获取当前计数值
func (c *Counter) Count() uint64 {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.count
}

// 任务执行,sleep 1秒后计数值 + 1
func worker(c *Counter, wg *sync.WaitGroup) {
    defer wg.Done()
    time.Sleep(time.Second)
    c.Incr()
}

func main() {
    var counter Counter
    var wg sync.WaitGroup
    wg.Add(10) // 设置WaitGroup的值为10
    
    // 开启10个goroutine执行Incr任务
    for i := 0; i < 10; i++ {
       go worker(&counter, &wg)
    }
    
    // 设置检查点,等待所有goroutine完成任务
    wg.Wait()
    // 输出当前计数器的值
    fmt.Println(counter.Count())
}

上述代码,定义的一个Counter计数器结构体,开启了10个worker分别对计数器的计数值加一,在10个worker执行完毕后,输出计数器的值。

  • 在第38行代码中,声明了一个WaitGroup变量,初始值为零。
  • 在第39行代码中,为WaitGroup变量设置计数值为10,因为需要编排10个goroutine worker执行计数任务。
  • 在第47行代码中,通过WaitGroupWait()方法,阻塞主goroutine,直到这10个goroutine worker完成计数任务。
  • 在第43行代码中,通过for循环开启10个goroutine worker,将定义的WaitGroup指针当作参数传递进入worker函数。每个woker在完成计数任务后,调用Done方法,把 WaitGroup 的计数值减 1。告知WaitGroup当前的goroutine的任务已完成。
  • 当10个goroutine worker都调用了Done方法后,WaitGroup的计数值降为0,表示WaitGroup通过Add方法中设置的10个计数值任务都已经完成,告知主goroutine这10个计数值任务都已经完成,不再阻塞,继续执行后续逻辑。

上述的例子,通过使用使用 WaitGroup 来进行任务编排,当需要启动多个 goroutine 执行任务,主 goroutine 需要等待子 goroutine 都完成后才继续执行的场景时,WaitGroup是一个非常好用的并发原语。

三、实现原理

1、WaitGroup结构体

WaitGroup的数据结构中,包含了一个noCopy的辅助字段,一个state1记录WaitGroup状态的数组。

在Go1.17的sync包中,WaitGroup的结构体如下:

type WaitGroup struct {
    // noCopy辅助字段,用于禁止拷贝,避免复制使用,可以告诉vet检查器违反了复制使用的规则
    noCopy noCopy

    // 64bit值的原子操作需要64bit对齐,但是32bit编译器不支持
    // 所以state1数组中的元素在不同的架构中不一样,具体处理看state方法
    state1 [3]uint32
}
  • noCopy:辅助字段,主要用于辅助vet工具检查是否通过copy赋值这个WaitGroup实例。
  • state1:数组复合字段,包含了WaitGroup的计数值、阻塞在检查点的waiter数量、信号量。

关于state1字段,对于64位整数的原子操作要求整数的地址是64位对齐,因此在64位和32位环境下state1字段的组成是不同的。

在64位环境下state1字段含义如下:

  • state1[0]:waiter的个数;
  • state1[1]:WaitGroup的计数值;
  • state1[2]:信号量;

并发控制利器:WaitGroup实现原理及使用注意事项

在32位环境下,如果state1分配到的地址不是64位对齐的地址state1字段含义如下:

  • state1[0]:信号量;
  • state1[1]:waiter的个数;
  • state1[2]:WaitGroup的计数值;

并发控制利器:WaitGroup实现原理及使用注意事项

为什么state1的元素排列会不同,具体来分析一下:

1、首先需要理解什么是内存对齐

为了能让CPU可以更快的存储、读取到各个字段,Go编译器会将结构体做数据的对齐,即内存对齐。

所谓的内存对齐,指的是内存地址的大小是所存储数据类型大小的整倍数(以字节为单位)以便CPU可以一次将该数据从内存中读取出来,减少了读取次数。

不同硬件平台占用的大小和对齐值都可能是不一样的,每个特定平台上的编译器都有自己的默认"对齐系数",32位系统对齐系数是4,64位系统对齐系数是8

编译器通过在结构体的各个字段之间填充一些空白,来达到对齐的目的。对齐原则具体来说如下:

  • 对齐原则1 结构体变量中成员的偏移量必须是成员数据类型大小的整数倍,例如int32为4,int64为8;
  • 对齐原则2 整个结构体的内存大小必须是其成员数据类型中最大字节的整数倍(结构体的内存占用是1/4/8/16byte…),例如结构体中最大的成员数据类型为int64,则整个结构体的内存占用为8/16/24...的整倍数;

举一个简单的例子:

type badSt struct {
    a int32 // 4
    b int64 // 8
    c bool  // 1
}

type goodSt struct {
    a int32 // 4
    c bool  // 1
    b int64 // 8
}

func main() {
    bad := badSt{}
    good := goodSt{}
    fmt.Printf("size of bad: %d\n", unsafe.Sizeof(bad)) // 24
    fmt.Printf("size of good: %d\n", unsafe.Sizeof(good)) // 16
}

我们知道,结构式的内存布局是占用一块连续的内存来存储结构体。

上述badSt结构体goodSt结构体的成员字段都相同,唯一的不同就是定义的顺序不同,但是初始化这两个结构体并打印它们的大小发现却不相同。

badSt结构体,占用24个字节

并发控制利器:WaitGroup实现原理及使用注意事项

具体分析:

  • 字段a为int32类型,占用4个字节,计算其偏移量,最开始下标为0,0%4=0,当前下标正好整除成员数据类型的大小,占用4个字节;
  • 字段b为int64类型,占用8个字节,计算其偏移量,下标4-7,用8都无法整除,即无法被对齐值整除(对齐原则1,偏移量4-7都无法8的整倍数),所以下标8~15为字段b的存储使用;
  • 字段c为bool类型,占用1个字节,计算其偏移量,下标16可以使用,偏移量16可以对1整除,16%1=0,因此字段c的偏移量为16;
  • badSt结构体的成员数据类型中,最大字节为int64,为8字节,当前结构体占用了17个字节,为了保证是整倍数(对齐原则2),因此在结构体后需要填充7个字节,占满24个字节。

goodSt结构体,占用16个字节

并发控制利器:WaitGroup实现原理及使用注意事项

具体分析:

  • 字段a为int32类型,占用4个字节,计算其偏移量,最开始下标为0,0%4=0,当前下标正好整除成员数据类型的大小,占用4个字节;
  • 字段c为bool类型,占用1个字节,计算其偏移量,下标4可以对1整除,所以偏移量4则用作字段c的存储;
  • 字段b为int64类型,占用8个字节,计算其偏移量,下标5-7用8都无法整除,即无法被对齐值整除(对齐原则1,偏移量5-7都无法8的整倍数),因此下标5-7进行空白填充,下标8可以被8整除,所以下标8~15为字段b的存储使用;
  • goodSt结构体的成员数据类型中,最大字节为int64,为8字节,当前结构体占用了16个字节,正好是8的整数倍,因此在结构体后无需填充字节,占满16个字节。

2、先决条件:64位整数的原子操作,要求整数的地址是64位对齐

了解了内存对齐后,回到WaitGroup结构体,在分配给WaitGroup的地址可能会出现三种情况:

32bit环境下,WaitGroup的地址是32bit对齐,包括state1分配的地址也为32bit对齐

情况一WaitGroupstate1的地址一定是32bit对齐,且是64bit对齐;

情况二WaitGroupstate1的地址一定是32bit对齐,但不是64bit对齐;

64bit环境下,WaitGroup的地址是64bit对齐,包括state1分配的地址也为64bit对齐

情况三WaitGroupstate1的地址一定是32bit对齐,且一定是64bit对齐;

32bit环境下,结构体的成员变量偏移量地址一般会以4字节对齐,超过4字节的成员变量填充至4的整倍数大小,保证结构体中下个成员变量的偏移量为4的整倍数,从而保证对齐。

64bit环境下,结构体的成员变量偏移量地址一般会以8字节对齐,超过4字节的成员变量填充至8的整倍数大小,保证结构体中下个成员变量的偏移量为8的整倍数,从而保证对齐

针对64位整数的原子操作,要求整数的地址是64位对齐这个先决条件,我们可以分析上述三种情况:

针对情况一和情况三,state1分配到的地址为64bit对齐,满足先决条件,可以进行原子操作,因此在state1在组成上为:

  • state1[0]:waiter的个数;
  • state1[1]:WaitGroup的计数值;
  • state1[2]:信号量;

针对情况而,state1到的地址不为64bit对齐,此时若还是使用 waiter的个数、WaitGroup的计数值、信号量的顺序,此时无法对其进行64位整数的原子操作。因此,通过使用如下顺序:

  • state1[0]:信号量;
  • state1[1]:waiter的个数;
  • state1[2]:WaitGroup的计数值;

可以保证waiter的个数WaitGroup的计数值作为一个64位整数时,其地址是64bit对齐,另其可以进行64位整数的原子操作。

举个简单的例子,在32位操作系统下,如果给WaitGroup分配的地址是4的倍数,即 4 * nn为奇数,若为偶数则WaitGroup已经64bit对齐),此时如果将信号量分配到前4个字节,即state1[0]state1uint32类型切片,uint32为4字节),此时state1[1]的地址必为64bit对齐4 * (n + 1)n + 1为偶数,4 * (n + 1)可以被8整除),保证waiter的个数、WaitGroup的计数值作为一个整体的uint64且该地址为64bit对齐,从而可以进行64位整数的原子操作,这就是为什么需要将信号量提前的原因啦。

在这种情况下,通过将信号量sem分配到status1的前4个字节,可以确保在32 位操作系统上进行原子操作时,64位值的内存对齐要求得到满足。

WaitGroup的实现上,waiter的个数WaitGroup的计数值作为一个64位整数来维护其状态,保证其waiter的个数WaitGroup的计数值这两个状态相对一致。

2、state()方法

我们可以来看WaitGroup的state()方法:

// 得到state的地址和信号量的地址
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        // 如果地址是64bit对齐的,数组前两个元素做state,后一个元素做信号量
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        // 如果地址是32bit对齐的,数组后两个元素用来做state,它可以用来做64bit的原子操作,第一个元素32bit用来做信号量
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

在调用WaitGroupAdd()Done()Wait()方法时,都会先通过调用state()方法来获取当前WaitGroup的状态,可以来具体分析state()方法的逻辑:

  • 在第3行中,通过uintptr(unsafe.Pointer(&wg.state1))%8 == 0来判断当前state1的地址是否是64bit对齐,通过取余8来进行判断,如果为0则说明当前state1的地址已是64bit对齐,返回waiter的个数、WaitGroup的计数值、信号量的顺序。
  • 反之,则说明当前state1的地址没有64bit对齐,则通过采用信号量、waiter的个数、WaitGroup的计数值的顺序,将信号量提前,保证waiter的个数、WaitGroup的计数值为64bit对齐,从而后续能够对其64位整数的原子操作。
  • 我们可以看到state()的返回值,将waiter的个数、WaitGroup的计数值这两部分作为一个uint64类型指针整体返回,信号量单独为uint32类型的指针返回。

3、Add()方法

func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    // 高32bit是计数值v,所以把delta左移32,增加到计数上
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32) // 当前计数值
    w := uint32(state) // waiter count

    if v > 0 || w == 0 {
        return
    }

    // 如果计数值v为0并且waiter的数量w不为0,那么state的值就是waiter的数量
    // 将waiter的数量设置为0,因为计数值v也是0,所以它们俩的组合*statep直接设置为0即可。此时需要并唤醒所有的waiter
    *statep = 0
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

Add()方法的主要逻辑在于操作state1的计数部分,即waiter个数以及WaitGroup计数值,Add()方法接收一个int类型的delta参数,为计数值增加一个delta值,其内部是通过原子操作把将delta值增加到计数值。

  • 首先通过wg.state()方法获取WaitGroup的状态指针statep和信号量指针semap
  • 将增量delta左移32位,使用原子操作将该值增加到statep指向的uint64类型的变量上,得到的结果为新的状态值state
  • state中提取出高32位作为当前计数器值v(int32类型),提取出低32位作为waiter数量w(uint32类型)
  • 如果当前计数器值v大于0或者waiter数量w等于0,则直接返回,不进行后续操作。
  • 如果当前计数器值v等于0并且waiter数量w不等于0,说明有阻塞在Wait方法上的goroutine需要被唤醒。此时,将状态值statep设置为0,表示计数器值和waiter数量都为0,然后依次唤醒所有的waiter goroutine
  • 具体的唤醒过程是通过调用runtime_Semrelease函数来完成的,该函数用于释放信号量并唤醒等待该信号量的goroutine。循环执行w次,每次都调用runtime_Semrelease函数,将semap指向的信号量释放,并通知一个等待中的goroutine

4、Done()方法

// Done方法实际就是计数器减1
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

Done方法内部其实就是通过 Add(-1) 实现的。

5、Wait()方法

func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32) // 当前计数值
        w := uint32(state) // waiter的数量
        if v == 0 {
            // 如果计数值为0, 调用这个方法的goroutine不必再等待,继续执行它后面的逻辑即可
            return
        }
        // 否则把waiter数量加1。期间可能有并发调用Wait的情况,所以最外层使用了一个for循环
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            // 阻塞休眠等待
            runtime_Semacquire(semap)
            // 被唤醒,不再阻塞,返回
            return
        }
    }
}

Wait方法的主要实现逻辑:通过for循环,不断检查state的值。

  • 如果计数值v为0,则说明所有的任务都已完成,调用Wait方法的goroutine不必再等待,直接返回。
  • 如果计数值v大于0,则说明任务仍未完成,调用该Wait方法的goroutine变为等待者,加入waiter队列,即waiter数量加1,并阻塞休眠自己。

四、常见错误

在上述分析WaitGroupAddDoneWait方法的实现时,为了更好的分析方法的实现逻辑,剔除了异常检查的代码。但是,这些异常检查逻辑非常有用。在实际开发的过程中,WaitGroup也会有误用的场景。

1、计数器设置为负值

WaitGroup的计数值必须大于等于0。在修改WaitGroup的计数值时,WaitGroup会先进行检查,如果计数值被设置为负数,则会导致panic

一般情况下,有两种情况会导致WaitGroup的计数值被设置为负数。

情况一:调用Add方法时参数传递一个负数

当调用Add方法传入一个负数,若当前WaitGroup的计数值加上这个负数后计数值还是一个大于等于0的数时,此时没有问题,但如果小于0了,则程序会出现panic

举个简单的例子:

package main

import (
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(10)
    wg.Add(-10) // -10作为参数调用Add方法,此时计数值被设置为0
    wg.Add(-1)  // -1作为参数调用Add方法,如果加上该-1则计数值会变为负数,会触发panic
}

上述代码的执行结果如下,可以看到在第11行代码触发了panic

并发控制利器:WaitGroup实现原理及使用注意事项

情况二:调用Done方法的次数超过了WaitGroup的计数值

在初次使用WaitGroup时,一般会事先调用Add方法为其设置计数值,然后在调用相同次数的Done来完成相应的任务。比方说我们在声明WaitGroup变量后,紧接着调用Add(10)为其设置计数值为10,然后在goroutine中调用相应次数的Done

如果Done方法调用的次数和计数值不一致,可能会导致死锁(Done方法调用次数少于计数值,导致调用Wait方法的goroutine阻塞,进而造成死锁)或者产生panicDone调用次数比计数值多)。

举个简单的例子:

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    wg.Done()
    wg.Done()
}

2、Add方法的使用时机不正确

在使用WaitGroup时,一定需要遵循的原则:需要等到所有的Add方法调用完成之后,再调用Wait方法进行阻塞。

若没有遵循上述的原则,则可能会导致panic或产生非预期的结果。

我们可以通过构造这样一个场景,来看看不满足这个原则的代码会发生什么问题?

main goroutine中启动四个goroutine,每个goroutine内部调用Add(1)后随即调用Done()main goroutine调用Wait方法等待任务完成。

func operateWg(mill time.Duration, wg *sync.WaitGroup) {
    duration := mill * time.Millisecond
    time.Sleep(duration) // 故意sleep一段时间

    wg.Add(1)
    fmt.Println("后台执行, duration:", duration)
    wg.Done()
}

func main() {
    var wg sync.WaitGroup
    go operateWg(100, &wg) // 启动第一个goroutine
    go operateWg(110, &wg) // 启动第二个goroutine
    go operateWg(120, &wg) // 启动第三个goroutine
    go operateWg(130, &wg) // 启动第四个goroutine

    wg.Wait()
    fmt.Println("Done")
}

上述代码场景中,原本的想法是在四个goroutine都执行完毕后,在main goroutine输出Done信息,但上述代码的错误之处在于,将WaitGroup.Add方法的调用放在了子 gorotuine 中。当main goroutine调用Wait方法时,四个goroutine一开始都处于休眠状态,所以可能会导致WaitGroupAdd方法还没被调用,此时WaitGroup的计数值为0,所以main goroutine并没有等待四个goroutine都执行完毕后才继续执行后续,而是直接立刻执行后续的逻辑。

导致上述错误的原因便是没有遵循先调用完成所有的Add方法再进行Wait

如果要解决上述场景的问题,一种方法是预先设置计数值

func operateWg(mill time.Duration, wg *sync.WaitGroup) {
    duration := mill * time.Millisecond
    time.Sleep(duration) // 故意sleep一段时间

    fmt.Println("后台执行, duration:", duration)
    wg.Done()
}

func main() {
    var wg sync.WaitGroup
    wg.Add(4) // 预先设定WaitGroup的计数值

    go operateWg(100, &wg) // 启动第一个goroutine
    go operateWg(110, &wg) // 启动第二个goroutine
    go operateWg(120, &wg) // 启动第三个goroutine
    go operateWg(130, &wg) // 启动第四个goroutine

    wg.Wait() // 等待所有goroutine执行完毕
    fmt.Println("Done")
}

// 执行结果
后台执行, duration: 100ms
后台执行, duration: 110ms
后台执行, duration: 120ms
后台执行, duration: 130ms
Done 

另一种方式则是在子goroutine启动之前,先调用Add方法增加计数值:

func operateWg(mill time.Duration, wg *sync.WaitGroup) {
    wg.Add(1)

    go func() {
       duration := mill * time.Millisecond
       time.Sleep(duration) // 故意sleep一段时间
       fmt.Println("后台执行, duration:", duration)
       wg.Done()
    }()
}

func main() {
    var wg sync.WaitGroup

    operateWg(100, &wg) // 调用方法,将计数值+1,并启动goroutine
    operateWg(110, &wg) // 调用方法,将计数值+1,并启动goroutine
    operateWg(120, &wg) // 调用方法,将计数值+1,并启动goroutine
    operateWg(130, &wg) // 调用方法,将计数值+1,并启动goroutine

    wg.Wait() // 等待所有goroutine执行完毕
    fmt.Println("Done")
}

// 执行结果
后台执行, duration: 100ms
后台执行, duration: 110ms
后台执行, duration: 120ms
后台执行, duration: 130ms
Done

3、前一个Wait还没结束就重用 WaitGroup

"前一个Wait还没结束就重用 WaitGroup",可以借用田径比赛的例子来说明,一般来说,例如100米的田径比赛,都会把选手们分为多个组,一组接着一组进行比赛,当一组选手比赛完之后,才进行下一组的比赛,为了确保每组的比赛时间不会冲突。

WaitGroup等一组比赛的所有选手都跑完, 5分钟过后才开始下一组比赛。下一组比赛还可以使用这个WaitGroup来控制,因为WaitGroup是可以重用的。

只要WaitGroup的计数值恢复到零值的状态,那么它就可以被看作是新创建的WaitGroup,被重复使用。但如果在WaitGroup的计数值还没有恢复到零值就重用的话,会导致程序panic

举个例子:

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        time.Sleep(time.Millisecond)
        wg.Done() // 计数器减1
        wg.Add(1) // 计数值加1
    }()
    wg.Wait() // 主goroutine等待,有可能和第7行并发执行
}

上述代码中,初始化WaitGroup的计数值为1,随后启动一个goroutine,该goroutine先调用Done方法,接着就调用Add方法,此时Add方法有可能和main goroutine并发执行。

在第6行代码中,虽然让WaitGroup的计数值恢复到0,但在第9行中,main goroutine调用了Wait方法正在等待,如果等待Waitmain goroutine,在刚被唤醒就和Add调用方法(第7行)有并发执行冲突,就会出现panic

因此,如果需要重用WaitGroup,必须要等到上一轮的WaitGroup计数值恢复到0后,才能重用WaitGroup执行下一轮的 Add/Wait,如果在 Wait 还没执行完的时候就调用下一轮 Add 方法,就有可能出现 panic

五、总结

WaitGroup在使用上,实际并没有这么的复杂,只需要在使用的过程中,注意几点就可以很好的避免错误使用 WaitGroup的情况.

  • 不重用WaitGroup,尽量采取新建WaitGroup的方式,避免重用带来的意外错误。
  • 保证所有的Add方法调用在Wait方法之前。
  • 不传递负数给Add方法,通过Done来给计数值-1。
  • 同一个WaitGroup,保证Add的计数值和Done方法调用的数量是一样的。
  • 在任务执行完毕后,不遗漏Done方法的调用。

参考文章

mp.weixin.qq.com/s/64eWxeB0x… time.geekbang.org/column/arti…

转载自:https://juejin.cn/post/7322811509535440930
评论
请登录