likes
comments
collection
share

6.Go语言实现生产级线程池我们都知道java中的线程池非常好用,但是在转变到Go语言的时候,发现没有现成的线程池功能,

作者站长头像
站长
· 阅读数 47

目录

  1. 背景
  2. 简单线程池实现​
  3. 优化1:任务超时控制
  4. 优化2:任务失败信息收集
  5. 优化3:线程池配置参数化
  6. 优化4:完善panic的捕获和错误处理
  7. 优化5:完善监控线程池状态

背景

我们都知道java中的线程池非常好用,但是在转变到Go语言的时候,发现没有现成的线程池功能,这里就来介绍下如何实现一个生产级的线程池,从一个简单的线程池开始,然后逐步实现一个生产级的线程池。

简单线程池实现​

package main

import (
 "sync"
)

// 定义任务Job,包含一个执行函数和参数
type Job struct {
 Exec func(data interface{}) // 需要执行的函数
 Data interface{}            // 数据
}

// 定义线程池,基于Go的协程实现
type Pool struct {
 jobs      chan Job // 任务队列
 quit      chan bool // 用于停止的信号
 wg        sync.WaitGroup // 等待执行任务完成
 workerNum int // 工作例程数量
}

// 创造一个新的线程池
func NewPool(workerNum int, jobQueueLen int) *Pool {
 return &Pool{
  jobs:      make(chan Job, jobQueueLen),
  quit:      make(chan bool),
  workerNum: workerNum,
 }
}

// 开启线程池
func (p *Pool) Start() {
 p.wg.Add(p.workerNum)
 for i := 0; i < p.workerNum; i++ {
  go func() {
   for {
    select {
    // 读取任务并执行
    case job := <-p.jobs:
          job.Exec(job.Data)
     // 收到advise to quit的信号
    case <-p.quit:
     p.wg.Done()
     return
    }
   }
  }()
 }
}

// 停止线程池
func (p *Pool) Stop() {
  for i := 0; i < p.workerNum; i++ {
        p.quit <- true
    }
    close(p.jobs)
    p.wg.Wait()
}

// 提交任务到线程池
func (p *Pool) AddJob(job Job) {
 p.jobs <- job
}

对于上述代码,创建了一个固定数量的goroutine作为worker,任务被提交至任务队列中,worker会接收任务并执行。如果队列未满,主例程可以继续提交任务至队列,而不会被阻塞。我们可以通过Pool.Start()方法启动线程池,通过Pool.Stop()方法停止线程池,也可以通过Pool.AddJob()方法向线程池提交任务

优化1:任务超时控制

我们可以通过利用Go的time包和"context"包来实现任务的超时控制。具体思路是:在提交job的时候,我们为每个job创建一个带有超时控制的context,然后在执行job的时候,通过select结构同时监听context的Done信号和任务完成信号,这样如果任务在规定时间内没有执行完,就会自动返回一个超时错误

package main

import (
 "sync"
)

// 定义任务Job,包含一个执行函数和参数
type Job struct {
  Exec func(context.Context, interface{}) error  // 需要执行的函数,增加了context参数
  Data interface{}            // 任务数据
  Ctx  context.Context   // 为每个job创建的context
}

// 定义线程池,基于Go的协程实现
type Pool struct {
 jobs      chan Job // 任务队列
 quit      chan bool // 用于停止的信号
 wg        sync.WaitGroup // 等待执行任务完成
 workerNum int // 工作例程数量
}

// 创造一个新的线程池
func NewPool(workerNum int, jobQueueLen int, jobTimeout time.Duration) *Pool {
 return &Pool{
  jobs:      make(chan Job, jobQueueLen),
  quit:      make(chan bool),
  wg:        sync.WaitGroup{}, // 初始化为空
  workerNum: workerNum,
  jobTimeout: jobTimeout,  //为每一个任务设置超时时间
 }
}

// 开启线程池
func (p *Pool) Start() {
 p.wg.Add(p.workerNum)
 for i := 0; i < p.workerNum; i++ {
  go func() {
  for {
    select {
      case job := <-p.jobs:
        // 同时监听任务完成和超时信号
        done := make(chan error, 1)
        go func() {
          done <- job.Exec(job.Ctx, job.Data)
        }()

        select {
          case <-done:
            // 任务完成
          case <-job.Ctx.Done():
            // 任务超时
        }

      case <-p.quit:
        p.wg.Done()
        return
    }
  }
}()
 }
}

