SpringBoot 整合WebSocket+事件发布 实现实时消息推送 抽象类抽取公共逻辑 @Autowired注入为Null解决
前言
最近在工作碰到一个需求,项目背景是一款线索分享系统,主要给公安系统使用。其中一个基本功能是,用户A在系统中提交了一条线索,这条线索根据一定的逻辑,判断需要推送给其它某些用户,例如这里举例,这条消息需要推送给用户B,此时B用户的系统页面中需要进行实时弹窗提示。
这个推送弹窗提示的效果,就可以使用webSocket来实现,webSocket不像传统的http请求,只能客户端去主动请求服务端,它可以实现服务端与客户端之间互发消息,进行实时的沟通。比方说在线的聊天室,就是使用了webSocket实现。
这个功能实现的过程中,还希望这个推送的过程,是一个异步操作,保存完线索信息之后,直接返回前端页面保存成功,推送的过程异步操作,不要影响保存的线程。因为这个项目是一个Springboot单体项目,所以这里就借助于SpringBoot的事件发布机制来完成。
这个项目由于是给公安系统使用的,所以项目是直接部署到公安内网,通过ip访问,所以本文就针对这种方式来演示如何使用webSocket,至于比方说配置了nginx反向代理的场景,这里就不涉及了。
项目依赖
真实项目不便透露,这里做了一个demo项目,SpringBoot版本为2.6.13,pom文件如下:
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-security</artifactId>
        </dependency>
        
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-thymeleaf</artifactId>
        </dependency>
        
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        
        <dependency>
            <groupId>org.thymeleaf.extras</groupId>
            <artifactId>thymeleaf-extras-springsecurity5</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-devtools</artifactId>
            <scope>runtime</scope>
            <optional>true</optional>
        </dependency>
        
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        
        <dependency>
            <groupId>org.springframework.security</groupId>
            <artifactId>spring-security-test</artifactId>
            <scope>test</scope>
        </dependency>
        
        <!--websocket-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>
        <!--mysql驱动-->
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>8.0.17</version>
        </dependency>
        <!--mybatis-plus依赖-->
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-boot-starter</artifactId>
            <version>3.3.1</version>
        </dependency>
        <dependency>
            <groupId>com.alibaba.fastjson2</groupId>
            <artifactId>fastjson2</artifactId>
            <version>2.0.12</version>
        </dependency>
    </dependencies>
准备工作
表设计
数据表就简单设计了两个,用户表用来存放系统登录用户,日志表用来对消息推送进行记录,同时可以记录堆积消息,防止消息丢失。
用户表:

日志表

