likes
comments
collection
share

一文介绍web Terminal的实战

作者站长头像
站长
· 阅读数 21

1.背景介绍

最近项目中做了一个功能,web ssh,就是将一个terminal终端搬到web上,实现通过web页面连接指定的服务器,并进行想要的命令操作,目前像这种类型的开源产品还是比较多的,在调研了多个产品后,我选择了自己造轮子,根本原因在于安全,所有连接到服务器上的用户执行的命令,我都要审计,留存。那就涉及到我需要劫持到用户在termial中键入的命令,并判断是否是白名单命令,然后才发往服务器上进行执行,并将结果返回给终端.

对于安全审计,从两个层面来操作,

    1. 监控,所谓监控就是,在用户进行ssh操作时,管理员可连接到用户使用的ssh会话中,实时查看当前用户进行的操作,关于这块,我们采用的是websocket聊天室方案
    1. 日志回放,开源的产品asciinema提供了完备的方案。

本文涉及到技术细节,将采用开源产品goploy进行说明,这是一款功能比较强大的运维产品,需要详细了解的可以到github上查看该项目代码

2.方案设计

web ssh 的实现依托于websocket,xterm,实现了功能完备的web terminal.xterm会将用户键入的命令通过websocket发送的websocket server,websocket鉴权通过后,发送到指定的server进行ssh连接执行命令,获取结果,然后写入到websocket返回给客户端terminal.

一文介绍web Terminal的实战

3.代码详解

3.1 Terminal核心功能

看一下核心的代码,就能了解整个过程,从前端往后端看

import { Terminal } from 'xterm'
import { FitAddon } from 'xterm-addon-fit'
import { AttachAddon } from 'xterm-addon-attach'
import { ElMessage } from 'element-plus'
import { NamespaceKey, getNamespaceId } from '@/utils/namespace'
export class xterm {
  private serverId: number
  private element: HTMLDivElement
  private websocket!: WebSocket
  private terminal!: Terminal
  constructor(element: HTMLDivElement, serverId: number) {
    this.element = element
    this.serverId = serverId
  }
  public connect(): void {
    const isWindows =
      ['Windows', 'Win16', 'Win32', 'WinCE'].indexOf(navigator.platform) >= 0
    this.terminal = new Terminal({
      fontSize: 14,
      cursorBlink: true,
      windowsMode: isWindows,
      theme: {
        foreground: '#ebeef5',
        background: '#1d2935',
        cursor: '#e6a23c',
        black: '#000000',
        brightBlack: '#555555',
        red: '#ef4f4f',
        brightRed: '#ef4f4f',
        green: '#67c23a',
        brightGreen: '#67c23a',
        yellow: '#e6a23c',
        brightYellow: '#e6a23c',
        blue: '#409eff',
        brightBlue: '#409eff',
        magenta: '#ef4f4f',
        brightMagenta: '#ef4f4f',
        cyan: '#17c0ae',
        brightCyan: '#17c0ae',
        white: '#bbbbbb',
        brightWhite: '#ffffff',
      },
    })
    const fitAddon = new FitAddon()
    this.terminal.open(this.element)
    this.terminal.loadAddon(fitAddon)
    fitAddon.fit()
    this.websocket = new WebSocket(
      `${location.protocol.replace('http', 'ws')}//${
        window.location.host + import.meta.env.VITE_APP_BASE_API
      }/ws/xterm?${NamespaceKey}=${getNamespaceId()}&serverId=${
        this.serverId
      }&rows=${this.terminal.rows}&cols=${this.terminal.cols}`
    )
    this.terminal.loadAddon(new AttachAddon(this.websocket))
    this.websocket.onclose = function (evt) {
      if (evt.reason !== '') {
        ElMessage.error(evt.reason)
      }
    }
  }
  public close(): void {
    this.terminal.dispose()
    this.websocket.close()
  }

  public send(message: string): void {
    this.websocket.send(message)
  }
}

