likes
comments
collection
share

让ChatGPT帮我写了个协程池~

作者站长头像
站长
· 阅读数 13

起因

最开始是想学习下通用协程池的写法,想自己实现一下。 然后想到了最近很火的ChatGPT,看看能不能让它帮我写一个,或者写完之后我再拿去优化一下~ 最终生成的代码可以直接划到文章末尾的最终版

动手

简略版

让ChatGPT帮我写了个协程池~

最开始我的提问很简单,就是希望它帮我生成一个Go的协程池, 它写完之后 ,本着不会百分百相信AI的原则,自己check一下 发现这个协程池其实是有几个问题的

  1. 超时控制:子func执行时间过长的话是无法处理或感知到的
  2. 资源耗尽:池中的func没有限制,无法管理资源

简略版分析

接着我想让ChatGPT帮我分析这个协程池的隐患

让ChatGPT帮我写了个协程池~ ChatGPT帮我分析出了四个可能的隐患

  1. 竞争条件:在Submit函数中,我们向任务通道中发送任务,而在worker函数中,我们从任务通道中读取任务。如果多个goroutine同时调用Submit函数来提交任务,那么它们可能会争夺向任务通道中发送任务的权限,从而导致竞争条件。为了避免这种情况,我们需要使用互斥锁或其他同步机制来保护任务通道的并发访问。
  2. 内存泄漏:在Submit函数中,我们创建了一个新的Task对象,并将其添加到任务通道中。如果我们没有在某个时刻清理这些Task对象,那么它们可能会一直占用内存,从而导致内存泄漏。为了避免这种情况,我们需要在worker函数中正确处理所有的任务,以确保它们在执行完成后能够被垃圾回收。 3.. 资源耗尽:在worker函数中,我们创建了一个新的goroutine来执行每个任务,并使用select语句来等待任务完成或超时。如果我们提交了大量的长时间运行的任务,那么这些goroutine可能会占用大量的系统资源,从而导致资源耗尽。为了避免这种情况,我们需要限制协程池的最大并发数,并且合理地设置任务的最长执行时间,以避免过多的长时间运行的任务同时执行。
  3. 任务超时:在worker函数中,我们使用time.After函数来设置任务的最长执行时间,并在select语句中等待任务完成或超时。如果我们设置的最长执行时间太短,那么可能会导致一些正常的任务被错误地标记为超时。如果我们设置的最长执行时间太长,那么可能会导致一些长时间运行的任务占用过多的系统资源,从而影响系统的性能和稳定性。为了避免这种情况,我们需要根据实际情况合理地设置任务的最长执行时间,以确保任务能够在预期的时间内完成。

改进

接着我让它帮我生成一个优化的版本

让ChatGPT帮我写了个协程池~

简单看了一下,它用了简单的生产-消费模型生成了一个支持单Task过期时间设置的协程池,单Task的过期时间使用方式符合预期,但是它引入了其他问题

  1. 内存泄漏:如果Task队列中积压了很多任务,会导致内存一直被占用,无法被释放
  2. 过期时间:由于任务在每个worker中属于串行消费,可能导致前面的func执行耗时过长,后面的func还没被消费到 就已经过期了

因此我让它再帮我改进一下~ 以下的代码其实经过多次调试改进,因为我感觉ChatGPT有时候无法理解上下文

package goroutine_pool

import (
   "context"
   "fmt"
   "sync"
   "time"
)

// Task represents a function to be executed in the worker pool
type Task struct {
   Handler func() error
   Timeout time.Duration
}

// Pool represents a worker pool
type Pool struct {
   size           int        // number of workers
   taskQueue      chan *Task // channel to store tasks
   ctx            context.Context
   cancel         context.CancelFunc
   wg             sync.WaitGroup
   isClosed       bool
   workerCallback func(*Worker)
   mutex          sync.Mutex
}

// Worker represents a worker in the pool
type Worker struct {
   id        int        // worker id
   taskQueue chan *Task // channel to receive tasks from the pool
   ctx       context.Context
   cancel    context.CancelFunc
   pool      *Pool // pointer to the pool
}

// NewPool creates a new worker pool with the given size
func NewPool(size int) *Pool {
   pool := &Pool{
      size:      size,
      taskQueue: make(chan *Task, size*10),
      isClosed:  false,
   }
   pool.ctx, pool.cancel = context.WithCancel(context.Background())
   return pool
}

