likes
comments
collection
share

一文搞懂gRPC自定义拦截器

作者站长头像
站长
· 阅读数 9

在Go语言中,使用gRPC时,拦截器(Interceptor)是一个强大的工具,允许你在RPC调用的生命周期中的关键点插入自定义逻辑。通过拦截器,你可以实现各种功能,如身份验证、授权、日志记录、监控和指标收集等。

gRPC拦截器拦截器的基本概念

拦截器是一种装饰器模式,它可以附加到 gRPC 的服务端或客户端上,并按照定义的顺序执行一系列操作。拦截器可以是一元的(Unary)或流的(Stream),分别对应于单个请求-响应调用和多个消息的双向流。

在gRPC中,拦截器是链式的,可以有一个或多个。当一个RPC调用发生时,拦截器会按照定义的顺序依次执行。每个拦截器都可以决定是继续调用链中的下一个拦截器,还是直接返回错误或响应。

自定义拦截器流程

定义拦截器函数

在Go中,你可以定义一个函数作为拦截器。这个函数需要满足特定的签名要求,对于服务器端拦截器,它应该实现grpc.UnaryServerInterceptorgrpc.StreamServerInterceptor接口;对于客户端拦截器,它应该实现grpc.UnaryClientInterceptorgrpc.StreamClientInterceptor接口。

例如,创建一个简单的服务器端一元拦截器:

import (
    "context"
    "google.golang.org/grpc"
)

func myUnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
    // 在调用处理程序之前执行的逻辑

    // 调用下一个拦截器或最终的处理程序
    resp, err = handler(ctx, req)

    // 在调用处理程序之后执行的逻辑

    return resp, err
}

注册拦截器

创建拦截器后,你需要在构建gRPC服务器或客户端时将其注册。

对于服务器:

import "google.golang.org/grpc"

func main() {
    // 创建gRPC服务器实例
    s := grpc.NewServer(
        grpc.UnaryInterceptor(myUnaryServerInterceptor), // 注册一元拦截器
        // grpc.StreamInterceptor(myStreamServerInterceptor), // 注册流拦截器(如果需要)
    )

    // 注册服务
    // ...

    // 启动服务器
    // ...
}

对于客户端:

import (
    "google.golang.org/grpc"
    "yourmodule/yourclient"
)

func main() {
    // 创建gRPC连接
    conn, err := grpc.Dial(
        "localhost:50051",
        grpc.WithInsecure(),
        grpc.WithUnaryInterceptor(myUnaryClientInterceptor), // 注册一元拦截器
        // grpc.WithStreamInterceptor(myStreamClientInterceptor), // 注册流拦截器(如果需要)
    )
    if err != nil {
        // 处理错误
    }
    defer conn.Close()

    // 创建客户端实例并使用连接
    client := yourclient.NewYourServiceClient(conn)

    // 发起RPC调用
    // ...
}

实现gRPC自定义拦截器

以下代码参考:grpc-go

服务端:

package main

import (
   "context"
   "fmt"
   "io"
   "log"
   "net"
   "time"

   "google.golang.org/grpc"
   pb "google.golang.org/grpc/examples/features/proto/echo"
)

type server struct {
   pb.UnimplementedEchoServer
}

func (s *server) UnaryEcho(ctx context.Context, in *pb.EchoRequest) (*pb.EchoResponse, error) {
   fmt.Printf("request message : %q\n", in.Message)
   return &pb.EchoResponse{Message: in.Message}, nil
}

func myUnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   m, err := handler(ctx, req)
   if err != nil {
      fmt.Printf("error: %v", err)
      return nil, err
   }
   return m, err
}

type wrappedStream struct {
   grpc.ServerStream
}

func (w *wrappedStream) RecvMsg(m any) error {
   fmt.Printf("Receive a message (Type: %T) at %s \n", m, time.Now().Format(time.RFC3339))
   return w.ServerStream.RecvMsg(m)
}

func (w *wrappedStream) SendMsg(m any) error {
   fmt.Printf("Send a message (Type: %T) at %v \n", m, time.Now().Format(time.RFC3339))
   return w.ServerStream.SendMsg(m)
}

func myStreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   if err := handler(srv, &wrappedStream{ss}); err != nil {
      fmt.Printf("RPC failed with error: %v", err)
      return err
   }
   return nil
}

func (s *server) BidirectionalStreamingEcho(stream pb.Echo_BidirectionalStreamingEchoServer) error {
   for {
      in, err := stream.Recv()
      if err != nil {
         if err == io.EOF {
            return nil
         }
         fmt.Printf("server: error receiving from stream: %v\n", err)
         return err
      }
      fmt.Printf("bidi echoing message %q\n", in.Message)
      _ = stream.Send(&pb.EchoResponse{Message: in.Message})
   }
}

