Browse Source

优化ws 支持同一个code多端接收消息

车车 1 month ago
parent
commit
2716e9eb88

+ 80 - 70
ktg-framework/src/main/java/com/ktg/framework/websocket/WebSocketAndroid.java

@@ -10,30 +10,33 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CopyOnWriteArraySet;
 
+
+/**
+ * 支持同一个code多端接收消息
+ */
 @Component
 @ServerEndpoint("/websocket/android/{code}")
 @Slf4j
 public class WebSocketAndroid {
     private Session session;
     private String code;
-    // 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
+    // 静态变量,记录当前在线连接数(线程安全)
     private static int onlineCount = 0;
+    // 存放所有客户端连接(线程安全)
     private static CopyOnWriteArraySet<WebSocketAndroid> webSocketSet = new CopyOnWriteArraySet<>();
-    //concurrent包的线程安全set,用来存放每个客户端对应的MyWebSocket对象
-    private static ConcurrentHashMap<String, WebSocketAndroid> webSocketMap2 = new ConcurrentHashMap();
+    // 核心修改:code与连接集合的映射(一个code对应多个连接)
+    private static ConcurrentHashMap<String, Set<WebSocketAndroid>> webSocketMap2 = new ConcurrentHashMap<>();
 
-    // 为了保存在线用户信息,在方法中新建一个list存储一下【实际项目依据复杂度,可以存储到数据库或者缓存】
+    // 存储所有会话(备用)
     private final static List<Session> SESSIONS = Collections.synchronizedList(new ArrayList<>());
     private Thread heartbeatThread;
 
     /**
      * 建立连接
-     *
-     * @param session
-     * @param code
      */
     @OnOpen
     public void onOpen(Session session, @PathParam("code") String code) {
@@ -41,34 +44,46 @@ public class WebSocketAndroid {
         this.code = code;
         webSocketSet.add(this);
         SESSIONS.add(session);
-        if (webSocketMap2.containsKey(code)) {
-            webSocketMap2.remove(code);
-            webSocketMap2.put(code, this);
-        } else {
-            webSocketMap2.put(code, this);
-            addOnlineCount();
-        }
-        // 心跳机制
+
+        // 关键:将当前连接加入code对应的集合(不存在则创建)
+        webSocketMap2.computeIfAbsent(code, k -> ConcurrentHashMap.newKeySet())
+                .add(this);
+
+        addOnlineCount(); // 在线数+1
+
+        // 启动心跳线程
         heartbeatThread = new Thread(() -> {
             try {
                 sendHeartbeats(session, code);
             } catch (IOException e) {
-                e.printStackTrace();
+                log.error("心跳线程异常", e);
             }
         });
         heartbeatThread.start();
-        log.info("[连接ID:{}] 建立连接, 当前连接数:{}", this.code, webSocketMap2.size());
-        log.info("线程序号{}, 当前存活线程数{}", heartbeatThread.getId(), Thread.activeCount());
 
+        log.info("[连接ID:{}] 建立连接, 当前总连接数:{}", this.code, getOnlineCount());
+        log.info("心跳线程序号{}, 当前存活线程数{}", heartbeatThread.getId(), Thread.activeCount());
     }
 
+    /**
+     * 发送心跳包
+     */
     private void sendHeartbeats(Session session, String code) throws IOException {
-        int heartbeatInterval = 60000; // 60 seconds
-        while (true) {
+        int heartbeatInterval = 60000; // 60秒间隔
+        while (!Thread.currentThread().isInterrupted()) { // 检查线程中断状态
             try {
                 Thread.sleep(heartbeatInterval);
-                session.getBasicRemote().sendText("heartbeat");
-            } catch (InterruptedException | IOException e) {
+                if (session.isOpen()) {
+                    session.getBasicRemote().sendText("heartbeat");
+                } else {
+                    break; // 连接已关闭,退出循环
+                }
+            } catch (InterruptedException e) {
+                // 线程被中断,退出循环
+                Thread.currentThread().interrupt();
+                break;
+            } catch (IOException e) {
+                log.error("[连接ID:{}] 心跳发送失败", code, e);
                 break;
             }
         }
@@ -80,110 +95,105 @@ public class WebSocketAndroid {
     @OnClose
     public void onClose() {
         webSocketSet.remove(this);
-        if (webSocketMap2.containsKey(code)) {
-            webSocketMap2.remove(code);
-            subOnlineCount();
+        SESSIONS.remove(this.session);
+
+        // 从code对应的集合中移除当前连接
+        Set<WebSocketAndroid> connections = webSocketMap2.get(code);
+        if (connections != null) {
+            connections.remove(this);
+            // 若集合为空,从map中删除该code
+            if (connections.isEmpty()) {
+                webSocketMap2.remove(code);
+            }
         }
-        // 关闭当前线程
+
+        subOnlineCount(); // 在线数-1
+
+        // 关闭心跳线程
         if (heartbeatThread != null && heartbeatThread.isAlive()) {
-            log.info("关闭线程序号{}, 当前存活线程数{}", heartbeatThread.getId(), Thread.activeCount());
-            heartbeatThread.interrupt(); // Interrupt the heartbeat thread
+            log.info("关闭[连接ID:{}]的心跳线程序号{}, 当前存活线程数{}",
+                    code, heartbeatThread.getId(), Thread.activeCount());
+            heartbeatThread.interrupt(); // 中断线程
             try {
-                heartbeatThread.join(); // Wait for the thread to finish
+                heartbeatThread.join(1000); // 等待1秒,确保线程退出
             } catch (InterruptedException e) {
-                log.error("Error joining heartbeat thread for connection ID: {}", code, e);
-                log.info("关闭线程序号{}, 当前存活线程数{}", Thread.currentThread().getId(), Thread.activeCount());
-                Thread.currentThread().interrupt(); // Restore the interruption status
+                log.error("[连接ID:{}] 心跳线程关闭异常", code, e);
+                Thread.currentThread().interrupt();
             }
         }
-        log.info("[连接ID:{}] 断开连接, 当前连接数:{}", code, webSocketMap2.size());
+
+        log.info("[连接ID:{}] 断开连接, 当前总连接数:{}", code, getOnlineCount());
     }
 
     /**
      * 发送错误
-     *
-     * @param session
-     * @param error
      */
     @OnError
     public void onError(Session session, Throwable error) {
-        log.info("[连接ID:{}] 错误原因:{}", this.code, error.getMessage());
-        error.printStackTrace();
-        // 发生错误时,关闭连接
-        // conn.close(500, "连接出错");
+        log.error("[连接ID:{}] 发生错误", this.code, error);
     }
 
     /**
      * 收到消息
-     *
-     * @param message
      */
     @OnMessage
     public void onMessage(String message) {
-        // log.info("【websocket消息】收到客户端发来的消息:{}", message);
         log.info("[连接ID:{}] 收到消息:{}", this.code, message);
     }
 
     /**
-     * 发送消息
-     *
-     * @param message
-     * @param code
+     * 给指定code的所有连接发送消息
      */
     public static void sendMessage(String code, String message) {
-        WebSocketAndroid webSocketIots = webSocketMap2.get(code);
-        log.info("【websocket消息】推送消息, webSocketServer={}", webSocketIots);
-        if (webSocketIots != null) {
-            log.info("【websocket消息】推送消息, message={}", message);
+        Set<WebSocketAndroid> connections = webSocketMap2.get(code);
+        if (connections == null || connections.isEmpty()) {
+            log.warn("【websocket】code:{} 无在线连接,无法发送消息", code);
+            return;
+        }
+
+        // 遍历所有连接发送消息
+        for (WebSocketAndroid webSocket : connections) {
             try {
-                webSocketIots.session.getBasicRemote().sendText(message);
+                if (webSocket.session.isOpen()) {
+                    webSocket.session.getBasicRemote().sendText(message);
+                    log.info("【websocket】给code:{}的连接发送消息成功:{}", code, message);
+                } else {
+                    log.warn("【websocket】code:{}的连接已关闭,跳过发送", code);
+                }
             } catch (Exception e) {
-                e.printStackTrace();
-                log.error("[连接ID:{}] 发送消息失败, 消息:{}", code, message, e);
+                log.error("【websocket】给code:{}的连接发送消息失败", code, e);
             }
         }
     }
 
     /**
-     * 群发消息
-     *
-     * @param message
+     * 群发消息(所有连接)
      */
     public static void sendMassMessage(String message) {
         try {
             for (Session session : SESSIONS) {
                 if (session.isOpen()) {
                     session.getBasicRemote().sendText(message);
-                    log.info("[连接ID:{}] 发送消息:{}", session.getRequestParameterMap().get("code"), message);
+                    log.info("【websocket】群发消息:{} 到会话:{}", message, session.getId());
                 }
             }
         } catch (Exception e) {
-            e.printStackTrace();
+            log.error("【websocket】群发消息失败", e);
         }
     }
 
     /**
-     * 获取当前连接数
-     *
-     * @return
+     * 在线人数统计
      */
     public static synchronized int getOnlineCount() {
         return onlineCount;
     }
 
-    /**
-     * 当前连接数加一
-     */
     public static synchronized void addOnlineCount() {
         WebSocketAndroid.onlineCount++;
     }
 
-    /**
-     * 当前连接数减一
-     */
     public static synchronized void subOnlineCount() {
         WebSocketAndroid.onlineCount--;
     }
-
 }
-