likes
comments
collection
share

揭秘 Go 并发利器 WaitGroupGo 语言中 WaitGroup 的工作原理、使用方法、最佳实践及注意事项,包括

作者站长头像
站长
· 阅读数 13

在 Go 语言的并发编程世界中,WaitGroup 是一个至关重要的工具,它为开发者提供了一种简单而有效的方式来管理和同步多个协程的执行。本文将深入揭秘 WaitGroup 的实现原理、注意事项、使用示例。

什么是 WaitGroup

WaitGroup 是 Go 标准库中 sync 包提供的一种同步原语,用于等待一组(可能是并发的)操作完成。它的主要作用是让主协程(即调用 WaitGroup 相关方法的协程)能够等待其他协程完成任务后再继续执行,确保所有并发操作都按预期完成。

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    // 模拟一些工作
    for i := 0; i < 5; i++ {
        fmt.Printf("Worker %d working... step %d\n", id, i)
    }
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup

    // 设置等待组计数器为 3,表示有三个协程需要等待
    wg.Add(3)

    // 启动三个协程
    go worker(1, &wg)
    go worker(2, &wg)
    go worker(3, &wg)

    // 等待所有协程完成
    wg.Wait()
    fmt.Println("All workers are done.")
}

WaitGroup 的核心方法

WaitGroup对外提供了**Add、Done、Wait **三个方法,这三个方法需要搭配使用。

Add 方法

  • 功能:用于设置 WaitGroup 需要等待的操作数量。这个方法接受一个整数参数 delta,可以是正数或负数,用来增加或减少等待的协程数量。

  • 使用场景:通常在创建协程之前调用 Add 方法来指定需要等待的协程数量。例如,如果要启动 10 个协程并等待它们全部完成,就需要调用 wg.Add(10)。

  • 注意事项:增加计数的 Add 调用应该在协程启动之前完成,否则可能会导致 wait 方法提前结束等待,因为 wait 方法是根据 Add 设置的计数来判断是否所有协程都已完成。

func main() {
    var wg sync.WaitGroup

    // 设置等待组计数器为 3,表示有三个协程需要等待
    wg.Add(3)
    
    // ...
}

Done 方法

  • 功能:每当一个协程完成任务时,需要调用 Done 方法来通知 WaitGroup,该协程的任务已经完成。实际上,Done 方法内部是调用了 Add(-1),将等待的协程数量减一。

  • 使用场景:在协程的任务函数中,当协程完成了自己的工作后,应该立即调用 Done 方法。例如:

Wait 方法

  • 功能:阻塞调用,直到 WaitGroup 的计数值变成 0,即所有被等待的协程都完成了任务。

  • 使用场景:在主协程中,当启动了多个协程并使用 add 方法设置了等待的协程数量后,调用 wait 方法来阻塞主协程,直到所有协程都完成任务。这样可以确保主协程在所有协程完成工作后再继续执行后续的代码。

func main() {
    var wg sync.WaitGroup

    // 设置等待组计数器为 3,表示有三个协程需要等待
    wg.Add(3)

    // 启动三个协程
    go worker(1, &wg)
    go worker(2, &wg)
    go worker(3, &wg)

    // 等待所有协程完成
    wg.Wait()
    fmt.Println("All workers are done.")
}

WaitGroup 的实现原理

WaitGroup结构体

type WaitGroup struct {
    noCopy noCopy

    state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
    sema  uint32
}

WaitGroup 是一个结构体,里面有 state 和 sema 两个核心字段,其中:

  • state 字段是atomic.Uint64类型,它的高32位是子协程的计数器(调用Add的总和),低32位是主携程计数器(调用Wait阻塞的协程数)。
func main() {
    var wg sync.WaitGroup
    wg.Add(3)
    go worker(1, &wg)
    go worker(2, &wg)
    go worker(3, &wg)
    wg.Wait()
}

当 worker 没有执行完时,state 的内存模型如下图所示:

揭秘 Go 并发利器 WaitGroupGo 语言中 WaitGroup 的工作原理、使用方法、最佳实践及注意事项,包括

state 的一些操作:

int32(state >> 32) // 取高 32 位的值
uint32(state) // 取低 32 位的值

wg.state.Add(uint64(delta) << 32) // 高 32 位加减
wg.state.CompareAndSwap(state, state+1) // 低 32 位操作
  • sema 字段是一个信号量,信号量(Semaphore)是一种用于多进程或多线程同步和互斥的机制。信号量是一个整数变量,通常与两个操作相关联:P操作(也称为wait、down、acquire等)和V操作(也称为signal、up、release等)。