func main() {
   lis, err := net.Listen("tcp", "127.0.0.1:8990")
   if err != nil {
      log.Fatalf("failed to listen: %v", err)
   }

   s := grpc.NewServer(grpc.UnaryInterceptor(myUnaryInterceptor),
      grpc.StreamInterceptor(myStreamInterceptor))

   pb.RegisterEchoServer(s, &server{})

   if err := s.Serve(lis); err != nil {
      log.Fatalf("failed to serve: %v", err)
   }
}

客户端:

package main

import (
   "context"
   "flag"
   "fmt"
   "io"
   "log"
   "time"

   "google.golang.org/grpc"
   ecpb "google.golang.org/grpc/examples/features/proto/echo"
)

func unaryInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   start := time.Now()
   err := invoker(ctx, method, req, reply, cc, opts...)
   end := time.Now()
   fmt.Printf("RPC: %s, start time: %s, end time: %s, err: %v \n", method, start.Format("Basic"), end.Format(time.RFC3339), err)
   return err
}


type wrappedStream struct {
   grpc.ClientStream
}

func (w *wrappedStream) RecvMsg(m any) error {
   fmt.Printf("Receive a message (Type: %T) at %v", m, time.Now().Format(time.RFC3339))
   return w.ClientStream.RecvMsg(m)
}

func (w *wrappedStream) SendMsg(m any) error {
   fmt.Printf("Send a message (Type: %T) at %v", m, time.Now().Format(time.RFC3339))
   return w.ClientStream.SendMsg(m)
}

func newWrappedStream(s grpc.ClientStream) grpc.ClientStream {
   return &wrappedStream{s}
}

// streamInterceptor is an example stream interceptor.
func streamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   s, err := streamer(ctx, desc, cc, method, opts...)
   if err != nil {
      return nil, err
   }
   return newWrappedStream(s), nil
}

func callUnaryEcho(client ecpb.EchoClient, message string) {
   ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   defer cancel()
   resp, err := client.UnaryEcho(ctx, &ecpb.EchoRequest{Message: message})
   if err != nil {
      log.Fatalf("client.UnaryEcho(_) = _, %v: ", err)
   }
   fmt.Println("UnaryEcho: ", resp.Message)
}

func callBidiStreamingEcho(client ecpb.EchoClient) {
   ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   defer cancel()
   c, err := client.BidirectionalStreamingEcho(ctx)
   if err != nil {
      return
   }
   for i := 0; i < 5; i++ {
      if err := c.Send(&ecpb.EchoRequest{Message: fmt.Sprintf("Request %d", i+1)}); err != nil {
         log.Fatalf("failed to send request due to error: %v", err)
      }
   }
   c.CloseSend()
   for {
      resp, err := c.Recv()
      if err == io.EOF {
         break
      }
      if err != nil {
         log.Fatalf("failed to receive response due to error: %v", err)
      }
      fmt.Println("BidiStreaming Echo: ", resp.Message)
   }
}

func main() {
   conn, err := grpc.Dial("127.0.0.1:8990",grpc.WithInsecure(),
      grpc.WithUnaryInterceptor(unaryInterceptor),
      grpc.WithStreamInterceptor(streamInterceptor))
   if err != nil {
      log.Fatalf("did not connect: %v", err)
   }
   defer conn.Close()

   rgc := ecpb.NewEchoClient(conn)
   callUnaryEcho(rgc, "hello world")
   callBidiStreamingEcho(rgc)
}

拦截器中的常见操作

在拦截器中,你可以执行多种操作,包括但不限于:

  • 读取和修改元数据(headers)
  • 记录请求和响应的详细信息
  • 执行身份验证和授权检查
  • 添加额外的请求或响应处理逻辑,如加密、解密、压缩、解压缩等
  • 监控RPC调用的性能指标,如延迟、吞吐量等

拦截器可以用来执行以下任务:

  • 修改请求和响应数据:拦截器可以在请求被处理程序处理之前或响应被发送回客户端之前修改请求或响应数据。
  • 处理异常情况:当请求或响应出现异常时,拦截器可以用来处理或记录这些异常。
  • 自定义监控和日志:通过在拦截器中记录日志或生成度量标准,你可以更容易地监控系统的性能和行为。
  • 实现AOP(面向切面编程)概念:拦截器允许你定义跨多个服务和方法的通用逻辑,例如日志记录、安全检查等。

4 总结

通过自定义gRPC拦截器,你可以在不修改服务或客户端代码的情况下,增强gRPC应用程序的功能和安全性。拦截器提供了一种灵活且强大的机制来插入自定义逻辑,并允许你在RPC调用的关键点上执行必要的操作。