|
|
@@ -0,0 +1,189 @@
|
|
|
+package com.ktg.framework.websocket;
|
|
|
+
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+
|
|
|
+import javax.websocket.*;
|
|
|
+import javax.websocket.server.PathParam;
|
|
|
+import javax.websocket.server.ServerEndpoint;
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.List;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
+import java.util.concurrent.CopyOnWriteArraySet;
|
|
|
+
|
|
|
+@Component
|
|
|
+@ServerEndpoint("/websocket/android/{code}")
|
|
|
+@Slf4j
|
|
|
+public class WebSocketAndroid {
|
|
|
+ private Session session;
|
|
|
+ private String code;
|
|
|
+ // 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
|
|
|
+ private static int onlineCount = 0;
|
|
|
+ private static CopyOnWriteArraySet<WebSocketAndroid> webSocketSet = new CopyOnWriteArraySet<>();
|
|
|
+ //concurrent包的线程安全set,用来存放每个客户端对应的MyWebSocket对象
|
|
|
+ private static ConcurrentHashMap<String, WebSocketAndroid> webSocketMap2 = new ConcurrentHashMap();
|
|
|
+
|
|
|
+ // 为了保存在线用户信息,在方法中新建一个list存储一下【实际项目依据复杂度,可以存储到数据库或者缓存】
|
|
|
+ private final static List<Session> SESSIONS = Collections.synchronizedList(new ArrayList<>());
|
|
|
+ private Thread heartbeatThread;
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 建立连接
|
|
|
+ *
|
|
|
+ * @param session
|
|
|
+ * @param code
|
|
|
+ */
|
|
|
+ @OnOpen
|
|
|
+ public void onOpen(Session session, @PathParam("code") String code) {
|
|
|
+ this.session = session;
|
|
|
+ this.code = code;
|
|
|
+ webSocketSet.add(this);
|
|
|
+ SESSIONS.add(session);
|
|
|
+ if (webSocketMap2.containsKey(code)) {
|
|
|
+ webSocketMap2.remove(code);
|
|
|
+ webSocketMap2.put(code, this);
|
|
|
+ } else {
|
|
|
+ webSocketMap2.put(code, this);
|
|
|
+ addOnlineCount();
|
|
|
+ }
|
|
|
+ // 心跳机制
|
|
|
+ heartbeatThread = new Thread(() -> {
|
|
|
+ try {
|
|
|
+ sendHeartbeats(session, code);
|
|
|
+ } catch (IOException e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ });
|
|
|
+ heartbeatThread.start();
|
|
|
+ log.info("[连接ID:{}] 建立连接, 当前连接数:{}", this.code, webSocketMap2.size());
|
|
|
+ log.info("线程序号{}, 当前存活线程数{}", heartbeatThread.getId(), Thread.activeCount());
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ private void sendHeartbeats(Session session, String code) throws IOException {
|
|
|
+ int heartbeatInterval = 60000; // 60 seconds
|
|
|
+ while (true) {
|
|
|
+ try {
|
|
|
+ Thread.sleep(heartbeatInterval);
|
|
|
+ session.getBasicRemote().sendText("heartbeat");
|
|
|
+ } catch (InterruptedException | IOException e) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 断开连接
|
|
|
+ */
|
|
|
+ @OnClose
|
|
|
+ public void onClose() {
|
|
|
+ webSocketSet.remove(this);
|
|
|
+ if (webSocketMap2.containsKey(code)) {
|
|
|
+ webSocketMap2.remove(code);
|
|
|
+ subOnlineCount();
|
|
|
+ }
|
|
|
+ // 关闭当前线程
|
|
|
+ if (heartbeatThread != null && heartbeatThread.isAlive()) {
|
|
|
+ log.info("关闭线程序号{}, 当前存活线程数{}", heartbeatThread.getId(), Thread.activeCount());
|
|
|
+ heartbeatThread.interrupt(); // Interrupt the heartbeat thread
|
|
|
+ try {
|
|
|
+ heartbeatThread.join(); // Wait for the thread to finish
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ log.error("Error joining heartbeat thread for connection ID: {}", code, e);
|
|
|
+ log.info("关闭线程序号{}, 当前存活线程数{}", Thread.currentThread().getId(), Thread.activeCount());
|
|
|
+ Thread.currentThread().interrupt(); // Restore the interruption status
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log.info("[连接ID:{}] 断开连接, 当前连接数:{}", code, webSocketMap2.size());
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 发送错误
|
|
|
+ *
|
|
|
+ * @param session
|
|
|
+ * @param error
|
|
|
+ */
|
|
|
+ @OnError
|
|
|
+ public void onError(Session session, Throwable error) {
|
|
|
+ log.info("[连接ID:{}] 错误原因:{}", this.code, error.getMessage());
|
|
|
+ error.printStackTrace();
|
|
|
+ // 发生错误时,关闭连接
|
|
|
+ // conn.close(500, "连接出错");
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 收到消息
|
|
|
+ *
|
|
|
+ * @param message
|
|
|
+ */
|
|
|
+ @OnMessage
|
|
|
+ public void onMessage(String message) {
|
|
|
+ // log.info("【websocket消息】收到客户端发来的消息:{}", message);
|
|
|
+ log.info("[连接ID:{}] 收到消息:{}", this.code, message);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 发送消息
|
|
|
+ *
|
|
|
+ * @param message
|
|
|
+ * @param code
|
|
|
+ */
|
|
|
+ public static void sendMessage(String code, String message) {
|
|
|
+ WebSocketAndroid webSocketIots = webSocketMap2.get(code);
|
|
|
+ log.info("【websocket消息】推送消息, webSocketServer={}", webSocketIots);
|
|
|
+ if (webSocketIots != null) {
|
|
|
+ log.info("【websocket消息】推送消息, message={}", message);
|
|
|
+ try {
|
|
|
+ webSocketIots.session.getBasicRemote().sendText(message);
|
|
|
+ } catch (Exception e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ log.error("[连接ID:{}] 发送消息失败, 消息:{}", code, message, e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 群发消息
|
|
|
+ *
|
|
|
+ * @param message
|
|
|
+ */
|
|
|
+ public static void sendMassMessage(String message) {
|
|
|
+ try {
|
|
|
+ for (Session session : SESSIONS) {
|
|
|
+ if (session.isOpen()) {
|
|
|
+ session.getBasicRemote().sendText(message);
|
|
|
+ log.info("[连接ID:{}] 发送消息:{}", session.getRequestParameterMap().get("code"), message);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取当前连接数
|
|
|
+ *
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ public static synchronized int getOnlineCount() {
|
|
|
+ return onlineCount;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 当前连接数加一
|
|
|
+ */
|
|
|
+ public static synchronized void addOnlineCount() {
|
|
|
+ WebSocketAndroid.onlineCount++;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 当前连接数减一
|
|
|
+ */
|
|
|
+ public static synchronized void subOnlineCount() {
|
|
|
+ WebSocketAndroid.onlineCount--;
|
|
|
+ }
|
|
|
+
|
|
|
+}
|
|
|
+
|