代码有点长,我们简单看一下,先connect方法中,实例化了一个Terminal,这个Terminal提供了很多默认的配置,包括背影,字体等,一般我们根据自己的后台页面风格进行调整即可,同时引入FitAddon来进行窗口自适应。这是一个独立的用于xterm.js的插件,需要独立安装并引入,它允许将终端的尺寸匹配到包含元素。 安装方式:

npm install --save xterm-addon-fit

在这之后,创建了一个Websocket连接,然后绑到terminal上,这里也是一个xterm.js的插件,xterm-addon-attach,具体安装:

npm install --save xterm-addon-attach

具体的Server端的逻辑,我们稍后讲解。

说完核心的方法后,我们来看一下,页面上如何进行创建xterm对象的,如何发送命令的

 const x = new xterm(
      terminalRefs.value[currentTerminalUUID.value] as HTMLDivElement,
      server.id
    )
    x.connect()

创建xterm对象,并连接websocket。

...
<el-row class="footer">
      <el-input
        v-model="command"
        :disabled="terminalList.length === 0"
        placeholder="Click here send to all windows"
        class="terminal-cmd"
        @keyup.enter="enterCommand"
      />
    </el-row>
    ...
function enterCommand() {
  terminalList.value.forEach((terminal) => {
    terminal.xterm?.send(command.value + '\n')
  })
  command.value = ''
}

键入的命令通过send方法发送到websocket server。 以上就是web terminal的核心功能,当然,在页面关闭时的资源关闭,回收也很重要,这里我们仅关注功能点,回收的部分,可以参考源码中的实现即可。

前端代码看完之后,我们继续服务端代码的学习,

websocket server部分:


const (
	// Time allowed to read the next pong message from the peer.
	pongWait = 60 * time.Second

	// Send pings to peer with this period. Must be less than pongWait.
	pingPeriod = (pongWait * 9) / 10

	// Maximum message size allowed from peer.
	maxMessageSize = 10240
)

const (
	TypeProject = 1
	TypeMonitor = 3
)

// Client stores a client information
type Client struct {
	Conn     *websocket.Conn
	UserInfo model.User
}

// Data is message struct
type Data struct {
	Type    int
	UserIDs []int64
	Message Message
}

type Message interface {
	CanSendTo(client *Client) error
}

// Hub is a client struct
type Hub struct {
	// Registered clients.
	clients map[*Client]bool

	// Inbound messages from the clients.
	Data chan *Data

	// Register requests from the clients.
	Register chan *Client

	// Unregister requests from clients.
	Unregister chan *Client
	// ping pong ticker
	ticker chan *Client
}

func init() {
	go hub.run()
}

func (hub *Hub) Handler() []server.Route {
	return []server.Route{
		server.NewRoute("/ws/connect", http.MethodGet, hub.connect),
		server.NewRoute("/ws/xterm", http.MethodGet, hub.xterm),
		server.NewRoute("/ws/sftp", http.MethodGet, hub.sftp),
	}
}

var hub = &Hub{
	Data:       make(chan *Data),
	clients:    make(map[*Client]bool),
	Register:   make(chan *Client),
	Unregister: make(chan *Client),
	ticker:     make(chan *Client),
}

func GetHub() *Hub {
	return hub
}

func Send(d Data) {
	GetHub().Data <- &d
}

