dmsc/ws/
server.rs

1//! Copyright 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18//! # WebSocket Server Implementation
19
20use super::*;
21use uuid::Uuid;
22use tokio::sync::mpsc;
23use tokio::time::Duration;
24use futures::StreamExt;
25use tungstenite::Message;
26
27#[cfg(feature = "pyo3")]
28#[allow(unused_imports)]
29use pyo3::prelude::*;
30
31pub struct DMSCWSServer {
32    config: DMSCWSServerConfig,
33    stats: Arc<RwLock<DMSCWSServerStats>>,
34    session_manager: Arc<DMSCWSSessionManager>,
35    event_tx: Arc<RwLock<Option<broadcast::Sender<DMSCWSEvent>>>>,
36    shutdown_tx: Option<mpsc::Sender<()>>,
37    running: Arc<RwLock<bool>>,
38    handler: Arc<RwLock<Option<Arc<dyn DMSCWSSessionHandler>>>>,
39}
40
41#[cfg(feature = "pyo3")]
42#[pyclass]
43pub struct DMSCWSServerPy {
44    inner: DMSCWSServer,
45}
46
47#[cfg(feature = "pyo3")]
48#[pymethods]
49impl DMSCWSServerPy {
50    #[new]
51    fn new(config: DMSCWSServerConfig) -> Self {
52        Self {
53            inner: DMSCWSServer::new(config),
54        }
55    }
56
57    fn get_stats(&self) -> DMSCWSServerStats {
58        self.inner.get_stats()
59    }
60
61    fn is_running(&self) -> bool {
62        tokio::runtime::Handle::try_current()
63            .map(|handle| handle.block_on(async { self.inner.is_running().await }))
64            .unwrap_or(false)
65    }
66
67    fn start(&mut self) -> PyResult<()> {
68        let rt = tokio::runtime::Runtime::new()
69            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
70        
71        rt.block_on(async {
72            self.inner.start().await
73        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
74    }
75
76    fn stop(&mut self) -> PyResult<()> {
77        let rt = tokio::runtime::Runtime::new()
78            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
79        
80        rt.block_on(async {
81            self.inner.stop().await
82        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
83    }
84
85    fn broadcast(&self, data: Vec<u8>) -> usize {
86        tokio::runtime::Handle::try_current()
87            .map(|handle| handle.block_on(async { self.inner.broadcast(&data).await.unwrap_or(0) }))
88            .unwrap_or(0)
89    }
90
91    fn get_active_session_count(&self) -> usize {
92        tokio::runtime::Handle::try_current()
93            .map(|handle| handle.block_on(async { self.inner.get_active_session_count().await }))
94            .unwrap_or(0)
95    }
96}
97
98impl DMSCWSServer {
99    pub fn new(config: DMSCWSServerConfig) -> Self {
100        Self {
101            config,
102            stats: Arc::new(RwLock::new(DMSCWSServerStats::new())),
103            session_manager: Arc::new(DMSCWSSessionManager::new(1000)),
104            event_tx: Arc::new(RwLock::new(None)),
105            shutdown_tx: None,
106            running: Arc::new(RwLock::new(false)),
107            handler: Arc::new(RwLock::new(None)),
108        }
109    }
110
111    pub async fn set_handler<H: DMSCWSSessionHandler + 'static>(&self, handler: H) {
112        *self.handler.write().await = Some(Arc::new(handler));
113    }
114
115    pub async fn start(&mut self) -> DMSCResult<()> {
116        let addr: SocketAddr = format!("{}:{}", self.config.addr, self.config.port)
117            .parse()
118            .map_err(|e| WSError::Server {
119                message: format!("Invalid address: {}", e)
120            })?;
121
122        let listener = TcpListener::bind(&addr)
123            .await
124            .map_err(|e| WSError::Server {
125                message: format!("Failed to bind: {}", e)
126            })?;
127
128        let (event_tx, _) = broadcast::channel(100);
129        *self.event_tx.write().await = Some(event_tx);
130
131        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
132        self.shutdown_tx = Some(shutdown_tx);
133
134        let running = self.running.clone();
135        let stats = self.stats.clone();
136        let session_manager = self.session_manager.clone();
137        let handler = self.handler.clone();
138        let config = self.config.clone();
139
140        *running.write().await = true;
141
142        tokio::spawn(async move {
143            Self::accept_connections(
144                listener,
145                session_manager,
146                stats,
147                handler,
148                config,
149                shutdown_rx,
150                running,
151            ).await;
152        });
153
154        tracing::info!("WebSocket server started on {}", addr);
155        Ok(())
156    }
157
158    async fn accept_connections(
159        listener: TcpListener,
160        session_manager: Arc<DMSCWSSessionManager>,
161        stats: Arc<RwLock<DMSCWSServerStats>>,
162        handler: Arc<RwLock<Option<Arc<dyn DMSCWSSessionHandler>>>>,
163        config: DMSCWSServerConfig,
164        mut shutdown_rx: mpsc::Receiver<()>,
165        running: Arc<RwLock<bool>>,
166    ) {
167        let mut shutdown = false;
168
169        while !shutdown {
170            let result = listener.accept().await;
171            
172            if shutdown {
173                break;
174            }
175
176            match result {
177                Ok((stream, remote_addr)) => {
178                    let session_id = Uuid::new_v4().to_string();
179                    let remote_addr_str = remote_addr.to_string();
180                    
181                    tracing::info!("New WebSocket connection: {} (session: {})", remote_addr_str, session_id);
182
183                    match tokio_tungstenite::accept_async(stream).await {
184                        Ok(ws_stream) => {
185                            let (_sender, receiver) = ws_stream.split();
186                            let (tx, rx) = mpsc::channel(100);
187
188                            let session = Arc::new(DMSCWSSession::new(
189                                session_id.clone(),
190                                tx,
191                                receiver,
192                                remote_addr_str.clone(),
193                            ));
194
195                            if session_manager.add_session(session.clone()).await.is_ok() {
196                                stats.write().await.record_connection();
197
198                                let handler_clone = handler.clone();
199                                let session_manager_clone = session_manager.clone();
200                                let stats_clone = stats.clone();
201
202                                tokio::spawn(async move {
203                                    Self::handle_session(
204                                        session.clone(),
205                                        rx,
206                                        handler_clone,
207                                        session_manager_clone,
208                                        stats_clone,
209                                    ).await;
210                                });
211                            } else {
212                                tracing::trace!("Failed to add session: {}", session_id);
213                            }
214                        }
215                        Err(e) => {
216                            tracing::error!("WebSocket upgrade failed: {}", e);
217                            stats.write().await.record_connection_error();
218                        }
219                    }
220                }
221                Err(e) => {
222                    tracing::error!("Failed to accept connection: {}", e);
223                    stats.write().await.record_connection_error();
224                }
225            }
226
227            tokio::time::sleep(Duration::from_secs(config.heartbeat_interval)).await;
228            
229            if !*running.read().await {
230                break;
231            }
232            
233            let _timeout = Duration::from_secs(config.heartbeat_timeout);
234            let sessions = session_manager.get_all_sessions().await;
235            for session_info in sessions {
236                let last_heartbeat_time = chrono::DateTime::from_timestamp(session_info.last_heartbeat as i64, 0)
237                    .unwrap_or_else(|| chrono::Utc::now());
238                let elapsed = last_heartbeat_time.signed_duration_since(chrono::Utc::now());
239                let elapsed_secs = elapsed.num_seconds() as u64;
240                
241                if elapsed_secs > config.heartbeat_timeout {
242                    if let Some(session) = session_manager.get_session(&session_info.session_id).await {
243                        let _ = session.close().await;
244                    }
245                }
246            }
247            
248            if shutdown_rx.try_recv().is_ok() {
249                shutdown = true;
250            }
251        }
252
253        tracing::info!("WebSocket server stopped");
254    }
255
256    async fn handle_session(
257        session: Arc<DMSCWSSession>,
258        mut rx: mpsc::Receiver<std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>,
259        handler: Arc<RwLock<Option<Arc<dyn DMSCWSSessionHandler>>>>,
260        session_manager: Arc<DMSCWSSessionManager>,
261        stats: Arc<RwLock<DMSCWSServerStats>>,
262    ) {
263        let session_id = session.id.clone();
264
265        while let Some(message_result) = rx.recv().await {
266            match message_result {
267                Ok(message) => {
268                    match message {
269                        Message::Binary(data) => {
270                            stats.write().await.record_message_received(data.len());
271
272                            let handler_read = handler.read().await;
273                            if let Some(handler) = &*handler_read {
274                                if let Ok(response) = handler.on_message(&session_id, &data).await {
275                                    if session.send(&response).await.is_err() {
276                                        break;
277                                    }
278                                    stats.write().await.record_message_sent(response.len());
279                                }
280                            } else {
281                                if session.send(&data).await.is_err() {
282                                    break;
283                                }
284                                stats.write().await.record_message_sent(data.len());
285                            }
286                        }
287                        Message::Text(text) => {
288                            let data = text.into_bytes();
289                            stats.write().await.record_message_received(data.len());
290                            
291                            let handler_read = handler.read().await;
292                            if let Some(handler) = &*handler_read {
293                                if let Ok(response) = handler.on_message(&session_id, &data).await {
294                                    if session.send(&response).await.is_err() {
295                                        break;
296                                    }
297                                }
298                            }
299                        }
300                        Message::Ping(ping_data) => {
301                            if session.send(&ping_data).await.is_err() {
302                                break;
303                            }
304                        }
305                        Message::Pong(_) => {}
306                        Message::Close(_) => {
307                            break;
308                        }
309                        _ => {}
310                    }
311                }
312                Err(e) => {
313                    tracing::error!("WebSocket error for session {}: {}", session_id, e);
314                    stats.write().await.record_message_error();
315                    break;
316                }
317            }
318        }
319
320        session_manager.remove_session(&session_id).await;
321        stats.write().await.record_disconnection();
322
323        let handler_read = handler.read().await;
324        if let Some(handler) = &*handler_read {
325            let _ = handler.on_disconnect(&session_id).await;
326        }
327    }
328
329    pub async fn stop(&mut self) -> DMSCResult<()> {
330        *self.running.write().await = false;
331
332        let sessions = self.session_manager.get_all_sessions().await;
333        for session_info in sessions {
334            if let Some(session) = self.session_manager.get_session(&session_info.session_id).await {
335                let _ = session.close().await;
336            }
337        }
338
339        if let Some(tx) = self.shutdown_tx.take() {
340            tx.send(()).await.map_err(|e| WSError::Server {
341                message: format!("Shutdown error: {}", e)
342            })?;
343        }
344
345        tracing::info!("WebSocket server stopped");
346        Ok(())
347    }
348
349    pub fn get_stats(&self) -> DMSCWSServerStats {
350        self.stats.try_read()
351            .map(|guard| guard.clone())
352            .unwrap_or_else(|_| DMSCWSServerStats::new())
353    }
354
355    pub async fn get_session_info(&self, session_id: &str) -> Option<DMSCWSSessionInfo> {
356        self.session_manager.get_session(session_id).await
357            .map(|s| s.get_info())
358    }
359
360    pub async fn broadcast(&self, data: &[u8]) -> DMSCResult<usize> {
361        let count = self.session_manager.broadcast(data).await?;
362        self.stats.write().await.record_message_sent(data.len() * count);
363        Ok(count)
364    }
365
366    pub async fn is_running(&self) -> bool {
367        *self.running.read().await
368    }
369
370    pub async fn get_active_session_count(&self) -> usize {
371        self.session_manager.get_session_count().await
372    }
373}
374
375impl Clone for DMSCWSServer {
376    fn clone(&self) -> Self {
377        Self {
378            config: self.config.clone(),
379            stats: self.stats.clone(),
380            session_manager: self.session_manager.clone(),
381            event_tx: self.event_tx.clone(),
382            shutdown_tx: None,
383            running: self.running.clone(),
384            handler: self.handler.clone(),
385        }
386    }
387}
388
389impl Default for DMSCWSServer {
390    fn default() -> Self {
391        Self::new(DMSCWSServerConfig::default())
392    }
393}