SpringSecurity配置类
项目的登录功能是引入了SpringSecurity来做的,这里为了简单,直接使用了默认的表单登录流程,配置类如下:
@Configuration
public class SecurityConfig extends WebSecurityConfigurerAdapter {
    @Autowired
    private UserServiceImpl userService;
    @Override
    protected void configure(AuthenticationManagerBuilder auth) throws Exception {
        auth.userDetailsService(userService);
    }
}
注入的UserService实例中,实现了UserDetailsService接口,代表这是SpringSecurity中的用户信息数据源。然后实现了loadUserByUsername方法,也就是说登录时,会调用这个方法来根据username获取用户信息。
项目还引入Mybatis-plus,所以这里UserService继承了ServiceImpl类
@Service
public class UserServiceImpl extends ServiceImpl<UserMapper, UserEntity> implements UserService, UserDetailsService {
    @Override
    public UserDetails loadUserByUsername(String username) throws UsernameNotFoundException {
        UserEntity user = getOne(new QueryWrapper<UserEntity>().eq("username", username));
        if (user == null) throw new UsernameNotFoundException("用户名不存在");
        return user;
    }
}
功能实现
整个流程这里就简单规划一下,有两个用户 user1、user2,user1登录系统后,保存一条线索,保存完成后。假定这条线索需要推送给user2,这里让user2的页面,弹出提醒。
配置类
需要提供一个配置类,里边定义一个ServerEndpointExporter,注册到容器中,如下
@Configuration
public class WebSocketConfig {
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
}
定义WebSocket Server
这里需要定义一个WebSocket Server,下面就简称server,这个server可以理解成是一个websocket请求访问的服务端。因为websocket的工作机制,是需要客户端 先发起一个请求,请求被服务端接收到,服务端与客户端建立一个长链接,然后双方接下来可以互发消息。
而且server也可以定义多个,比方说有线索保存的推送,有人员报警的推送等等,就好比是一个个的controller。
首先定义一个类,标注上@Component注解以及@ServerEndpoint注解,后者可以用来指定当前类是一个WebSocket Server的服务端,同时也可以设置它的请求地址:
整个类的具体配置如下:
- @ServerEndpoint设置了当前服务端的请求路径,为/webSocket/clue/{id},这里的{id},其实就是一个路径参数,每一个id就对应一个websocket连接,建议这里的id直接和用户对应,也就是说,每个用户对应一个连接。
- 类中设置了一个成员变量:onlineSessionClientMap,是一个map集合,里面存放所有在线的用户连接。
- 注入了LogService,用来保存推送的消息,同时也可以进行堆积消息的保存。
- onOpen、onClose、onMessage、onError这四个方法,分别标注了对应的注解,代表当用户连接成功时的操作、关闭连接时的操作、收到消息时的操作、连接出错时的操作。
- 最下面还有两个自定义方法,sendToOne也就是对某个id发消息,也就是对某个用户发消息。getWebSocketUrl是提供了一个静态方法,当有地方想要使用这个server时,就调用这个方法,获取webSocket的连接地址,返回给前端,由前端发起webSocket请求。
@Component
@ServerEndpoint("/webSocket/clue/{id}")
public class ClueServer {
    //在线客户端集合
    private static Map<Long, Session> onlineSessionClientMap = new ConcurrentHashMap<>();
    @Autowired
    private LogService logService;
    /**
     * 连接创建成功
     *
     * @param id
     * @param session
     */
    @OnOpen
    public void onOpen(@PathParam("id") Long id, Session session) {
        //保存到在线客户端集合
        onlineSessionClientMap.put(id, session);
        //如果当前用户存在堆积数据 进行推送
        List<LogEntity> list = logService.list(new QueryWrapper<LogEntity>().eq("receive_user_id", id).eq("is_accumulation", 1));
        for (LogEntity log : list) {
            sendAccumulationToOne(id, log.getId(), log.getMessage());
        }
    }
    /**
     * 连接关闭回调
     *
     * @param id
     * @param session
     */
    @OnClose
    public void onClose(@PathParam("id") String id, Session session) {
        //从map集合中移除
        onlineSessionClientMap.remove(id);
    }
    /**
     * 收到消息后的回调
     *
     * @param message
     * @param session
     */
    @OnMessage
    public void onMessage(String message, Session session) {
    }
    /**
     * 发生错误时的回调
     *
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
    }
    /**
     * 向指定的id发送消息
     *
     * @param id
     * @param message
     */
    public void sendToOne(String id, String message) {
        Session session = onlineSessionClientMap.get(id);
        if (session == null) {
            //如果该id不在线,记录消息 等待用户访问时推送
            LogEntity log = new LogEntity();
            log.setReceiveUserId(Long.parseLong(id));
            log.setIsAccumulation(1);
            log.setMessage(message);
            logService.save(log);
            return;
        }
        if (session != null) session.getAsyncRemote().sendText(message);
    }
    /**
     * 向指定用户发送堆积消息
     *
     * @param id      用户id
     * @param dataId  日志记录id
     * @param message
     */
    public void sendAccumulationToOne(Long id, Long dataId, String message) {
        Session session = onlineSessionClientMap.get(id);
        LogEntity log = new LogEntity();
        log.setId(dataId);
        log.setIsAccumulation(0);
        logService.updateById(log);
        if (session != null) session.getAsyncRemote().sendText(message);
    }
    /**
     * 获取websocket地址
     *
     * @param request
     * @return
     */
    public static String getWebSocketUrl(HttpServletRequest request) {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attributes.getRequest();
        String url = request.getRequestURL().toString();
        int i = 0, j = 0;
        for (int k = 0; k < url.length(); k++) {
            if (url.charAt(k) == '/') {
                i = k;
                j++;
            }
            if (j == 3) {
                break;
            }
        }
        String ipPort = j < 3 ? url.substring(7, url.length()) : url.substring(7, i);
        String value = getClass().getAnnotation(ServerEndpoint.class).value();
        return "ws://" + ipPort + value.replace("{id}", UserUtil.getLoginUser().getId().toString());
    }
}
@Autowired注入为Null
上面的代码中,会出现问题。
在onOpen方法中,调用logService.list方法时,会抛出空指针异常,logService是null,但是实际上容器中是有logService这个bean的。
这是因为@ServerEndpoint标注的类,是多例的。因为可能会创建很多个server。而在spring中,bean默认是单例的,所以注入的logService是单例的。
而单例bean的注入,是在bean的初始化流程中完成的,这里的server类是多例bean,就没有走populateBean方法去注入属性,所以通过@Autowired注入的属性都是null。
所以对于LogService的注入,可以采用下面这种方式,创建一个静态实例,通过set方法对其进行赋值。
    private static LogService logService;
    @Autowired
    public void setLogService(LogService logService) {
        this.logService = logService;
    }