func (hub *Hub) connect(gp *server.Goploy) server.Response {
	upgrader := websocket.Upgrader{
		CheckOrigin: func(r *http.Request) bool {
			if config.Toml.CORS.Enabled {
				if config.Toml.CORS.Origins == "*" {
					return true
				} else if strings.Contains(config.Toml.CORS.Origins, r.Header.Get("origin")) {
					return true
				}
			}
			if strings.Contains(r.Header.Get("origin"), strings.Split(r.Host, ":")[0]) {
				return true
			}
			return false
		},
	}
	c, err := upgrader.Upgrade(gp.ResponseWriter, gp.Request, nil)
	if err != nil {
		log.Error(err.Error())
		return response.JSON{Code: response.Error, Message: err.Error()}
	}
	c.SetReadLimit(maxMessageSize)
	c.SetReadDeadline(time.Now().Add(pongWait))
	c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(pongWait)); return nil })
	client := &Client{
		Conn:     c,
		UserInfo: gp.UserInfo,
	}
	hub.Register <- client

	ticker := time.NewTicker(pingPeriod)
	stop := make(chan bool, 1)
	go func() {
		for {
			select {
			case <-ticker.C:
				hub.ticker <- client
			case <-stop:
				return
			}
		}
	}()
	// you must read message to trigger pong handler
	for {
		_, _, err = c.ReadMessage()
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
				log.Error(err.Error())
			}
			break
		}
	}

	defer func() {
		hub.Unregister <- client
		c.Close()
		ticker.Stop()
		stop <- true
	}()

	return response.Empty{}
}

// Run goroutine run the sync hub
func (hub *Hub) run() {
	for {
		select {
		case client := <-hub.Register:
			hub.clients[client] = true
		case client := <-hub.Unregister:
			if _, ok := hub.clients[client]; ok {
				delete(hub.clients, client)
				client.Conn.Close()
			}
		case data := <-hub.Data:
			for client := range hub.clients {
				if data.Message.CanSendTo(client) != nil {
					continue
				}
				// check userIDs list
				for _, userID := range data.UserIDs {
					if client.UserInfo.ID != userID {
						continue
					}
				}
				if err := client.Conn.WriteJSON(
					struct {
						Type    int         `json:"type"`
						Message interface{} `json:"message"`
					}{
						Type:    data.Type,
						Message: data.Message,
					}); websocket.IsCloseError(err) {
					hub.Unregister <- client
				}
			}
		case client := <-hub.ticker:
			if _, ok := hub.clients[client]; ok {
				if err := client.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
					hub.Unregister <- client
				}
			}
		}
	}
}

这部分代码并不复杂,先看init方法,这个是首先运行的代码,实现了一个hub.包括wesocket连接和用户信息管理,当有新的连接创建时,添加到hub里,当有连接断开时,从hub中移除连接。还有就是数据中转,将需要发送的数据,发送到指定的websocket连接。

然后就是/ws/xterm的handler。


// write data to WebSocket
// the data comes from ssh server.
type xtermBufferWriter struct {
	buffer bytes.Buffer
	mu     sync.Mutex
}

// implement Write interface to write bytes from ssh server into bytes.Buffer.
func (w *xtermBufferWriter) Write(p []byte) (int, error) {
	w.mu.Lock()
	defer w.mu.Unlock()
	return w.buffer.Write(p)
}