// 停止线程池
func (p *Pool) Stop() {
  for i := 0; i < p.workerNum; i++ {
        p.quit <- true
    }
    close(p.jobs)
    p.wg.Wait()
}

// 提交任务到线程池
func (p *Pool) AddJob(jobFunc func(context.Context, interface{}) error, data interface{}) {
  ctx, _ := context.WithTimeout(context.Background(), p.jobTimeout)
  p.jobs <- Job{Exec: jobFunc, Data: data, Ctx: ctx}
}

注意这种方式的一个问题是,虽然我们能得到任务超时的错误并且停止等待任务的完成,但实际上我们无法真正停止正在执行的任务,这可能会导致系统资源的浪费

也就是需要我们的job需要自己去监听这个ctx.Done()信号,并且在收到信号后自己及时结束并清理工作。如果job是一些阻塞的操作,如I/O操作,这个时候他可能就无法感知到ctx.Done()的信号,所以需要额外留意

func myJob(ctx context.Context, data interface{}) error {
    // 执行任务的具体逻辑
    // ...

    // 定期检查Context是否已被取消
    select {
    case <-ctx.Done():
        // 任务已被取消,立即退出
        return ctx.Err()
    default:
        // 继续执行任务
    }
}

优化2:任务失败信息收集

增加一个 error channel,用于接收执行任务的错误

package main

import (
 "context"
 "sync"
 "time"
)

// 定义任务Job,包含一个执行函数和参数
type Job struct {
 Exec func(context.Context, interface{}) error  // 需要执行的函数,增加了context参数
 Data interface{}            // 任务数据
 Ctx  context.Context   // 为每个job创建的context
}

// 定义线程池,基于Go的协程实现
type Pool struct {
 jobs      chan Job // 任务队列
 quit      chan bool // 用于停止的信号
 wg        sync.WaitGroup // 等待执行任务完成
 workerNum int // 工作例程数量
 jobTimeout time.Duration  //为每一个任务设置超时时间
 errors    chan error // 错误接收
}

// 创造一个新的线程池
func NewPool(workerNum int, jobQueueLen int, jobTimeout time.Duration) *Pool {
 return &Pool{
  jobs:      make(chan Job, jobQueueLen),
  quit:      make(chan bool),
  wg:        sync.WaitGroup{}, // 初始化为空
  workerNum: workerNum,
  jobTimeout: jobTimeout,
  errors:    make(chan error,100), // 初始化错误队列
 }
}

// 开启线程池
func (p *Pool) Start() {
 p.wg.Add(p.workerNum)
 for i := 0; i < p.workerNum; i++ {
  go func() {
    for {
      select {
        case job := <-p.jobs:
          // 同时监听任务完成和超时信号
          done := make(chan error, 1)
          go func() {
            done <- job.Exec(job.Ctx, job.Data)
          }()

          select {
            case err := <-done:
              // 任务完成,将可能的错误写入错误通道
              if err != nil {
                p.errors <- err
              }

            case <-job.Ctx.Done():
              // 任务超时
          }

        case <-p.quit:
          p.wg.Done()
          return
      }
    }
  }()
 }
}

// 停止线程池
func (p *Pool) Stop() {
  for i := 0; i < p.workerNum; i++ {
        p.quit <- true
    }
    close(p.jobs)
    p.wg.Wait()
}

// 提交任务到线程池
func (p *Pool) AddJob(jobFunc func(context.Context, interface{}) error, data interface{}) {
  ctx, _ := context.WithTimeout(context.Background(), p.jobTimeout)
  p.jobs <- Job{Exec: jobFunc, Data: data, Ctx: ctx}
}

// 获取错误
func (p *Pool) GetErrors() <-chan error {
 return p.errors
}

增加了一个错误接收通道,这是一个公共的错误通道,用于接收所有任务执行中产生的错误。

在任务执行的 Goroutine 中,当任务成功执行完毕后,会把返回的 err 对象写入错误通道。当任务执行出错时,会把这个错误写入错误通道。