sema 的初始值是 0,调用 Wait 方法会进行 P 操作,如果此时没有 V 操作,Wait 方法就会阻塞,然后子协程执行完会调用 Done 方法,当 state 高32位为0时,就会进行 V 操作,这时 Wait 方法就会被唤醒继续执行。

Add 方法源码

func (wg *WaitGroup) Add(delta int) {
    // 省略 race 相关代码
    
    // 1. 原子更新 state 的值
    state := wg.state.Add(uint64(delta) << 32)
    
    // 2. 通过位操作获取子协程计数器和主协程计数器的值
    v := int32(state >> 32)
    w := uint32(state)
    
    // 省略 race 相关代码
    
    // 3. v < 0:抛出 panic , v 的值不可以是负值
    if v < 0 {
       panic("sync: negative WaitGroup counter")
    }
    
    // 4. Wait 和 Add() 不能同时被调用,否则会抛出 panic
    // 4.1 w != 0 说明 Wait 方法已经被调用但是还没返回
    // 4.2 delta > 0 && v == int32(delta) 说明调用了 Add() 方法
    if w != 0 && delta > 0 && v == int32(delta) {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    
    // 5. 执行到此处可能有两种情况
    // 5.1 Wait 方法还没被调用(w == 0),此时调用 Add() 或者 Done() 都直接返回,不需要进行 V 操作
    // 5.2 Wait 方法已经被调用(w != 0),此时只能调用 Done(), 若果 v > 0 说明子协程没有全部执行完,可以直接返回,不需要进行 V 操作
    if v > 0 || w == 0 {
       return
    }
    
    // 6. 执行到这里说明 v == 0 && w != 0,所有子协程都已经执行完,
    // v == 0 时调用 Wait() 并不会更改state,再次检查 state 防止有并发调用 Add,
    if wg.state.Load() != state {
       panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    
    // 7. 执行 V 操作,释放所有因调用 Wait 而阻塞的协程
    wg.state.Store(0)
    for ; w != 0; w-- {
       runtime_Semrelease(&wg.sema, false, 0)
    }
}

关键逻辑已经添加注释了,总结一下 Add 主要操作:

  1. 更新 state,通常 Add 是增加操作,Done 是减少操作;

  2. 校验state,防止有 v < 0 或者错误调用 Add 函数的情况;

  3. 如果 state 符合预期,执行 V 操作,释放所有因调用 Wait 而阻塞的协程。

Done 方法源码

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

Done() 的底层实现是调用了 Add() 方法,参数是 -1 代表完成一个任务。

Wait 方法源码

func (wg *WaitGroup) Wait() {
    // 省略 race 相关代码
    
    for {
       state := wg.state.Load()
       v := int32(state >> 32)
       w := uint32(state)
       if v == 0 {
          // 1. Counter is 0, no need to wait.
          // 省略 race 相关代码
          return
       }
       
       // 2. Increment waiters count.
       if wg.state.CompareAndSwap(state, state+1) {
          // 省略 race 相关代码
          
          // 3. P 操作,阻塞当前进程
          runtime_Semacquire(&wg.sema)
          if wg.state.Load() != 0 {
             // 4. 当前协程已经被唤醒,此时应该 v == 0,Wait 没有返回前不可以复用 WaitGroup
             panic("sync: WaitGroup is reused before previous Wait has returned")
          }
          // 省略 race 相关代码
          return
       }
    }
}

Wait 的主要逻辑:

  1. 判断 v 的值,如果 v == 0 , Wait 可以直接返回,v == 0 说明没有需要等待的子协程;

  2. 使用 CompareAndSwap 进行 state + 1 操作,如果执行成功进行下面步骤,如果不成功开启新一轮Wait逻辑;

  3. 如果 state + 1 操作成功后,需要进行 P 操作,阻塞当前进程;

  4. 当前协程已经被唤醒,再次校验 state 的值,此时应该 state == 0,Wait 没有返回前不可以复用 WaitGroup。

WaitGroup 的注意事项

正确使用 adddone方法

确保在启动协程之前正确地调用 add 方法来设置等待的协程数量,并且在协程完成任务后及时调用 done 方法。避免忘记调用 done 方法导致程序永远阻塞在 wait 上,或者超量调用 done 方法导致计数器变为负数而引发 panic。

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
        defer func(){
                // 调用了两次 Done() 方法
                wg.Done()
                wg.Done()
        }()
        
    
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)

    go worker(1, &wg)

    wg.Wait()
    fmt.Println("All workers are done.")
}