func (hub *Hub) xterm(gp *server.Goploy) server.Response {
	upgrader := websocket.Upgrader{
		CheckOrigin: func(r *http.Request) bool {
			if config.Toml.CORS.Enabled {
				if config.Toml.CORS.Origins == "*" {
					return true
				} else if strings.Contains(config.Toml.CORS.Origins, r.Header.Get("origin")) {
					return true
				}
			}
			if strings.Contains(r.Header.Get("origin"), strings.Split(r.Host, ":")[0]) {
				return true
			}
			return false
		},
	}
	c, err := upgrader.Upgrade(gp.ResponseWriter, gp.Request, nil)
	if err != nil {
		return response.JSON{Code: response.Error, Message: err.Error()}
	}
	defer c.Close()
	c.SetReadLimit(maxMessageSize)
	c.SetReadDeadline(time.Now().Add(pongWait))
	c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(pongWait)); return nil })

	rows, err := strconv.Atoi(gp.URLQuery.Get("rows"))
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}
	cols, err := strconv.Atoi(gp.URLQuery.Get("cols"))
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}
	serverID, err := strconv.ParseInt(gp.URLQuery.Get("serverId"), 10, 64)
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}

	srv, err := (model.Server{ID: serverID}).GetData()
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}
	client, err := srv.ToSSHConfig().Dial()
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}
	defer client.Close()
	// create session
	session, err := client.NewSession()
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}
	defer session.Close()
	sessionStdin, err := session.StdinPipe()
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}
	comboWriter := new(xtermBufferWriter)
	//ssh.stdout and stderr will write output into comboWriter
	session.Stdout = comboWriter
	session.Stderr = comboWriter
	// Request pseudo terminal
	if err := session.RequestPty("xterm", rows, cols, ssh.TerminalModes{}); err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}
	// Start remote shell
	if err := session.Shell(); err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}

	// terminal log
	tlID, err := model.TerminalLog{
		NamespaceID: gp.Namespace.ID,
		UserID:      gp.UserInfo.ID,
		ServerID:    serverID,
		RemoteAddr:  gp.Request.RemoteAddr,
		UserAgent:   gp.Request.UserAgent(),
		StartTime:   time.Now().Format("20060102150405"),
	}.AddRow()
	if err != nil {
		_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
		return response.Empty{}
	}

	var recorder *pkg.Recorder
	recorder, err = pkg.NewRecorder(config.GetTerminalLogPath(tlID), "xterm", rows, cols)
	if err != nil {
		log.Error(err.Error())
	} else {
		defer recorder.Close()
	}

	ticker := time.NewTicker(pingPeriod)
	defer ticker.Stop()
	flushMessageTick := time.NewTicker(time.Millisecond * time.Duration(50))
	defer flushMessageTick.Stop()
	stop := make(chan bool, 1)
	defer func() {
		stop <- true
	}()
	go func() {
		for {
			select {
			case <-ticker.C:
				if err := c.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
					c.Close()
					return
				}
			case <-flushMessageTick.C:
				if comboWriter.buffer.Len() != 0 {
					err := c.WriteMessage(websocket.BinaryMessage, comboWriter.buffer.Bytes())
					if err != nil {
						c.Close()
						return
					}
					if recorder != nil {
						if err := recorder.WriteData(comboWriter.buffer.String()); err != nil {
							log.Error(err.Error())
						}
					}
					comboWriter.buffer.Reset()
				}
			case <-stop:
				return
			}
		}
	}()

	for {
		messageType, p, err := c.ReadMessage()
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
				log.Error(err.Error())
			}
			break
		}
		if messageType != websocket.PongMessage {
			if _, err := sessionStdin.Write(p); err != nil {
				log.Error(err.Error())
				break
			}
		}
	}

	if err := (model.TerminalLog{
		ID:      tlID,
		EndTime: time.Now().Format("20060102150405"),
	}.EditRow()); err != nil {
		log.Error(err.Error())
	}

	return response.Empty{}
}

这里也并不复杂,我们捡重点的说,升级完websocket后,获取cols,和rows属性,这里是设置terminal的高度和宽度的,防止过小的窗口无法和前端的窗口匹配,无法查看完整的信息。然后就是通过serverId参数,获取server的信息,并创建ssh 连接,创建连接后,开启shell. 然后再一个goroutine中根据设置的时间的ticker来刷新输出缓存到websocket连接,缓存里就是ssh命令执行的结果。 在接下来的for循环中,读取websocket链接中用户键入的命令,并写入到ssh的连接中,来让命令在服务器上执行,执行的结果将写入到缓存中。让上面的goroutine来间隔刷新出去给websocket.

以上就是完整的web terminal的流程。大家有不懂的地方可以留言,我们一起学习。

3.2 日志回放

上面还有一段代码,没有说就是

	var recorder *pkg.Recorder
	recorder, err = pkg.NewRecorder(config.GetTerminalLogPath(tlID), "xterm", rows, cols)
	if err != nil {
		log.Error(err.Error())
	} else {
		defer recorder.Close()
	}
        
        ...
        if recorder != nil {
		if err := recorder.WriteData(comboWriter.buffer.String()); err != nil {
			log.Error(err.Error())
		}
	}

这段代码就是实现了如何进行日志回放的核心代码,我们来看一下如何实现的。


type Env struct {
	Shell string `json:"SHELL"`
	Term  string `json:"TERM"`
}