// NewWorker creates a new worker and starts it
func NewWorker(id int, pool *Pool) *Worker {
   ctx, cancel := context.WithCancel(pool.ctx)
   worker := &Worker{
      id:        id,
      taskQueue: make(chan *Task),
      ctx:       ctx,
      cancel:    cancel,
      pool:      pool,
   }
   go worker.Start()
   return worker
}

// Start starts the worker
func (w *Worker) Start() {
   defer w.pool.wg.Done()
   for {
      select {
      case task := <-w.taskQueue:
         if task == nil {
            // channel closed
            return
         }
         if task.Timeout > 0 {
            _, cancel := context.WithTimeout(w.ctx, task.Timeout)
            defer cancel()
            w.pool.workerCallback(w)
            err := task.Handler()
            if err != nil {
               fmt.Printf("worker %d: error executing task: %v\n", w.id, err)
            }
         } else {
            w.pool.workerCallback(w)
            err := task.Handler()
            if err != nil {
               fmt.Printf("worker %d: error executing task: %v\n", w.id, err)
            }
         }
      case <-w.ctx.Done():
         // worker is shutting down
         return
      }
   }
}

// Submit adds a new task to the task queue
func (p *Pool) Submit(handler func() error, timeout time.Duration) error {
   if p.isClosed {
      return fmt.Errorf("pool is closed")
   }
   select {
   case p.taskQueue <- &Task{Handler: handler, Timeout: timeout}:
      return nil
   default:
      return fmt.Errorf("task queue is full")
   }
}

// Close closes the worker pool
func (p *Pool) Close() {
   p.mutex.Lock()
   defer p.mutex.Unlock()
   if p.isClosed {
      return
   }
   p.isClosed = true
   close(p.taskQueue)
   p.cancel()
   p.wg.Wait()
}

// SetWorkerCallback sets the callback function to be called before and after a worker executes a task
func (p *Pool) SetWorkerCallback(callback func(*Worker)) {
   p.workerCallback = callback
}

func main() {
   pool := NewPool(5)
   defer pool.Close()

   pool.SetWorkerCallback(func(w *Worker) {
      fmt.Printf("worker %d: start executing task\n", w.id)
   })

   for i := 0; i < 10; i++ {
      j := i
      err := pool.Submit(func() error {
         fmt.Printf("task %d: start\n", j)
         time.Sleep(time.Second)
         fmt.Printf("task %d: end\n", j)
         return nil
      }, time.Second*2)
      if err != nil {
         fmt.Printf("error submitting task: %v\n", err)
      }
   }
}

改进版隐患

我自己看了下代码,感觉现在这个版本已经足够了~但是不知道还有没有什么细节的隐患,因此再问了ChatGPT,同时希望它能帮我生成单测代码来验证这个Pool的正确性

让ChatGPT帮我写了个协程池~

ChatGPT告诉我这个pool有三个隐患:

  1. Task的Timeout是以阻塞方式处理的,如果task执行时间过长,可能会影响整个协程池的性能。
  2. Submit函数里的select中没有default分支,当task队列满了之后,调用方会被阻塞,而没有任何反馈。
  3. Submit函数并没有返回任务执行结果,无法判断任务是否执行成功。

但是我自己分析了一下: (1) 我觉得这个可以接受 (2) Submit函数里面其实有default分支的 (3) 这个是没办法避免的事情,或者说可以解决但是投入成本太大~当前方案就已经可以接受了

最终版

我希望线程池最好还是能够带上一定的可监控性,所以最后让ChatGPT帮我完善了一下之前的协程池~

package goroutine_pool

import (
   "context"
   "fmt"
   "sync"
   "time"
)

// 任务结构体
type Task struct {
   Handler func(ctx context.Context) error // 任务执行的函数
   Timeout time.Duration                   // 任务执行超时时间
}

// 协程池结构体
type Pool struct {
   capacity int            // 协程池容量
   workers  []*Worker      // 协程池中的协程
   taskChan chan *Task     // 任务通道
   wg       sync.WaitGroup // 等待所有协程执行完毕
}

// 协程结构体
type Worker struct {
   id        int           // 协程id
   taskChan  chan *Task    // 任务通道
   stopChan  chan struct{} // 停止协程的通道
   running   bool          // 协程是否在运行中
   terminate bool          // 是否需要终止协程
}

// 创建协程池
func NewPool(capacity int) *Pool {
   return &Pool{
      capacity: capacity,
      workers:  make([]*Worker, capacity),
      taskChan: make(chan *Task),
   }
}