前端发起webSocket请求
可以参考下面的代码,具体就是页面初始化时,获取到对应的url,然后在js中发起请求,等待接收消息。接收到消息之后,就做某些操作,比如弹窗或者显示消息等等。
controller
@RequestMapping("/")
@Controller
public class IndexController {
    @GetMapping
    public String index(HttpServletRequest request, Model model) {
        model.addAttribute("webSocketUrl", WebSocketServer.getWebSocketUrl(request));
        return "index";
    }
}
index.html
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Title</title>
</head>
<body>
<h1>首页</h1>
<input type="hidden" id="webSocketUrl" th:value="${webSocketUrl}"/>
<div id="tip" style="border:1px solid #F00;width:30%;display: none"/>
<script type="text/javascript">
    window.onload = function () {
        //创建websocket连接
        let url = document.getElementById("webSocketUrl").value;
        let socket = new WebSocket(url);
        //打开事件
        socket.onopen = function () {
            console.log("webSocket正在连接");
        };
        //获得消息事件
        socket.onmessage = function (msg) {
            let tip = document.getElementById("tip");
            tip.textContent = msg.data;
            tip.style.display = "";
        };
        //关闭事件
        socket.onclose = function () {
            console.log("webSocket已关闭");
        };
        //发生了错误事件
        socket.onerror = function () {
            console.log("webSocket发生错误")
        }
    }