type Header struct {
	Title     string `json:"title"`
	Version   int    `json:"version"`
	Height    int    `json:"height"`
	Width     int    `json:"width"`
	Env       Env    `json:"env"`
	Timestamp int    `json:"Timestamp"`
}

type Recorder struct {
	File      *os.File
	Timestamp int
}

func (recorder *Recorder) Close() {
	if recorder.File != nil {
		_ = recorder.File.Close()
	}
}

func (recorder *Recorder) WriteHeader(header *Header) (err error) {
	var p []byte

	if p, err = json.Marshal(header); err != nil {
		return
	}

	if _, err := recorder.File.Write(p); err != nil {
		return err
	}
	if _, err := recorder.File.Write([]byte("\n")); err != nil {
		return err
	}

	recorder.Timestamp = header.Timestamp

	return
}

func (recorder *Recorder) WriteData(data string) (err error) {
	now := int(time.Now().UnixNano())

	delta := float64(now-recorder.Timestamp*1000*1000*1000) / 1000 / 1000 / 1000

	row := make([]interface{}, 0)
	row = append(row, delta)
	row = append(row, "o")
	row = append(row, data)

	var s []byte
	if s, err = json.Marshal(row); err != nil {
		return
	}
	if _, err := recorder.File.Write(s); err != nil {
		return err
	}
	if _, err := recorder.File.Write([]byte("\n")); err != nil {
		return err
	}
	return
}

func NewRecorder(recordingPath, term string, h int, w int) (recorder *Recorder, err error) {
	recorder = &Recorder{}

	if _, err := os.Stat(path.Dir(recordingPath)); err != nil {
		if err := os.MkdirAll(path.Dir(recordingPath), os.ModePerm); err != nil {
			return recorder, err
		}
	}

	file, err := os.Create(recordingPath)
	if err != nil {
		return nil, err
	}

	recorder.File = file

	header := &Header{
		Title:     "",
		Version:   2,
		Height:    h,
		Width:     w,
		Env:       Env{Shell: "/bin/bash", Term: term},
		Timestamp: int(time.Now().Unix()),
	}

	if err := recorder.WriteHeader(header); err != nil {
		return nil, err
	}

	return recorder, nil
}

惊不惊喜,意不意外,这里的核心就是文件的写入,只是这里的文件是指定的内容格式,terminal中所有的执行过程,都会记录到该文件中。这个文件有固定的头部信息,文件格式为.cast类型文件。

这里就引入了该功能使用的工具asciinema.我们来了解一下这个工具

您可能知道SSH,屏幕或脚本命令。实际上,Asciinema的灵感来自脚本(和Scriptreplay)命令。您可能不知道的是它们都使用相同的UNIX系统功能:伪末端。

伪终端是一对伪驱动器,其中一个(从属)模拟了真实的文本终端设备,另一个终端设备(Master)提供了终端模拟器过程控制从属的手段。

这是终端仿真器与用户和外壳互动的方式:

终端模拟器过程的作用是与用户互动。将文本输入输入到主伪设备中,以供外壳使用(已连接到从属伪设备),并从主伪设备读取文本输出并将其显示给用户。

换句话说,伪末端使程序能够充当用户,显示器和外壳之间的中间人。它允许透明捕获用户输入(键盘)和终端输出(显示)。屏幕命令将其用于捕获特殊键盘快捷键,例如 ctrl-a ,并更改输出以显示窗口号/名称和其他消息。

Asciinema记录器通过利用伪终端来捕获所有输入到终端并将其保存在内存中(以及定时信息)的输出来完成其作业。捕获的输出包括原始的,不变的形式的所有文本和无形的逃生/控制序列。当录制会话完成时,它将输出(以assiicast格式)上传到asciinema.org。这就是“录制”部分。 Asciinema项目由几个互补作品构建:

  • 基于命令行的终端会话记录器Asciinema,
  • 网站上带有API的网站acciinema.org,
  • JavaScript播放器