$ go run main.go 
Worker 1 done

panic: sync: negative WaitGroup counter

goroutine 18 [running]:
sync.(*WaitGroup).Add(0xc00007e020, 0xffffffffffffffff)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:74 +0x139
sync.(*WaitGroup).Done(...)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:99
main.worker.func1(0xc00007e020)
        /box/main.go:11 +0x4c
main.worker(0x1, 0xc00007e020)
        /box/main.go:16 +0xf2
created by main.main
        /box/main.go:22 +0x78
panic: sync: WaitGroup is reused before previous Wait has returned

goroutine 1 [running]:
sync.(*WaitGroup).Wait(0xc00007e020)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:132 +0xad
main.main()
        /box/main.go:24 +0x86

Exited with error status 2

合理复用 WaitGroup:

WaitGroup 对象可以在所有协程完成后重用。但是在重用时,要确保之前的 wait 方法已经返回,否则可能会出现不可预期的行为。

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer func(){
        // 调用了两次 Done() 方法
        wg.Done()
        go func(){
            wg.Add(1)
        }()
    }()
        
    
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)

    go worker(1, &wg)

    wg.Wait()
    fmt.Println("All workers are done.")
}

$ go run main.go
Worker 1 done

panic: sync: WaitGroup is reused before previous Wait has returned

goroutine 1 [running]:
sync.(*WaitGroup).Wait(0xc000016060)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:132 +0xad
main.main()
        /box/main.go:27 +0x86

Exited with error status 2

不要复制 WaitGroup:

WaitGroup 实例是不期望被复制的,如果复制后需要当做不同的实例看待,如果错误的使用了复制后的实例,可能造成协程泄漏:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)

    go func(){
        // 复制一个 wg 
        wg := wg
        defer wg.Done()
        
        fmt.Printf("Worker done\n")
    }()

    wg.Wait()
    fmt.Println("All workers are done.")
}

& go run main.go 
Worker done

fatal error: all goroutines are asleep - deadlock!

goroutine 1 [semacquire]:
sync.runtime_Semacquire(0xc000016068)
        /usr/local/go-1.13.5/src/runtime/sema.go:56 +0x42
sync.(*WaitGroup).Wait(0xc000016060)
        /usr/local/go-1.13.5/src/sync/waitgroup.go:130 +0x64
main.main()
        /box/main.go:20 +0x7d

Exited with error status 2

WaitGroup 的示例

并行计算

假设我们需要计算一个大型数组中每个元素的平方值,可以将数组分成多个部分,每个部分由一个协程来处理。使用 WaitGroup 可以确保所有协程都完成计算后再汇总结果。

package main

import (
    "fmt"
    "sync"
)

func square(wg *sync.WaitGroup, slice []int, result chan<- int) {
    defer wg.done()
    for _, v := range slice {
        result <- v * v
    }
}

func main() {
    var wg sync.WaitGroup
    slice := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    result := make(chan int, len(slice))

    // 将数组分成 3 个部分,每个部分由一个协程处理
    wg.Add(3)
    go square(&wg, slice[:3], result)
    go square(&wg, slice[3:6], result)
    go square(&wg, slice[6:], result)

    // 等待所有协程完成
    wg.Wait()
    close(result)

    // 汇总结果
    var total int
    for v := range result {
        total += v
    }
    fmt.Println(total)
}

并发请求

在开发中,当需要同时处理多个 RPC 请求时,可以使用 WaitGroup 来确保所有请求都处理完成后再返回响应。

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup

    wg.Add(3)
    go func() {
       defer wg.Done()
       fmt.Println("call rpc1 ...")
    }()
    go func() {
       defer wg.Done()
       fmt.Println("call rpc2 ...")
    }()
    go func() {
       defer wg.Done()
       fmt.Println("call rpc3 ...")
    }()

    // 等待所有请求处理完成
    wg.Wait()
    fmt.Println("所有请求处理完成")
}

总结

WaitGroup 是 Go 语言中非常强大的并发控制工具,它能够帮助开发者轻松地管理和同步多个协程的执行,确保并发操作的正确执行顺序。通过正确地使用 WaitGroup,开发者可以编写出高效、可靠的并发程序,充分发挥 Go 语言的并发优势。在实际应用中,我们需要深入理解 WaitGroup 的工作原理和使用方法,避免常见的错误,以确保程序的正确性和性能。

转载自:https://juejin.cn/post/7415912074994171944
评论
请登录