</script>
</body>
</html>
保存线索
接下来就剩在保存线索数据的时候,调用server来进行推送了。
这里为了不让推送的过程影响到保存线程,使用@Async+事件发布机制,完成一个异步操作。
开启异步支持
首先为了开启异步支持,需要在主启动类或者配置类上,标注@EnableAsync注解
定义事件类
public class ClueEvent extends ApplicationEvent {
    private String message;
    public ClueEvent(Object source, String message) {
        super(source);
        this.message = message;
    }
    public String getMessage() {
        return message;
    }
    public void setMessage(String message) {
        this.message = message;
    }
}
定义监听器
@Component
public class ClueEventListener {
    @Autowired
    private WebSocketServer webSocketServer;
    /**
     * 监听器处理消息
     *
     * @param event
     */
    @Async
    @EventListener
    public void onEventListener(ClueEvent event) {
        //这里的id就写死成user2的id
        webSocketServer.sendToOne("2", event.getMessage());
    }
}
保存时发布事件
@Controller
@RequestMapping("/clue")
public class ClueController {
    @Autowired
    private ClueService clueService;
    @Autowired
    private ApplicationEventPublisher applicationEventPublisher;
    @PostMapping("/save")
    @ResponseBody
    public void save(ClueEntity clueEntity) {
        clueService.save(clueEntity);
        //发布推送事件
        applicationEventPublisher.publishEvent(new ClueEvent(this, "您有一条新线索待接收"));
    }
}
结构优化:抽取公共抽象类
截止到这里,就可以实现一个实时的消息推送了。
但是,上面的代码中,还是可以看出一些问题,比如说,如果系统中有很多个webSocket需求,那就需要写很多个server类,这些类中,有相当多的代码都是重复的,这里考虑可以为这些server类抽取一个抽象类,封装一些公共逻辑。
抽象类
如下是公共类的代码;有几个重要的点需要强调以下:
- onlineSessionClientMap这个map集合中,使用了- serverName做为key,- Map<String, List<Session>>结构作为value,这是因为:这个集合是一个- 静态变量,会被它的所有子类- 共享,并且在websocket中,server类是- 多例的实例,所以所有的server类会共享这个Map集合,这里使用server类的类名做一个区分
- 类中还提供了一个方法: public Function getStringMessage() {},这个是根据实际的项目需求来做的。因为有的需求可能向页面推送一个字符串,有的可能向页面推送一个对象。他们要保存到日志表中的消息,可能不一样。这里提供一个getStringMessage(),用来转换出存到数据库的正确消息形式。默认逻辑是直接将消息toString(),也可以在子类中重写这个方法,提供一个lambda表达式,自定义转换逻辑。
@Slf4j
public abstract class AbstractWebSocketServer {
    private static LogService logService;
    @Autowired
    public void setLogService(LogService logService) {
        this.logService = logService;
    }
    //存放所有在线的客户端 key为类名称 区分不同服务端
    private static Map<String, Map<String, List<Session>>> onlineSessionClientMap = new ConcurrentHashMap<>();
    /**
     * 创建连接成功
     *
     * @param id
     * @param session
     */
    @OnOpen
    public void onOpen(@PathParam("id") String id, Session session) {
        String serverName = getClass().getSimpleName();
        List<Session> list = new ArrayList<>();
        list.add(session);
        synchronized (onlineSessionClientMap) {
            if (onlineSessionClientMap.containsKey(serverName)) {
                Map<String, List<Session>> map = onlineSessionClientMap.get(serverName);
                if (map.containsKey(id)) {
                    list.addAll(map.get(id));
                }
                map.put(id, list);
                onlineSessionClientMap.put(serverName, map);
            } else {
                Map<String, List<Session>> map = new ConcurrentHashMap();
                map.put(id, list);
                onlineSessionClientMap.put(serverName, map);
            }
        }
        //查询堆积消息 并发送
        List<LogEntity> msgList = logService.list(new QueryWrapper<LogEntity>().eq("receive_user_id", id).eq("is_accumulation", 1));
        if (msgList.size() > 0) {
            for (LogEntity entity : msgList) {
                sendAccumulationToOne(Long.valueOf(id), entity.getId(), entity.getMessage());
            }
        }
    }
    /**
     * 关闭连接
     *
     * @param id
     * @param session
     */
    @OnClose
    public void onClose(@PathParam("id") String id, Session session) {
        String serverName = getClass().getSimpleName();
        synchronized (onlineSessionClientMap) {
            if (onlineSessionClientMap.containsKey(serverName)) {
                Map<String, List<Session>> map = onlineSessionClientMap.get(serverName);
                List<Session> sessions = map.get(id);
                sessions.remove(session);
                if (sessions.size() == 0) {
                    map.remove(id);
                } else {
                    map.put(id, sessions);
                }
                onlineSessionClientMap.put(serverName, map);
            }
        }
    }
    /**
     * 收到消息
     *
     * @param message
     * @param session
     */
    @OnMessage
    public void onMessage(String message, Session session) {
    }
    /**
     * 发生错误
     *
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.error("WebSocket发生错误:" + error.getMessage());
    }
    /**
     * 指定id发送消息
     *
     * @param id
     * @param message
     */
    public void sendToOne(String id, Object message) {
        sendToOne(id, message, getStringMessage());
    }
    /**
     * 指定id发送消息
     *
     * @param id       连接id
     * @param message  具体的消息:此处封装为Object 实际可能是一个字符串 也可能是一个对象
     * @param function lambda表达式:用于获取需要保存到数据的预警明细信息。如果传null 默认将message.toString进行保存
     */
    public void sendToOne(String id, Object message, Function function) {
        String serverName = getClass().getSimpleName();
        Map<String, List<Session>> map = onlineSessionClientMap.get(serverName);
        List<Session> sessions = map.get(id);
        //保存消息
        LogEntity entity = new LogEntity();
        entity.setReceiveUserId(Long.valueOf(id));
        entity.setMessage(function.apply(message).toString());
        entity.setIsAccumulation(sessions != null ? 0 : 1);
        logService.save(entity);
        //用户在线,直接推送消息
        if (sessions != null) {
            for (Session session : sessions) {
                if (session != null) session.getAsyncRemote().sendText(JSONObject.toJSONString(message));
            }
        }
    }
    /**
     * 推送堆积消息
     *
     * @param id
     * @param dataId
     * @param message
     */
    private void sendAccumulationToOne(Long id, Long dataId, Object message) {
        String serverName = getClass().getSimpleName();
        Map<String, List<Session>> map = onlineSessionClientMap.get(serverName);
        List<Session> sessions = map.get(id);
        //修改状态
        logService.update(new UpdateWrapper<LogEntity>().eq("id", dataId).set("is_accumulation", 0));
        //用户在线,直接推送消息
        for (Session session : sessions) {
            if (session != null) session.getAsyncRemote().sendText(JSONObject.toJSONString(message));
        }
    }
    /**
     * 提供一个函数式接口,用来将message对象转换为能够保存到数据库的字符串 默认返回null
     *
     * @return
     */
    public Function getStringMessage() {
        return o -> o.toString();
    }
    /**
     * 获取webSocket连接地址
     *
     * @return
     */
    public String getWebSocketUrl() {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attributes.getRequest();
        String url = request.getRequestURL().toString();
        int i = 0, j = 0;
        for (int k = 0; k < url.length(); k++) {
            if (url.charAt(k) == '/') {
                i = k;
                j++;
            }
            if (j == 3) {
                break;
            }
        }
        String ipPort = j < 3 ? url.substring(7, url.length()) : url.substring(7, i);
        String value = getClass().getAnnotation(ServerEndpoint.class).value();
        return "ws://" + ipPort + value.replace("{id}", UserUtil.getLoginUser().getId().toString());
    }
}
子类
假设项目中,目前有两个实时推送消息的需求:一个是线索实时推送,一个是告警实时推送。
如下,定义两个server类之后,继承AbstractWebSocketServer,标上注解,这个类就直接可以作为一个server类来使用了,大大节省了开发成本。
@Component
@ServerEndpoint("/websocket/alarm/{id}")
public class AlarmServer extends AbstractWebSocketServer {
}
@Component
@ServerEndpoint("/websocket/clue/{id}")
public class ClueServer extends AbstractWebSocketServer {
}
模拟推送
如下,在controller中模拟推送,其他的事件类等等就不贴出来了。
@RequestMapping("/")
@Controller
public class IndexController {
    @Autowired
    private ClueServer clueServer;
    @Autowired
    private AlarmServer alarmServer;
    @Autowired
    private ApplicationEventPublisher applicationEventPublisher;
    @GetMapping
    public String index(Model model) {
        model.addAttribute("clueUrl", clueServer.getWebSocketUrl());
        model.addAttribute("alarmUrl", alarmServer.getWebSocketUrl());
        return "index";
    }
    @GetMapping("/sendClue")
    @ResponseBody
    public void sendClue() {
        //模拟发送线索推送
        applicationEventPublisher.publishEvent(new ClueEvent(this, "您有一条线索待确认!"));
    }
    @GetMapping("/sendAlarm")
    @ResponseBody
    public void sendAlarm() {
        //模拟发送告警推送
        applicationEventPublisher.publishEvent(new AlarmEvent(this, "您有一条告警待确认!"));
    }
}
总结
完成上述操作之后,就能够实现在保存线索时,实时的推送消息到指定的用户了。
转载自:https://juejin.cn/post/7357698668890275876