当您在终端中运行Asciinema Rec时,录制开始时,捕获发出shell命令时打印到终端的所有输出。当录制完成(通过击中CTRL-D或打字出口)时,将捕获的输出上传到asciinema.org网站,并准备在网络上播放。

那如何在项目中进行使用呢,我们看一下具体代码:

func (Log) GetTerminalRecord(gp *server.Goploy) server.Response {
	type ReqData struct {
		RecordID int64 `schema:"recordId" validate:"gt=0"`
	}
	var reqData ReqData
	if err := gp.Decode(&reqData); err != nil {
		return response.JSON{Code: response.IllegalParam, Message: err.Error()}
	}
	terminalLog, err := model.TerminalLog{ID: reqData.RecordID}.GetData()
	if err != nil {
		return response.JSON{Code: response.Error, Message: err.Error()}
	}
	if gp.UserInfo.SuperManager != model.SuperManager && terminalLog.NamespaceID != gp.Namespace.ID {
		return response.JSON{Code: response.Error, Message: "You have no access to enter this record"}
	}
	return response.File{Filename: config.GetTerminalLogPath(reqData.RecordID)}
}

文件读取后,返回给前端


type File struct {
	Filename string
}

func (f File) Write(w http.ResponseWriter, _ *http.Request) error {
	file, err := os.Open(f.Filename)
	if err != nil {
		return err
	}

	fileStat, err := file.Stat()
	if err != nil {
		return err
	}

	w.Header().Set("Content-Disposition", "attachment; filename="+fileStat.Name())
	w.Header().Set("Content-Type", "application/x-asciicast")
	w.Header().Set("Content-Length", strconv.FormatInt(fileStat.Size(), 10))

	_, err = io.Copy(w, file)
	if err != nil {
		return err
	}
	return nil
}

前端创建AsciinemaPlayer对象进行可播放的文件查看。

function handleRecord(data: TerminalLogData) {
  recordViewer.value = true
  const castUrl = `${location.origin}${
    import.meta.env.VITE_APP_BASE_API
  }/log/getTerminalRecord?${NamespaceKey}=${getNamespaceId()}&recordId=${
    data.id
  }`
  nextTick(() => {
    AsciinemaPlayer.create(castUrl, record.value, {
      fit: false,
      fontSize: '14px',
    })
  })
}

3.3 会话监控

具体说就是,如果用户在终端进行操作,其操作的输入和输入在监控端都能实时看到,监控端仅仅提供消息的查看能力,无法插入消息。其实看到这,你可能也就意识到了,这其实就像一个聊天室,只不过监控者属于潜水用户,看其他人的消息,废话不多说,我们来看看如何实现的。

中转站

其实,整个设计中,最核心的就是消息中转站。

package main

type message struct {
	data []byte
	room string
}

type subscription struct {
	conn *connection
	room string
}

// hub maintains the set of active connections and broadcasts messages to the
// connections.
type hub struct {
	// Registered connections.
	rooms map[string]map[*connection]bool

	// Inbound messages from the connections.
	broadcast chan message

	// Register requests from the connections.
	register chan subscription

	// Unregister requests from connections.
	unregister chan subscription
}

var h = hub{
	broadcast:  make(chan message),
	register:   make(chan subscription),
	unregister: make(chan subscription),
	rooms:      make(map[string]map[*connection]bool),
}

func (h *hub) run() {
	for {
		select {
		case s := <-h.register:
			connections := h.rooms[s.room]
			if connections == nil {
				connections = make(map[*connection]bool)
				h.rooms[s.room] = connections
			}
			h.rooms[s.room][s.conn] = true
		case s := <-h.unregister:
			connections := h.rooms[s.room]
			if connections != nil {
				if _, ok := connections[s.conn]; ok {
					delete(connections, s.conn)
					close(s.conn.send)
					if len(connections) == 0 {
						delete(h.rooms, s.room)
					}
				}
			}
		case m := <-h.broadcast:
			connections := h.rooms[m.room]
			for c := range connections {
				select {
				case c.send <- m.data:
				default:
					close(c.send)
					delete(connections, c)
					if len(connections) == 0 {
						delete(h.rooms, m.room)
					}
				}
			}
		}
	}
}

