@Slf4j @Component @ServerEndpoint("/websocket/link/{userId}") public class OldWebSocketService { // 用于存储在线用户的会话,使用ConcurrentHashMap确保线程安全 private static final Map<String, Session> onlineSessions = new ConcurrentHashMap<>(); // 堆代码 duidaima.com @OnOpen public void handleOpen(Session session, @PathParam("userId") String userId) { onlineSessions.put(userId, session); log.info("用户ID为 {} 的用户已连接,当前在线用户数: {}", userId, onlineSessions.size()); broadcastMessage("系统提示:有新用户加入"); } @OnMessage public void handleMessage(String message, Session session, @PathParam("userId") String userId) { log.info("服务端收到用户ID为 {} 的消息: {}", userId, message); JSONObject jsonMessage = JSON.parseObject(message); String targetUserId = jsonMessage.getString("to"); String content = jsonMessage.getString("content"); Session targetSession = onlineSessions.get(targetUserId); if (targetSession != null) { JSONObject responseMessage = new JSONObject(); responseMessage.put("from", userId); responseMessage.put("content", content); sendMessage(responseMessage.toJSONString(), targetSession); log.info("向用户ID为 {} 发送消息: {}", targetUserId, responseMessage.toJSONString()); } else { log.info("未能找到用户ID为 {} 的会话,消息发送失败", targetUserId); } } private void sendMessage(String message, Session targetSession) { try { log.info("服务端向客户端[{}]发送消息: {}", targetSession.getId(), message); targetSession.getBasicRemote().sendText(message); } catch (Exception e) { log.error("服务端向客户端发送消息失败", e); } } private void broadcastMessage(String message) { for (Session session : onlineSessions.values()) { sendMessage(message, session); } } @OnClose public void handleClose(Session session, @PathParam("userId") String userId) { onlineSessions.remove(userId); log.info("用户ID为 {} 的连接已关闭,当前在线用户数: {}", userId, onlineSessions.size()); } }打造安全加固的 WebSocket 体系
5.连接关闭:连接关闭时,EnhancedWebSocketHandler的afterConnectionClosed方法被调用,移除对应会话。
@Slf4j @Component public class EnhancedWebSocketHandler implements WebSocketHandler { // 存储用户标识与会话的映射关系,保证线程安全 private static final Map<String, WebSocketSession> userSessionMap = new ConcurrentHashMap<>(); @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { String userKey = (String) session.getAttributes().get("uniqueUserKey"); session.sendMessage(new TextMessage("用户:"+userKey+" 认证成功")); log.info("WebSocket连接已建立,用户唯一标识: {}, 会话ID: {}", userKey, session.getId()); userSessionMap.put(userKey, session); log.info("新用户连接,当前在线用户数: {}", userSessionMap.size()); } @Override public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception { JSONObject json = JSONObject.parseObject((String) message.getPayload()); if(!userSessionMap.containsKey(json.getString("to"))){ session.sendMessage(new TextMessage("接收用户不存在!!!")); return; } String userKey = (String) session.getAttributes().get("uniqueUserKey"); if (!userSessionMap.containsKey(userKey)) { session.sendMessage(new TextMessage("发送用户不存在!!!")); return; } session.sendMessage(new TextMessage("收到 over")); log.info("消息接收成功,内容: {}", message.getPayload()); } @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { String userKey = (String) session.getAttributes().get("uniqueUserKey"); if (userSessionMap.containsKey(userKey)) { log.error("WebSocket传输出现错误,用户标识: {}, 错误信息: {}", userKey, exception.getMessage()); } } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { String userKey = (String) session.getAttributes().get("uniqueUserKey"); log.info("WebSocket连接已关闭,会话ID: {}, 关闭状态: {}", session.getId(), closeStatus); userSessionMap.remove(userKey); } @Override public boolean supportsPartialMessages() { returntrue; } public void sendMessage(String message, WebSocketSession targetSession) { try { log.info("服务端向客户端[{}]发送消息: {}", targetSession.getId(), message); targetSession.sendMessage(new TextMessage(message)); } catch (Exception e) { log.error("服务端向客户端发送消息失败", e); } } public void broadcastMessage(String message) { for (WebSocketSession session : userSessionMap.values()) { sendMessage(message, session); } } }自定义 WebSocket 拦截器
@Slf4j @Component public class SecurityInterceptor implements HandshakeInterceptor { @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler, Map<String, Object> attributes) throws Exception { // 获取 HttpServletRequest 对象 HttpServletRequest rs=((ServletServerHttpRequest) request).getServletRequest(); String token = rs.getParameter("Authorization"); log.info("拦截器获取到的令牌: {}", token); if (token == null ||!isValidToken(token)) { log.warn("无效的令牌,拒绝WebSocket连接"); returnfalse; } String userKey = rs.getParameter("UniqueUserKey"); attributes.put("uniqueUserKey", userKey); returntrue; } @Override public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler, Exception exception) { // 可在此处添加握手成功后的处理逻辑 } private boolean isValidToken(String token) { // 实际应用中应包含复杂的令牌验证逻辑,如JWT验证 // 此处仅为示例,简单判断令牌是否为"validToken" return"1234".equals(token); } }WebSocket 配置类
@Configuration @EnableWebSocket public class WebSocketSecurityConfig implements WebSocketConfigurer { private final EnhancedWebSocketHandler enhancedWebSocketHandler; private final SecurityInterceptor securityInterceptor; public WebSocketSecurityConfig(EnhancedWebSocketHandler enhancedWebSocketHandler, SecurityInterceptor securityInterceptor) { this.enhancedWebSocketHandler = enhancedWebSocketHandler; this.securityInterceptor = securityInterceptor; } @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { registry.addHandler(enhancedWebSocketHandler, "/secure-websocket") .setAllowedOrigins("*") .addInterceptors(securityInterceptor); } @Bean public ServerEndpointExporter serverEndpointExporter() { return new ServerEndpointExporter(); } }示例页面
<!DOCTYPE html> <html lang="zh-CN"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <title>堆代码-duidaima.com</title> <link href="https://cdn.jsdelivr.net/npm/sweetalert2@11.7.2/dist/sweetalert2.min.css" rel="stylesheet" > <style> body { font-family: Arial, sans-serif; } #authSection { margin-bottom: 10px; } #tokenInput, #userKeyInput { width: 200px; padding: 8px; margin-right: 10px; } #authButton { padding: 8px 16px; } #messageInput { width: 300px; padding: 8px; margin-right: 10px; } #targetUserInput { width: 200px; padding: 8px; margin-right: 10px; } #sendButton { padding: 8px 16px; } #messageList { list-style-type: none; padding: 0; } #messageList li { margin: 8px 0; border: 1px solid #ccc; padding: 8px; border-radius: 4px; } </style> </head> <body> <h2>WebSocket 认证交互页面</h2> <div id="authSection"> <label for="tokenInput">输入认证 Token:</label> <input type="text" id="tokenInput" placeholder="请输入认证 Token"> <label for="userKeyInput">输入用户唯一标识:</label> <input type="text" id="userKeyInput" placeholder="请输入用户唯一标识"> <button id="authButton">认证并连接</button> </div> <input type="text" id="messageInput" placeholder="请输入要发送的消息"> <input type="text" id="targetUserInput" placeholder="请输入接收消息的用户标识"> <button id="sendButton">发送消息</button> <ul id="messageList"></ul> <script> let socket; document.getElementById('authButton').addEventListener('click', function () { const token = document.getElementById('tokenInput').value; const userKey = document.getElementById('userKeyInput').value; if (token.trim() === '' || userKey.trim() === '') { console.error('Token 或用户唯一标识不能为空'); return; } const socketUrl = 'ws://localhost:8080/secure-websocket?Authorization='+token+'&UniqueUserKey='+userKey ; socket = new WebSocket(socketUrl); socket.onopen = function () { console.log('WebSocket 连接已打开'); }; socket.onmessage = function (event) { const messageItem = document.createElement('li'); messageItem.textContent = event.data; document.getElementById('messageList').appendChild(messageItem); }; socket.onclose = function () { console.log('WebSocket 连接已关闭'); }; socket.onerror = function (error) { console.error('WebSocket 发生错误:', error); }; }); document.getElementById('sendButton').addEventListener('click', function () { if (!socket || socket.readyState!== WebSocket.OPEN) { console.error('WebSocket 连接未建立或已关闭'); return; } const message = document.getElementById('messageInput').value; const targetUser = document.getElementById('targetUserInput').value; if (message.trim() === '' || targetUser.trim() === '') { return; } const messageObj = { to: targetUser, content: message }; socket.send(JSON.stringify(messageObj)); document.getElementById('messageInput').value = ''; document.getElementById('targetUserInput').value = ''; }); </script> </body> </html>
测试