这样,任务执行的错误就能被很好地收集并得到处理。如果需要对错误进行相应的处理,只需要提供一个 Goroutine 专门用来不断从错误通道读取错误,并根据错误进行相应的处理即可

优化3:线程池配置参数化

让线程池实现更加通用和可配置,可以根据不同场景灵活配置参数

package main

import (
 "context"
 "sync"
 "time"
)

// 定义任务Job,包含一个执行函数和参数
type Job struct {
 Exec func(context.Context, interface{}) error // 需要执行的函数,增加了context参数
 Data interface{}                              // 任务数据
 Ctx  context.Context                          // 为每个job创建的context
}

// 线程池配置
type PoolConfig struct {
 WorkerNum       int           // 工作协程数量
 JobQueueLen     int           // 任务队列长度
 JobTimeout      time.Duration // 任务超时时间
 ErrorBufferSize int           // 错误通道缓冲区大小
 EnableTimeout   bool          // 是否启用任务超时
 ParentContext   context.Context // 父级上下文
}

// 定义线程池,基于Go的协程实现
type Pool struct {
 jobs          chan Job      // 任务队列
 quit          chan bool     // 用于停止的信号
 wg            sync.WaitGroup // 等待执行任务完成
 workerNum     int           // 工作例程数量
 jobTimeout    time.Duration // 为每一个任务设置超时时间
 errors        chan error    // 错误接收
 enableTimeout bool          // 是否启用任务超时
 parentCtx     context.Context // 父级上下文
}

// 创造一个新的线程池
func NewPool(cfg *PoolConfig) *Pool {
 return &Pool{
  jobs:          make(chan Job, cfg.JobQueueLen),
  quit:          make(chan bool),
  wg:            sync.WaitGroup{}, // 初始化为空
  workerNum:     cfg.WorkerNum,
  jobTimeout:    cfg.JobTimeout,
  errors:        make(chan error, cfg.ErrorBufferSize), // 初始化错误队列
  enableTimeout: cfg.EnableTimeout,
  parentCtx:     cfg.ParentContext,
 }
}

// 开启线程池
func (p *Pool) Start() {
 p.wg.Add(p.workerNum)
 for i := 0; i < p.workerNum; i++ {
  go func() {
   for {
    select {
    case job := <-p.jobs:
     // 同时监听任务完成和超时信号
     done := make(chan error, 1)
     go func() {
      done <- job.Exec(job.Ctx, job.Data)
     }()

     if p.enableTimeout {
      select {
      case err := <-done:
       // 任务完成,将可能的错误写入错误通道
       if err != nil {
        p.errors <- err
       }
      case <-job.Ctx.Done():
       // 任务超时
      }
     } else {
      err := <-done
      if err != nil {
       p.errors <- err
      }
     }
    case <-p.quit:
     p.wg.Done()
     return
    }
   }
  }()
 }
}

// 停止线程池
func (p *Pool) Stop() {
 for i := 0; i < p.workerNum; i++ {
        p.quit <- true
    }
    close(p.jobs)
    p.wg.Wait()
}

// 提交任务到线程池
func (p *Pool) AddJob(jobFunc func(context.Context, interface{}) error, data interface{}) {
 ctx, cancel := context.WithTimeout(p.parentCtx, p.jobTimeout)
 defer cancel()
 p.jobs <- Job{Exec: jobFunc, Data: data, Ctx: ctx}
}

// 获取错误
func (p *Pool) GetErrors() <-chan error {
 return p.errors
}
  1. 添加了PoolConfig结构体,包含了线程池的各种配置参数。
  2. NewPool函数现在接收一个PoolConfig作为参数,使用这些配置初始化线程池。
  3. 在Start函数中,根据enableTimeout的值决定是否启用任务超时机制。
  4. 在AddJob函数中,使用parentCtx作为父级上下文创建每个Job的上下文。
  5. 在AddJob函数中,使用defer cancel()来确保在任务完成或超时后,该Job的上下文被取消。

优化4:完善panic的捕获和错误处理

为了提高代码的健壮性和可读性,对一些未知的panic进行处理

package main

import (
 "context"
 "sync"
 "time"
)

