一文介绍web Terminal的实战
1.背景介绍
最近项目中做了一个功能,web ssh,就是将一个terminal终端搬到web上,实现通过web页面连接指定的服务器,并进行想要的命令操作,目前像这种类型的开源产品还是比较多的,在调研了多个产品后,我选择了自己造轮子,根本原因在于安全,所有连接到服务器上的用户执行的命令,我都要审计,留存。那就涉及到我需要劫持到用户在termial中键入的命令,并判断是否是白名单命令,然后才发往服务器上进行执行,并将结果返回给终端.
对于安全审计,从两个层面来操作,
-
- 监控,所谓监控就是,在用户进行ssh操作时,管理员可连接到用户使用的ssh会话中,实时查看当前用户进行的操作,关于这块,我们采用的是websocket聊天室方案
-
- 日志回放,开源的产品asciinema提供了完备的方案。
本文涉及到技术细节,将采用开源产品goploy进行说明,这是一款功能比较强大的运维产品,需要详细了解的可以到github上查看该项目代码
2.方案设计
web ssh 的实现依托于websocket,xterm,实现了功能完备的web terminal.xterm会将用户键入的命令通过websocket发送的websocket server,websocket鉴权通过后,发送到指定的server进行ssh连接执行命令,获取结果,然后写入到websocket返回给客户端terminal.
3.代码详解
3.1 Terminal核心功能
看一下核心的代码,就能了解整个过程,从前端往后端看
import { Terminal } from 'xterm'
import { FitAddon } from 'xterm-addon-fit'
import { AttachAddon } from 'xterm-addon-attach'
import { ElMessage } from 'element-plus'
import { NamespaceKey, getNamespaceId } from '@/utils/namespace'
export class xterm {
private serverId: number
private element: HTMLDivElement
private websocket!: WebSocket
private terminal!: Terminal
constructor(element: HTMLDivElement, serverId: number) {
this.element = element
this.serverId = serverId
}
public connect(): void {
const isWindows =
['Windows', 'Win16', 'Win32', 'WinCE'].indexOf(navigator.platform) >= 0
this.terminal = new Terminal({
fontSize: 14,
cursorBlink: true,
windowsMode: isWindows,
theme: {
foreground: '#ebeef5',
background: '#1d2935',
cursor: '#e6a23c',
black: '#000000',
brightBlack: '#555555',
red: '#ef4f4f',
brightRed: '#ef4f4f',
green: '#67c23a',
brightGreen: '#67c23a',
yellow: '#e6a23c',
brightYellow: '#e6a23c',
blue: '#409eff',
brightBlue: '#409eff',
magenta: '#ef4f4f',
brightMagenta: '#ef4f4f',
cyan: '#17c0ae',
brightCyan: '#17c0ae',
white: '#bbbbbb',
brightWhite: '#ffffff',
},
})
const fitAddon = new FitAddon()
this.terminal.open(this.element)
this.terminal.loadAddon(fitAddon)
fitAddon.fit()
this.websocket = new WebSocket(
`${location.protocol.replace('http', 'ws')}//${
window.location.host + import.meta.env.VITE_APP_BASE_API
}/ws/xterm?${NamespaceKey}=${getNamespaceId()}&serverId=${
this.serverId
}&rows=${this.terminal.rows}&cols=${this.terminal.cols}`
)
this.terminal.loadAddon(new AttachAddon(this.websocket))
this.websocket.onclose = function (evt) {
if (evt.reason !== '') {
ElMessage.error(evt.reason)
}
}
}
public close(): void {
this.terminal.dispose()
this.websocket.close()
}
public send(message: string): void {
this.websocket.send(message)
}
}
代码有点长,我们简单看一下,先connect方法中,实例化了一个Terminal,这个Terminal提供了很多默认的配置,包括背影,字体等,一般我们根据自己的后台页面风格进行调整即可,同时引入FitAddon来进行窗口自适应。这是一个独立的用于xterm.js的插件,需要独立安装并引入,它允许将终端的尺寸匹配到包含元素。 安装方式:
npm install --save xterm-addon-fit
在这之后,创建了一个Websocket连接,然后绑到terminal上,这里也是一个xterm.js的插件,xterm-addon-attach,具体安装:
npm install --save xterm-addon-attach
具体的Server端的逻辑,我们稍后讲解。
说完核心的方法后,我们来看一下,页面上如何进行创建xterm对象的,如何发送命令的
const x = new xterm(
terminalRefs.value[currentTerminalUUID.value] as HTMLDivElement,
server.id
)
x.connect()
创建xterm对象,并连接websocket。
...
<el-row class="footer">
<el-input
v-model="command"
:disabled="terminalList.length === 0"
placeholder="Click here send to all windows"
class="terminal-cmd"
@keyup.enter="enterCommand"
/>
</el-row>
...
function enterCommand() {
terminalList.value.forEach((terminal) => {
terminal.xterm?.send(command.value + '\n')
})
command.value = ''
}
键入的命令通过send方法发送到websocket server。 以上就是web terminal的核心功能,当然,在页面关闭时的资源关闭,回收也很重要,这里我们仅关注功能点,回收的部分,可以参考源码中的实现即可。
前端代码看完之后,我们继续服务端代码的学习,
websocket server部分:
const (
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 10240
)
const (
TypeProject = 1
TypeMonitor = 3
)
// Client stores a client information
type Client struct {
Conn *websocket.Conn
UserInfo model.User
}
// Data is message struct
type Data struct {
Type int
UserIDs []int64
Message Message
}
type Message interface {
CanSendTo(client *Client) error
}
// Hub is a client struct
type Hub struct {
// Registered clients.
clients map[*Client]bool
// Inbound messages from the clients.
Data chan *Data
// Register requests from the clients.
Register chan *Client
// Unregister requests from clients.
Unregister chan *Client
// ping pong ticker
ticker chan *Client
}
func init() {
go hub.run()
}
func (hub *Hub) Handler() []server.Route {
return []server.Route{
server.NewRoute("/ws/connect", http.MethodGet, hub.connect),
server.NewRoute("/ws/xterm", http.MethodGet, hub.xterm),
server.NewRoute("/ws/sftp", http.MethodGet, hub.sftp),
}
}
var hub = &Hub{
Data: make(chan *Data),
clients: make(map[*Client]bool),
Register: make(chan *Client),
Unregister: make(chan *Client),
ticker: make(chan *Client),
}
func GetHub() *Hub {
return hub
}
func Send(d Data) {
GetHub().Data <- &d
}
func (hub *Hub) connect(gp *server.Goploy) server.Response {
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
if config.Toml.CORS.Enabled {
if config.Toml.CORS.Origins == "*" {
return true
} else if strings.Contains(config.Toml.CORS.Origins, r.Header.Get("origin")) {
return true
}
}
if strings.Contains(r.Header.Get("origin"), strings.Split(r.Host, ":")[0]) {
return true
}
return false
},
}
c, err := upgrader.Upgrade(gp.ResponseWriter, gp.Request, nil)
if err != nil {
log.Error(err.Error())
return response.JSON{Code: response.Error, Message: err.Error()}
}
c.SetReadLimit(maxMessageSize)
c.SetReadDeadline(time.Now().Add(pongWait))
c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(pongWait)); return nil })
client := &Client{
Conn: c,
UserInfo: gp.UserInfo,
}
hub.Register <- client
ticker := time.NewTicker(pingPeriod)
stop := make(chan bool, 1)
go func() {
for {
select {
case <-ticker.C:
hub.ticker <- client
case <-stop:
return
}
}
}()
// you must read message to trigger pong handler
for {
_, _, err = c.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
log.Error(err.Error())
}
break
}
}
defer func() {
hub.Unregister <- client
c.Close()
ticker.Stop()
stop <- true
}()
return response.Empty{}
}
// Run goroutine run the sync hub
func (hub *Hub) run() {
for {
select {
case client := <-hub.Register:
hub.clients[client] = true
case client := <-hub.Unregister:
if _, ok := hub.clients[client]; ok {
delete(hub.clients, client)
client.Conn.Close()
}
case data := <-hub.Data:
for client := range hub.clients {
if data.Message.CanSendTo(client) != nil {
continue
}
// check userIDs list
for _, userID := range data.UserIDs {
if client.UserInfo.ID != userID {
continue
}
}
if err := client.Conn.WriteJSON(
struct {
Type int `json:"type"`
Message interface{} `json:"message"`
}{
Type: data.Type,
Message: data.Message,
}); websocket.IsCloseError(err) {
hub.Unregister <- client
}
}
case client := <-hub.ticker:
if _, ok := hub.clients[client]; ok {
if err := client.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
hub.Unregister <- client
}
}
}
}
}
这部分代码并不复杂,先看init方法,这个是首先运行的代码,实现了一个hub.包括wesocket连接和用户信息管理,当有新的连接创建时,添加到hub里,当有连接断开时,从hub中移除连接。还有就是数据中转,将需要发送的数据,发送到指定的websocket连接。
然后就是/ws/xterm
的handler。
// write data to WebSocket
// the data comes from ssh server.
type xtermBufferWriter struct {
buffer bytes.Buffer
mu sync.Mutex
}
// implement Write interface to write bytes from ssh server into bytes.Buffer.
func (w *xtermBufferWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.buffer.Write(p)
}
func (hub *Hub) xterm(gp *server.Goploy) server.Response {
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
if config.Toml.CORS.Enabled {
if config.Toml.CORS.Origins == "*" {
return true
} else if strings.Contains(config.Toml.CORS.Origins, r.Header.Get("origin")) {
return true
}
}
if strings.Contains(r.Header.Get("origin"), strings.Split(r.Host, ":")[0]) {
return true
}
return false
},
}
c, err := upgrader.Upgrade(gp.ResponseWriter, gp.Request, nil)
if err != nil {
return response.JSON{Code: response.Error, Message: err.Error()}
}
defer c.Close()
c.SetReadLimit(maxMessageSize)
c.SetReadDeadline(time.Now().Add(pongWait))
c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(pongWait)); return nil })
rows, err := strconv.Atoi(gp.URLQuery.Get("rows"))
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
cols, err := strconv.Atoi(gp.URLQuery.Get("cols"))
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
serverID, err := strconv.ParseInt(gp.URLQuery.Get("serverId"), 10, 64)
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
srv, err := (model.Server{ID: serverID}).GetData()
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
client, err := srv.ToSSHConfig().Dial()
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
defer client.Close()
// create session
session, err := client.NewSession()
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
defer session.Close()
sessionStdin, err := session.StdinPipe()
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
comboWriter := new(xtermBufferWriter)
//ssh.stdout and stderr will write output into comboWriter
session.Stdout = comboWriter
session.Stderr = comboWriter
// Request pseudo terminal
if err := session.RequestPty("xterm", rows, cols, ssh.TerminalModes{}); err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
// Start remote shell
if err := session.Shell(); err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
// terminal log
tlID, err := model.TerminalLog{
NamespaceID: gp.Namespace.ID,
UserID: gp.UserInfo.ID,
ServerID: serverID,
RemoteAddr: gp.Request.RemoteAddr,
UserAgent: gp.Request.UserAgent(),
StartTime: time.Now().Format("20060102150405"),
}.AddRow()
if err != nil {
_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()))
return response.Empty{}
}
var recorder *pkg.Recorder
recorder, err = pkg.NewRecorder(config.GetTerminalLogPath(tlID), "xterm", rows, cols)
if err != nil {
log.Error(err.Error())
} else {
defer recorder.Close()
}
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
flushMessageTick := time.NewTicker(time.Millisecond * time.Duration(50))
defer flushMessageTick.Stop()
stop := make(chan bool, 1)
defer func() {
stop <- true
}()
go func() {
for {
select {
case <-ticker.C:
if err := c.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
c.Close()
return
}
case <-flushMessageTick.C:
if comboWriter.buffer.Len() != 0 {
err := c.WriteMessage(websocket.BinaryMessage, comboWriter.buffer.Bytes())
if err != nil {
c.Close()
return
}
if recorder != nil {
if err := recorder.WriteData(comboWriter.buffer.String()); err != nil {
log.Error(err.Error())
}
}
comboWriter.buffer.Reset()
}
case <-stop:
return
}
}
}()
for {
messageType, p, err := c.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
log.Error(err.Error())
}
break
}
if messageType != websocket.PongMessage {
if _, err := sessionStdin.Write(p); err != nil {
log.Error(err.Error())
break
}
}
}
if err := (model.TerminalLog{
ID: tlID,
EndTime: time.Now().Format("20060102150405"),
}.EditRow()); err != nil {
log.Error(err.Error())
}
return response.Empty{}
}
这里也并不复杂,我们捡重点的说,升级完websocket后,获取cols,和rows属性,这里是设置terminal的高度和宽度的,防止过小的窗口无法和前端的窗口匹配,无法查看完整的信息。然后就是通过serverId参数,获取server的信息,并创建ssh 连接,创建连接后,开启shell. 然后再一个goroutine中根据设置的时间的ticker来刷新输出缓存到websocket连接,缓存里就是ssh命令执行的结果。 在接下来的for循环中,读取websocket链接中用户键入的命令,并写入到ssh的连接中,来让命令在服务器上执行,执行的结果将写入到缓存中。让上面的goroutine来间隔刷新出去给websocket.
以上就是完整的web terminal的流程。大家有不懂的地方可以留言,我们一起学习。
3.2 日志回放
上面还有一段代码,没有说就是
var recorder *pkg.Recorder
recorder, err = pkg.NewRecorder(config.GetTerminalLogPath(tlID), "xterm", rows, cols)
if err != nil {
log.Error(err.Error())
} else {
defer recorder.Close()
}
...
if recorder != nil {
if err := recorder.WriteData(comboWriter.buffer.String()); err != nil {
log.Error(err.Error())
}
}
这段代码就是实现了如何进行日志回放的核心代码,我们来看一下如何实现的。
type Env struct {
Shell string `json:"SHELL"`
Term string `json:"TERM"`
}
type Header struct {
Title string `json:"title"`
Version int `json:"version"`
Height int `json:"height"`
Width int `json:"width"`
Env Env `json:"env"`
Timestamp int `json:"Timestamp"`
}
type Recorder struct {
File *os.File
Timestamp int
}
func (recorder *Recorder) Close() {
if recorder.File != nil {
_ = recorder.File.Close()
}
}
func (recorder *Recorder) WriteHeader(header *Header) (err error) {
var p []byte
if p, err = json.Marshal(header); err != nil {
return
}
if _, err := recorder.File.Write(p); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
recorder.Timestamp = header.Timestamp
return
}
func (recorder *Recorder) WriteData(data string) (err error) {
now := int(time.Now().UnixNano())
delta := float64(now-recorder.Timestamp*1000*1000*1000) / 1000 / 1000 / 1000
row := make([]interface{}, 0)
row = append(row, delta)
row = append(row, "o")
row = append(row, data)
var s []byte
if s, err = json.Marshal(row); err != nil {
return
}
if _, err := recorder.File.Write(s); err != nil {
return err
}
if _, err := recorder.File.Write([]byte("\n")); err != nil {
return err
}
return
}
func NewRecorder(recordingPath, term string, h int, w int) (recorder *Recorder, err error) {
recorder = &Recorder{}
if _, err := os.Stat(path.Dir(recordingPath)); err != nil {
if err := os.MkdirAll(path.Dir(recordingPath), os.ModePerm); err != nil {
return recorder, err
}
}
file, err := os.Create(recordingPath)
if err != nil {
return nil, err
}
recorder.File = file
header := &Header{
Title: "",
Version: 2,
Height: h,
Width: w,
Env: Env{Shell: "/bin/bash", Term: term},
Timestamp: int(time.Now().Unix()),
}
if err := recorder.WriteHeader(header); err != nil {
return nil, err
}
return recorder, nil
}
惊不惊喜,意不意外,这里的核心就是文件的写入,只是这里的文件是指定的内容格式,terminal中所有的执行过程,都会记录到该文件中。这个文件有固定的头部信息,文件格式为.cast
类型文件。
这里就引入了该功能使用的工具asciinema
.我们来了解一下这个工具
您可能知道SSH,屏幕或脚本命令。实际上,Asciinema的灵感来自脚本(和Scriptreplay)命令。您可能不知道的是它们都使用相同的UNIX系统功能:伪末端。
伪终端是一对伪驱动器,其中一个(从属)模拟了真实的文本终端设备,另一个终端设备(Master)提供了终端模拟器过程控制从属的手段。
这是终端仿真器与用户和外壳互动的方式:
终端模拟器过程的作用是与用户互动。将文本输入输入到主伪设备中,以供外壳使用(已连接到从属伪设备),并从主伪设备读取文本输出并将其显示给用户。
换句话说,伪末端使程序能够充当用户,显示器和外壳之间的中间人。它允许透明捕获用户输入(键盘)和终端输出(显示)。屏幕命令将其用于捕获特殊键盘快捷键,例如 ctrl-a ,并更改输出以显示窗口号/名称和其他消息。
Asciinema记录器通过利用伪终端来捕获所有输入到终端并将其保存在内存中(以及定时信息)的输出来完成其作业。捕获的输出包括原始的,不变的形式的所有文本和无形的逃生/控制序列。当录制会话完成时,它将输出(以assiicast格式)上传到asciinema.org。这就是“录制”部分。 Asciinema项目由几个互补作品构建:
- 基于命令行的终端会话记录器Asciinema,
- 网站上带有API的网站acciinema.org,
- JavaScript播放器
当您在终端中运行Asciinema Rec时,录制开始时,捕获发出shell命令时打印到终端的所有输出。当录制完成(通过击中CTRL-D或打字出口)时,将捕获的输出上传到asciinema.org网站,并准备在网络上播放。
那如何在项目中进行使用呢,我们看一下具体代码:
func (Log) GetTerminalRecord(gp *server.Goploy) server.Response {
type ReqData struct {
RecordID int64 `schema:"recordId" validate:"gt=0"`
}
var reqData ReqData
if err := gp.Decode(&reqData); err != nil {
return response.JSON{Code: response.IllegalParam, Message: err.Error()}
}
terminalLog, err := model.TerminalLog{ID: reqData.RecordID}.GetData()
if err != nil {
return response.JSON{Code: response.Error, Message: err.Error()}
}
if gp.UserInfo.SuperManager != model.SuperManager && terminalLog.NamespaceID != gp.Namespace.ID {
return response.JSON{Code: response.Error, Message: "You have no access to enter this record"}
}
return response.File{Filename: config.GetTerminalLogPath(reqData.RecordID)}
}
文件读取后,返回给前端
type File struct {
Filename string
}
func (f File) Write(w http.ResponseWriter, _ *http.Request) error {
file, err := os.Open(f.Filename)
if err != nil {
return err
}
fileStat, err := file.Stat()
if err != nil {
return err
}
w.Header().Set("Content-Disposition", "attachment; filename="+fileStat.Name())
w.Header().Set("Content-Type", "application/x-asciicast")
w.Header().Set("Content-Length", strconv.FormatInt(fileStat.Size(), 10))
_, err = io.Copy(w, file)
if err != nil {
return err
}
return nil
}
前端创建AsciinemaPlayer
对象进行可播放的文件查看。
function handleRecord(data: TerminalLogData) {
recordViewer.value = true
const castUrl = `${location.origin}${
import.meta.env.VITE_APP_BASE_API
}/log/getTerminalRecord?${NamespaceKey}=${getNamespaceId()}&recordId=${
data.id
}`
nextTick(() => {
AsciinemaPlayer.create(castUrl, record.value, {
fit: false,
fontSize: '14px',
})
})
}
3.3 会话监控
具体说就是,如果用户在终端进行操作,其操作的输入和输入在监控端都能实时看到,监控端仅仅提供消息的查看能力,无法插入消息。其实看到这,你可能也就意识到了,这其实就像一个聊天室,只不过监控者属于潜水用户,看其他人的消息,废话不多说,我们来看看如何实现的。
中转站
其实,整个设计中,最核心的就是消息中转站。
package main
type message struct {
data []byte
room string
}
type subscription struct {
conn *connection
room string
}
// hub maintains the set of active connections and broadcasts messages to the
// connections.
type hub struct {
// Registered connections.
rooms map[string]map[*connection]bool
// Inbound messages from the connections.
broadcast chan message
// Register requests from the connections.
register chan subscription
// Unregister requests from connections.
unregister chan subscription
}
var h = hub{
broadcast: make(chan message),
register: make(chan subscription),
unregister: make(chan subscription),
rooms: make(map[string]map[*connection]bool),
}
func (h *hub) run() {
for {
select {
case s := <-h.register:
connections := h.rooms[s.room]
if connections == nil {
connections = make(map[*connection]bool)
h.rooms[s.room] = connections
}
h.rooms[s.room][s.conn] = true
case s := <-h.unregister:
connections := h.rooms[s.room]
if connections != nil {
if _, ok := connections[s.conn]; ok {
delete(connections, s.conn)
close(s.conn.send)
if len(connections) == 0 {
delete(h.rooms, s.room)
}
}
}
case m := <-h.broadcast:
connections := h.rooms[m.room]
for c := range connections {
select {
case c.send <- m.data:
default:
close(c.send)
delete(connections, c)
if len(connections) == 0 {
delete(h.rooms, m.room)
}
}
}
}
}
}
所有的用户连接到服务器都自动创建一个房间,默认情况下房间内只有服务器和用户,当监控者进来后,就可以查看信息流了,hub中提供了注册的channel和取消注册的channel每个加入hub中的用户都将进行注册操作,退出时进行取消注册操作。在run方法中,实时监控每个channel的数据变更,如果有广播的消息,我们就将消息发送给房间内的所有其他用户。这里有个细节,如果发送消息这无需接受消息,那么发送者的conn可以设置成nil,然后再这里进行nil的判断,就无需将广播的消息发送给消息发送者了。
再来看看websocket连接的逻辑处理,
package main
import (
"fmt"
"log"
"net/http"
"time"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 512
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// connection is an middleman between the websocket connection and the hub.
type connection struct {
// The websocket connection.
ws *websocket.Conn
// Buffered channel of outbound messages.
send chan []byte
}
// readPump pumps messages from the websocket connection to the hub.
func (s subscription) readPump() {
c := s.conn
defer func() {
h.unregister <- s
c.ws.Close()
}()
c.ws.SetReadLimit(maxMessageSize)
c.ws.SetReadDeadline(time.Now().Add(pongWait))
c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(pongWait)); return nil })
for {
_, msg, err := c.ws.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
log.Printf("error: %v", err)
}
break
}
m := message{msg, s.room}
h.broadcast <- m
}
}
// write writes a message with the given message type and payload.
func (c *connection) write(mt int, payload []byte) error {
c.ws.SetWriteDeadline(time.Now().Add(writeWait))
return c.ws.WriteMessage(mt, payload)
}
// writePump pumps messages from the hub to the websocket connection.
func (s *subscription) writePump() {
c := s.conn
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.ws.Close()
}()
for {
select {
case message, ok := <-c.send:
if !ok {
c.write(websocket.CloseMessage, []byte{})
return
}
if err := c.write(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
if err := c.write(websocket.PingMessage, []byte{}); err != nil {
return
}
}
}
}
// serveWs handles websocket requests from the peer.
func serveWs(w http.ResponseWriter, r *http.Request, roomId string) {
fmt.Print(roomId)
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err.Error())
return
}
c := &connection{send: make(chan []byte, 256), ws: ws}
s := subscription{c, roomId}
h.register <- s
go s.writePump()
go s.readPump()
}
监控者用户在连接到websocket server的时候,也会将自己注册到要监控的房间内,这个时候,通过readPump和writePump就可以进行消息交互了,只不过,如果监控者无需发送消息的话,我们可以去掉readPump.仅保留writePump协程即可。writePump协程会在指定的时间周期内进行读取是否有接收到的消息,然后将消息返回给前端。
以上就是整个功能的核心逻辑,看到出来,是不是就是一个阉割版的聊天室功能,从这里我们也可以了解一下,整个websocket聊天室功能的逻辑,然后根据自己的业务场景,进行改造。
参考链接
转载自:https://juejin.cn/post/7296111218720833571