likes
comments
collection
share

我该怎么等你: sync.WaitGroup

作者站长头像
站长
· 阅读数 61

引言

假设你有一大批任务需要处理,你很容易会想到开启多协程并发去处理,等待所有的协程处理完毕,便完成了整个任务的处理。那么怎么做到等待多个协程处理完毕的阻塞语义呢?

要不sleep一会儿?不过好像sleep的时间不好确定

func Run1() {
    for range [10]struct{}{} {
        go func() {
            // do something
            fmt.Println("running ...")
        }()
    }
    time.Sleep(time.Second)
    fmt.Println("done ...")
}

使用channel来实现阻塞的语义 好像确实可行

但是需要额外处理好channel的关闭

func Run2() {

    var (
        count    int64 = 0
        doneChan       = make(chan struct{}, 1)
    )
    for range [10]struct{}{} {
        go func() {
            // do something
            fmt.Println("running ...")
            if new := atomic.AddInt64(&count, 1); new == 10 {
                doneChan <- struct{}{}
                close(doneChan)
            }
        }()
    }
    _ = <-doneChan
    fmt.Println("done ...")
}

好在 sync.WaitGroup 便可以解决上述问题

A WaitGroup waits for a collection of goroutines to finish. The main goroutine calls Add to set the number of goroutines to wait for. Then each of the goroutines runs and calls Done when finished. At the same time, Wait can be used to block until all goroutines have finished.

下面是一个错误的使用示例 先卖一个关子,我们继续往下看

  • Add 声明添加一个处理者
  • Done 声明一个处理者处理完毕
  • Wait 阻塞等待所有的处理者完成

func Run3() {
    var wg sync.WaitGroup
    for range [10]struct{}{} {
        go func() {
            wg.Add(1)
            fmt.Println("running ...")
             wg.Done()
        }()
    }

    wg.Wait()
    fmt.Println("done ...")
}

// 要么会panic 
// 要么发现并没有等待所有的处理者结束便打印了 "done ..."

核心结构

state1(uint64)state2(uint32)共同表示了

  • 正在处理任务的协程数 counter 通过Add()进行+1Done()进行-1
  • 等待上述操作全部完成的协程数 waiter,通过Wait()阻塞等待
  • 信号量 sema 实现阻塞和唤醒语义
type WaitGroup struct {
    noCopy noCopy

    // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
    // 64-bit atomic operations require 64-bit alignment, but 32-bit
    // compilers only guarantee that 64-bit fields are 32-bit aligned.
    // For this reason on 32 bit architectures we need to check in state()
    // if state1 is aligned or not, and dynamically "swap" the field order if
    // needed.
    state1 uint64
    state2 uint32
}

辅助函数 state 返回state(counter|waiter)和信号量

// state returns pointers to the state and sema fields stored within wg.state*.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        // state1 is 64-bit aligned: nothing to do.
        return &wg.state1, &wg.state2
    } else {
        // state1 is 32-bit aligned but not 64-bit aligned: this means that
        // (&state1)+4 is 64-bit aligned.
        state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
        return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
    }
}

为什么需要不同的处理?

32 位系统下,使用 atomic 对 64 位变量进行原子操作,调用者需要自行保证该变量是64位内存对齐的~

On ARM , x86-32, and 32-bit MIPS , it is the caller’s responsibility to arrange for 64-bit alignment of 64-bit words accessed atomically. The first word in a variable or in an allocated struct, array, or slice can be relied upon to be 64-bit aligned.

辅助函数 state() 用于返回 statep变量(counter|waiter)和semap 信号量

针对state1 是32位内存对齐还是64位内存对齐 需要进行不同的翻译

  • state1已经是64位内存对齐的,这时候这个字段无论是在64位系统还是32位系统下进行原子操作都是安全的。所以可以直接返回state1,高32位代表counter数,低32位代表waiter
  • 如果state1变量不是64位内存对齐的(说明当前处于32位操作系统上,该字段肯定是32位内存对齐的),则可以理解为在原来的offset上加上32就一定是64位内存对齐了。所以把state1的低32位和state2当作一个字段(counter|waiter)返回,state1的高32位当作信号量返回

经过state()方法的抽象,对外只暴露信号量sema和state变量,无需关心不同系统之间的差异

64位内存对齐下

state1state2(高32位)state2(低32位)
counterwaitersema

32位内存对齐下

state1state2(高32位)state2(低32位)
semacounterwaiter

为什么counterwaiter只使用一个变量?

因为counterwaiter 进行增减需要保证并发安全,而为了追求极致性能,没有引入锁,而是基于64位的变量(counter|waiter)的原子操作(AddUint64CompareAndSwapUint64)来保证并发读写安全

实现原理

Add

声明新增或者移除 delta 个 处理者 counter

当delta大于0时 代表新增

delta小于0时 代表移除 本质上Done()方法调用的就是Add(-1)

  • 使用AddUint64counter变量原子的加减

  • 如果发现 counter==0(说明所有的处理者都处理完毕)且存在waiter(有协程调用wait()在阻塞等待),则一次性清空所有的waiter(而不是由所有的wait()调用方操作waiter--,避免了多次的原子操作),然后再调用runtime_Semrelease唤醒所有的waiter

什么情况下 调用Add 会panic呢? panic说明该方法使用姿势不对

  • counter出现负数时(说明counter的添加和移除数量不匹配)
  • 并发调用Add()Wait() 时(所有的Add()操作应该在调用Wait()前)
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    if race.Enabled {
        // ...
    }
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    // counter 数量
    v := int32(state >> 32)
    // waiter 数量
    w := uint32(state)
    if race.Enabled && delta > 0 && v == int32(delta) {
        // ...
    }
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    // w!=0 说明已经有协程调用过 wait()
    // delta大于0 说明 正在添加counter
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
        return
    }
    // This goroutine has set counter to 0 when waiters > 0.
    // Now there can't be concurrent mutations of state:
    // - Adds must not happen concurrently with Wait,
    // - Wait does not increment waiters if it sees counter == 0.
    // Still do a cheap sanity check to detect WaitGroup misuse.
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // Reset waiters count to 0.
    // 当调用Add方法发现最后一个处理者也已经完成时,选择直接清空waiter数量
    // 而不是由所有的wait()调用方操作waiter--,避免了多次的原子操作 AddUint64(statep, -1)
    *statep = 0
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

Done

本质上就是调用Add(-1)

// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

Wait

wait() 用于等待所有的处理者完成

  • 当发现counter==0时,无需阻塞,直接返回
  • 否则,通过CAS指令对waiter进行+1,调用runtime_Semacquire实现阻塞等待
  • 当所有的counter都处理完成时,便会唤醒所有阻塞在该信号量上的协程
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    // 省略部分不重要代码
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        w := uint32(state)
        if v == 0 {
            // Counter is 0, no need to wait.
            // ...
            return
        }
        // Increment waiters count.
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            if race.Enabled && w == 0 {
                // Wait must be synchronized with the first Add.
                // Need to model this is as a write to race with the read in Add.
                // As a consequence, can do the write only for the first waiter,
                // otherwise concurrent Waits will race with each other.
                race.Write(unsafe.Pointer(semap))
            }
            runtime_Semacquire(semap)
           // 说明该waitegroup 被重用了
           // 还没有等待所有的wait()返回就调用了Add()操作
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            // 省略部分不重要代码
            return
        }
    }
}
转载自:https://juejin.cn/post/7352162941313613875
评论
请登录