// Job 定义任务,包含一个执行函数和参数
type Job struct {
 Exec func(context.Context, interface{}) error // 需要执行的函数,增加了context参数
 Data interface{}                              // 任务数据
 Ctx  context.Context                          // 为每个job创建的context
}

// PoolConfig 线程池配置
type PoolConfig struct {
 WorkerNum       int           // 工作协程数量
 JobQueueLen     int           // 任务队列长度
 JobTimeout      time.Duration // 任务超时时间
 ErrorBufferSize int           // 错误通道缓冲区大小
 EnableTimeout   bool          // 是否启用任务超时
 ParentContext   context.Context // 父级上下文
}

// Pool 定义线程池,基于Go的协程实现
type Pool struct {
 jobs          chan Job      // 任务队列
 quit          chan bool     // 用于停止的信号
 wg            sync.WaitGroup // 等待执行任务完成
 workerNum     int           // 工作例程数量
 jobTimeout    time.Duration // 为每一个任务设置超时时间
 errors        chan error    // 错误接收
 enableTimeout bool          // 是否启用任务超时
 parentCtx     context.Context // 父级上下文
}

// NewPool 创造一个新的线程池
func NewPool(cfg *PoolConfig) *Pool {
 return &Pool{
  jobs:          make(chan Job, cfg.JobQueueLen),
  quit:          make(chan bool),
  wg:            sync.WaitGroup{}, // 初始化为空
  workerNum:     cfg.WorkerNum,
  jobTimeout:    cfg.JobTimeout,
  errors:        make(chan error, cfg.ErrorBufferSize), // 初始化错误队列
  enableTimeout: cfg.EnableTimeout,
  parentCtx:     cfg.ParentContext,
 }
}

// Start 开启线程池
func (p *Pool) Start() {
 p.wg.Add(p.workerNum)
 for i := 0; i < p.workerNum; i++ {
  go func() {
   defer p.wg.Done()
   for {
    select {
    case job := <-p.jobs:
     // 同时监听任务完成和超时信号
     done := make(chan error, 1)
     go func() {
      done <- job.Exec(job.Ctx, job.Data)
     }()

     if p.enableTimeout {
      select {
      case err := <-done:
       // 任务完成,将可能的错误写入错误通道
       if err != nil {
        p.errors <- err
       }
      case <-job.Ctx.Done():
       // 任务超时
      }
     } else {
      err := <-done
      if err != nil {
       p.errors <- err
      }
     }
    case <-p.quit:
     return
    }
   }
  }()
 }
}

// Stop 停止线程池
func (p *Pool) Stop() {
 for i := 0; i < p.workerNum; i++ {
  p.quit <- true
 }
 close(p.jobs)
 close(p.errors)
 p.wg.Wait()
}

// AddJob 提交任务到线程池
func (p *Pool) AddJob(jobFunc func(context.Context, interface{}) error, data interface{}) {
 ctx, cancel := context.WithTimeout(p.parentCtx, p.jobTimeout)
 defer cancel()
 p.jobs <- Job{Exec: jobFunc, Data: data, Ctx: ctx}
}

// GetErrors 获取错误
func (p *Pool) GetErrors() <-chan error {
 return p.errors
}
  1. 在 Start 方法中,对 go 关键字启动的新 goroutine 添加 defer p.wg.Done()。这样可以确保无论该 goroutine 是正常退出还是发生 panic,都能正确地减少 WaitGroup 的计数器。
  2. 在 AddJob 方法中,添加 defer cancel()。这样可以确保即使在执行 p.jobs <- Job{…} 时发生了 panic,也能正常取消 context。
  3. 在 Stop 方法中,添加 close(p.errors) 操作,以确保 GetErrors 方法能正常返回,不会发生阻塞。

优化5:完善监控线程池状态

添加一个监控线程池状态的协程,用于实时监控线程池的状态。

package main

import (
 "context"
 "sync"
 "time"
 "sync/atomic"
)

// Job 定义任务,包含一个执行函数和参数
type Job struct {
 Exec func(context.Context, interface{}) error // 需要执行的函数,增加了context参数
 Data interface{}                              // 任务数据
 Ctx  context.Context                          // 为每个job创建的context
}