所有的用户连接到服务器都自动创建一个房间,默认情况下房间内只有服务器和用户,当监控者进来后,就可以查看信息流了,hub中提供了注册的channel和取消注册的channel每个加入hub中的用户都将进行注册操作,退出时进行取消注册操作。在run方法中,实时监控每个channel的数据变更,如果有广播的消息,我们就将消息发送给房间内的所有其他用户。这里有个细节,如果发送消息这无需接受消息,那么发送者的conn可以设置成nil,然后再这里进行nil的判断,就无需将广播的消息发送给消息发送者了。

再来看看websocket连接的逻辑处理,

package main

import (
	"fmt"
	"log"
	"net/http"
	"time"

	"github.com/gorilla/websocket"
)

const (
	// Time allowed to write a message to the peer.
	writeWait = 10 * time.Second

	// Time allowed to read the next pong message from the peer.
	pongWait = 60 * time.Second

	// Send pings to peer with this period. Must be less than pongWait.
	pingPeriod = (pongWait * 9) / 10

	// Maximum message size allowed from peer.
	maxMessageSize = 512
)

var upgrader = websocket.Upgrader{
	ReadBufferSize:  1024,
	WriteBufferSize: 1024,
}

// connection is an middleman between the websocket connection and the hub.
type connection struct {
	// The websocket connection.
	ws *websocket.Conn

	// Buffered channel of outbound messages.
	send chan []byte
}

// readPump pumps messages from the websocket connection to the hub.
func (s subscription) readPump() {
	c := s.conn
	defer func() {
		h.unregister <- s
		c.ws.Close()
	}()
	c.ws.SetReadLimit(maxMessageSize)
	c.ws.SetReadDeadline(time.Now().Add(pongWait))
	c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(pongWait)); return nil })
	for {
		_, msg, err := c.ws.ReadMessage()
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
				log.Printf("error: %v", err)
			}
			break
		}
		m := message{msg, s.room}
		h.broadcast <- m
	}
}

// write writes a message with the given message type and payload.
func (c *connection) write(mt int, payload []byte) error {
	c.ws.SetWriteDeadline(time.Now().Add(writeWait))
	return c.ws.WriteMessage(mt, payload)
}

// writePump pumps messages from the hub to the websocket connection.
func (s *subscription) writePump() {
	c := s.conn
	ticker := time.NewTicker(pingPeriod)
	defer func() {
		ticker.Stop()
		c.ws.Close()
	}()
	for {
		select {
		case message, ok := <-c.send:
			if !ok {
				c.write(websocket.CloseMessage, []byte{})
				return
			}
			if err := c.write(websocket.TextMessage, message); err != nil {
				return
			}
		case <-ticker.C:
			if err := c.write(websocket.PingMessage, []byte{}); err != nil {
				return
			}
		}
	}
}

// serveWs handles websocket requests from the peer.
func serveWs(w http.ResponseWriter, r *http.Request, roomId string) {
	fmt.Print(roomId)
	ws, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Println(err.Error())
		return
	}
	c := &connection{send: make(chan []byte, 256), ws: ws}
	s := subscription{c, roomId}
	h.register <- s
	go s.writePump()
	go s.readPump()
}

监控者用户在连接到websocket server的时候,也会将自己注册到要监控的房间内,这个时候,通过readPump和writePump就可以进行消息交互了,只不过,如果监控者无需发送消息的话,我们可以去掉readPump.仅保留writePump协程即可。writePump协程会在指定的时间周期内进行读取是否有接收到的消息,然后将消息返回给前端。

以上就是整个功能的核心逻辑,看到出来,是不是就是一个阉割版的聊天室功能,从这里我们也可以了解一下,整个websocket聊天室功能的逻辑,然后根据自己的业务场景,进行改造。

参考链接