// 创建一个协程
func newWorker(id int, taskChan chan *Task) *Worker {
   return &Worker{
      id:        id,
      taskChan:  taskChan,
      stopChan:  make(chan struct{}),
      running:   true,
      terminate: false,
   }
}

// 启动协程
func (w *Worker) start() {
   go func() {
      for {
         select {
         case task := <-w.taskChan:
            // 如果任务为空,则退出协程
            if task == nil {
               w.running = false
               w.terminate = true
               return
            }
            ctx, cancel := context.WithTimeout(context.Background(), task.Timeout)
            defer cancel()
            err := task.Handler(ctx)
            if err != nil {
               // 处理任务执行失败的情况
            }
         case <-w.stopChan:
            w.running = false
            w.terminate = true
            return
         }
      }
   }()
}

// 停止协程
func (w *Worker) stop() {
   w.stopChan <- struct{}{}
}

// 启动协程池
func (p *Pool) Start() {
   for i := 0; i < p.capacity; i++ {
      p.workers[i] = newWorker(i, p.taskChan)
      p.workers[i].start()
   }
}

// 停止协程池
func (p *Pool) Stop() {
   for i := 0; i < p.capacity; i++ {
      p.workers[i].stop()
   }
   for i := 0; i < p.capacity; i++ {
      // 等待所有协程退出
      if p.workers[i].running {
         p.workers[i].stop()
      }
   }
   close(p.taskChan)
   p.wg.Wait()
}

// 提交任务到协程池
func (p *Pool) Submit(task *Task) {
   p.wg.Add(1)
   p.taskChan <- task
}

// 监控协程池
func (p *Pool) Monitor() {
   for {
      // 统计任务队列长度
      queueLen := len(p.taskChan)
      // 统计协程池中正在运行的协程数和需要退出的协程数
      var runningCount, terminateCount int
      for i := 0; i < p.capacity; i++ {
         if p.workers[i].running {
            runningCount++
         }
         if p.workers[i].terminate {
            terminateCount++
         }
      }
      // 打印监控信息
      fmt.Printf("Queue length: %d, Running workers: %d, Terminating workers: %d\n", queueLen, runningCount, terminateCount)
      time.Sleep(time.Second)
   }
}

测试代码

自己简单跑了下功能测试和benchmark ,下面直接上代码

package goroutine_pool

import (
   "context"
   "fmt"
   "log"
   "net/http"
   "sync"
   "testing"
   "time"
)

var (
   pool = NewPool(10)
   ctx  = context.Background()
)

func init() {
   pool.Start()
   log.Println(http.ListenAndServe("localhost:8080", nil))

}
func TestPool(t *testing.T) {
   testSubmit(ctx)
}

func testSubmit(ctx context.Context) {
   wg := sync.WaitGroup{}
   for i := 0; i < 100; i++ {
      wg.Add(1)
      go func(ii int) {
         defer wg.Done()
         task := &Task{
            Handler: func(ctx context.Context) error {
               err := handlePrintFunc(ctx, ii)
               return err
            },
            Timeout: time.Second,
         }
         pool.Submit(task)
      }(i)
   }
   wg.Wait()
   pool.Stop()

   go hang()
   select {}
}

func handlePrintFunc(ctx context.Context, i int) (err error) {
   fmt.Println(fmt.Sprintf("task:[%d] in", i))
   time.Sleep(2 * time.Second)
   fmt.Println(fmt.Sprintf("task:[%d] out", i))
   return nil
}

func hang() {
   time.Sleep(time.Hour)
}


func BenchmarkPool(b *testing.B) {
   go hang()
   b.ResetTimer()
   for i := 0; i < b.N; i++ {
      task := &Task{
         Handler: func(ctx context.Context) error {
            time.Sleep(time.Millisecond)
            return nil
         },
         Timeout: time.Second,
      }
      pool.Submit(task)
   }
   pool.Stop()
}

// copilot生成压测代码
func BenchmarkPoolV2(b *testing.B) {
   go hang()
   b.ResetTimer()
   for i := 0; i < b.N; i++ {
      task := &Task{
         Handler: func(ctx context.Context) error {
            time.Sleep(time.Millisecond)
            return nil
         },
         Timeout: time.Second,
      }
      pool.Submit(task)
   }
   pool.Stop()
}

总结

在调试的过程中,发现ChapGPT应该是借鉴了不少开源项目的代码风格,整体来说代码写的还是比较优雅,就是有时候生成的代码编译不通过,需要再次生成。 生成的代码能用,但是基本需要手动再改改 -.-