// PoolConfig 线程池配置
type PoolConfig struct {
 WorkerNum       int           // 工作协程数量
 JobQueueLen     int           // 任务队列长度
 JobTimeout      time.Duration // 任务超时时间
 ErrorBufferSize int           // 错误通道缓冲区大小
 EnableTimeout   bool          // 是否启用任务超时
 ParentContext   context.Context // 父级上下文
}

// PoolStat 定义线程池的状态
type PoolStat struct {
 NumJobs     int64 // 已添加的任务总数
 NumFinished int64 // 已完成的任务总数
 NumActive   int64 // 正在执行的任务数量
}

// Pool 定义线程池,基于Go的协程实现
type Pool struct {
 jobs          chan Job      
 quit          chan bool     
 wg            sync.WaitGroup
 workerNum     int
 jobTimeout    time.Duration
 errors        chan error   
 enableTimeout bool         
 parentCtx     context.Context
 numJobs       int64         
 numFinished   int64         
}

// NewPool 创造一个新的线程池
func NewPool(cfg *PoolConfig) *Pool {
 return &Pool{
  jobs:          make(chan Job, cfg.JobQueueLen),
  quit:          make(chan bool),
  wg:            sync.WaitGroup{}, // 初始化为空
  workerNum:     cfg.WorkerNum,
  jobTimeout:    cfg.JobTimeout,
  errors:        make(chan error, cfg.ErrorBufferSize), // 初始化错误队列
  enableTimeout: cfg.EnableTimeout,
  parentCtx:     cfg.ParentContext,
  numJobs:       0,
  numFinished:   0,
 }
}

// Start 开启线程池
func (p *Pool) Start() {
 p.wg.Add(p.workerNum)
 for i := 0; i < p.workerNum; i++ {
  go func() {
   defer p.wg.Done()
   for {
    select {
    case job := <-p.jobs:
     done := make(chan error, 1)
     go func() {
      done <- job.Exec(job.Ctx, job.Data) // 执行任务并将结果发送给done通道
     }()
     if p.enableTimeout {
      select {
      case err := <-done:
       // 任务完成,记录可能的错误并增加完成任务的数量
       if err != nil {
          p.errors <- err
       }
       atomic.AddInt64(&p.numFinished, 1)
      case <-job.Ctx.Done():
       // 任务超时,此处可以添加适当的处理代码
      }
     } else {
      // 若关闭了超时检查,则直接等待任务完成
      err := <-done
      if err != nil {
         p.errors <- err
      }
      atomic.AddInt64(&p.numFinished, 1)
     }
    case <-p.quit:
     // 收到停止信号,结束处理
     return
    }
   }
  }()
 }
}

// Stop 停止线程池
func (p *Pool) Stop() {
 for i := 0; i < p.workerNum; i++ {
  p.quit <- true
 }
 close(p.jobs)
 close(p.errors)
 p.wg.Wait()
}

// AddJob 提交任务到线程池
func (p *Pool) AddJob(jobFunc func(context.Context, interface{}) error, data interface{}) {
 ctx, cancel := context.WithTimeout(p.parentCtx, p.jobTimeout)
 defer cancel()
 p.jobs <- Job{Exec: jobFunc, Data: data, Ctx: ctx} // 添加一个任务到任务通道
 atomic.AddInt64(&p.numJobs, 1) // 增加任务数量
}

// GetErrors 获取错误
func (p *Pool) GetErrors() <-chan error {
 return p.errors
}

// Stat 获取线程池状态
func (p *Pool) Stat() PoolStat {
 return PoolStat{
  NumJobs:     atomic.LoadInt64(&p.numJobs), // 获取任务总数
  NumFinished: atomic.LoadInt64(&p.numFinished), //获取完成任务数
  NumActive:   atomic.LoadInt64(&p.numJobs) - atomic.LoadInt64(&p.numFinished), //计算正在执行的任务数
 }
}

上述示例实现了一个非常基本的监控,仅包括任务的添加,完成和正在进行的数量,并未涵盖Java线程池中的一些更复杂的信息,例如线程状态,线程池大小,拒绝的任务等。如果需要更深入的监控功能,可能需要使用更复杂的设计模式或库,或者采取定制化的解决方案

转载自:https://juejin.cn/post/7423359941854691379
评论
请登录