dmsc/ws/
mod.rs

1//! Copyright 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18//! # WebSocket Support
19
20use crate::core::{DMSCResult, DMSCError};
21use async_trait::async_trait;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24use std::net::SocketAddr;
25use futures::stream::SplitStream;
26use tokio::net::TcpListener;
27use tokio::sync::broadcast;
28use std::collections::HashMap;
29use tungstenite::Message;
30
31#[cfg(feature = "pyo3")]
32use pyo3::prelude::*;
33
34#[cfg(feature = "websocket")]
35mod server;
36
37#[cfg(feature = "websocket")]
38mod client;
39
40#[cfg(feature = "websocket")]
41pub use server::DMSCWSServer;
42
43#[cfg(feature = "websocket")]
44pub use client::DMSCWSClient;
45
46#[cfg(feature = "websocket")]
47pub use client::DMSCWSClientConfig;
48
49#[cfg(feature = "websocket")]
50pub use client::DMSCWSClientStats;
51
52#[cfg(all(feature = "websocket", feature = "pyo3"))]
53pub use server::DMSCWSServerPy;
54
55#[cfg(all(feature = "websocket", feature = "pyo3"))]
56pub use client::DMSCWSClientPy;
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59#[cfg_attr(feature = "pyo3", pyclass)]
60pub struct DMSCWSServerConfig {
61    pub addr: String,
62    pub port: u16,
63    pub max_connections: usize,
64    pub heartbeat_interval: u64,
65    pub heartbeat_timeout: u64,
66    pub max_message_size: usize,
67    pub ping_interval: u64,
68}
69
70#[cfg(feature = "pyo3")]
71#[pymethods]
72impl DMSCWSServerConfig {
73    #[new]
74    fn new() -> Self {
75        Self::default()
76    }
77    
78    #[getter]
79    fn get_addr(&self) -> String {
80        self.addr.clone()
81    }
82    
83    #[setter]
84    fn set_addr(&mut self, addr: String) {
85        self.addr = addr;
86    }
87    
88    #[getter]
89    fn get_port(&self) -> u16 {
90        self.port
91    }
92    
93    #[setter]
94    fn set_port(&mut self, port: u16) {
95        self.port = port;
96    }
97}
98
99impl Default for DMSCWSServerConfig {
100    fn default() -> Self {
101        Self {
102            addr: "127.0.0.1".to_string(),
103            port: 8080,
104            max_connections: 1000,
105            heartbeat_interval: 30,
106            heartbeat_timeout: 60,
107            max_message_size: 65536,
108            ping_interval: 25,
109        }
110    }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114#[cfg_attr(feature = "pyo3", pyclass)]
115pub enum DMSCWSEvent {
116    Connected { session_id: String },
117    Disconnected { session_id: String },
118    Message { session_id: String, data: Vec<u8> },
119    Error { session_id: String, message: String },
120}
121
122#[derive(Debug, Clone)]
123#[cfg_attr(feature = "pyo3", pyclass)]
124pub struct DMSCWSSessionInfo {
125    pub session_id: String,
126    pub remote_addr: String,
127    pub connected_at: u64,
128    pub messages_sent: u64,
129    pub messages_received: u64,
130    pub bytes_sent: u64,
131    pub bytes_received: u64,
132    pub is_active: bool,
133    pub last_heartbeat: u64,
134}
135
136#[cfg(feature = "pyo3")]
137#[pymethods]
138impl DMSCWSSessionInfo {
139    #[getter]
140    fn get_session_id(&self) -> String {
141        self.session_id.clone()
142    }
143    
144    #[getter]
145    fn get_remote_addr(&self) -> String {
146        self.remote_addr.clone()
147    }
148    
149    #[getter]
150    fn get_connected_at(&self) -> u64 {
151        self.connected_at
152    }
153    
154    #[getter]
155    fn get_messages_sent(&self) -> u64 {
156        self.messages_sent
157    }
158    
159    #[getter]
160    fn get_messages_received(&self) -> u64 {
161        self.messages_received
162    }
163    
164    #[getter]
165    fn get_bytes_sent(&self) -> u64 {
166        self.bytes_sent
167    }
168    
169    #[getter]
170    fn get_bytes_received(&self) -> u64 {
171        self.bytes_received
172    }
173    
174    #[getter]
175    fn get_is_active(&self) -> bool {
176        self.is_active
177    }
178    
179    #[getter]
180    fn get_last_heartbeat(&self) -> u64 {
181        self.last_heartbeat
182    }
183}
184
185impl Default for DMSCWSSessionInfo {
186    fn default() -> Self {
187        Self {
188            session_id: String::new(),
189            remote_addr: String::new(),
190            connected_at: 0,
191            messages_sent: 0,
192            messages_received: 0,
193            bytes_sent: 0,
194            bytes_received: 0,
195            is_active: false,
196            last_heartbeat: 0,
197        }
198    }
199}
200
201#[async_trait]
202pub trait DMSCWSSessionHandler: Send + Sync {
203    async fn on_connect(&self, session_id: &str, remote_addr: &str) -> DMSCResult<()>;
204    async fn on_disconnect(&self, session_id: &str) -> DMSCResult<()>;
205    async fn on_message(&self, session_id: &str, data: &[u8]) -> DMSCResult<Vec<u8>>;
206    async fn on_error(&self, session_id: &str, error: &str) -> DMSCResult<()>;
207}
208
209#[derive(Debug, thiserror::Error)]
210pub enum WSError {
211    #[error("Server error: {message}")]
212    Server { message: String },
213    #[error("Session error: {message}")]
214    Session { message: String },
215    #[error("Connection error: {message}")]
216    Connection { message: String },
217    #[error("Message too large: {size} bytes (max: {max_size})")]
218    MessageTooLarge { size: usize, max_size: usize },
219    #[error("Session not found: {session_id}")]
220    SessionNotFound { session_id: String },
221    #[error("Invalid message format")]
222    InvalidFormat,
223}
224
225impl From<WSError> for DMSCError {
226    fn from(error: WSError) -> Self {
227        DMSCError::Other(format!("WebSocket error: {}", error))
228    }
229}
230
231pub struct DMSCWSSession {
232    pub id: String,
233    pub sender: tokio::sync::mpsc::Sender<std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>,
234    pub receiver: SplitStream<tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>>,
235    pub info: Arc<RwLock<DMSCWSSessionInfo>>,
236}
237
238impl DMSCWSSession {
239    pub fn new(
240        id: String,
241        sender: tokio::sync::mpsc::Sender<std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>,
242        receiver: SplitStream<tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>>,
243        remote_addr: String,
244    ) -> Self {
245            let now = chrono::Utc::now().timestamp() as u64;
246        let session_id = id.clone();
247        Self {
248            id,
249            sender,
250            receiver,
251            info: Arc::new(RwLock::new(DMSCWSSessionInfo {
252                session_id,
253                remote_addr,
254                connected_at: now,
255                messages_sent: 0,
256                messages_received: 0,
257                bytes_sent: 0,
258                bytes_received: 0,
259                is_active: true,
260                last_heartbeat: now,
261            })),
262        }
263    }
264
265    pub async fn send(&self, data: &[u8]) -> DMSCResult<()> {
266        let message = Message::Binary(data.to_vec());
267        
268        self.sender.send(Ok(message))
269            .await
270            .map_err(|e| WSError::Session {
271                message: format!("Failed to send message: {}", e)
272            })?;
273
274        let mut info = self.info.write().await;
275        info.messages_sent += 1;
276        info.bytes_sent += data.len() as u64;
277
278        Ok(())
279    }
280
281    pub async fn send_text(&self, text: &str) -> DMSCResult<()> {
282        let message = Message::Text(text.to_string());
283        
284        self.sender.send(Ok(message))
285            .await
286            .map_err(|e| WSError::Session {
287                message: format!("Failed to send message: {}", e)
288            })?;
289
290        let mut info = self.info.write().await;
291        info.messages_sent += 1;
292        info.bytes_sent += text.len() as u64;
293
294        Ok(())
295    }
296
297    pub async fn close(&self) -> DMSCResult<()> {
298        self.sender.send(Ok(Message::Close(None)))
299            .await
300            .map_err(|e| WSError::Session {
301                message: format!("Failed to close session: {}", e)
302            })?;
303
304        let mut info = self.info.write().await;
305        info.is_active = false;
306
307        Ok(())
308    }
309
310    pub fn get_info(&self) -> DMSCWSSessionInfo {
311        self.info.try_read()
312            .map(|guard| guard.clone())
313            .unwrap_or_else(|_| DMSCWSSessionInfo::default())
314    }
315}
316
317pub struct DMSCWSSessionManager {
318    sessions: Arc<RwLock<HashMap<String, Arc<DMSCWSSession>>>>,
319    max_connections: usize,
320}
321
322impl Clone for DMSCWSSessionManager {
323    fn clone(&self) -> Self {
324        Self {
325            sessions: self.sessions.clone(),
326            max_connections: self.max_connections,
327        }
328    }
329}
330
331impl DMSCWSSessionManager {
332    pub fn new(max_connections: usize) -> Self {
333        Self {
334            sessions: Arc::new(RwLock::new(HashMap::new())),
335            max_connections,
336        }
337    }
338
339    pub async fn add_session(&self, session: Arc<DMSCWSSession>) -> DMSCResult<()> {
340        let mut sessions = self.sessions.write().await;
341        
342        if sessions.len() >= self.max_connections {
343            return Err(WSError::Session {
344                message: format!("Max connections reached: {}", self.max_connections)
345            }.into());
346        }
347
348        sessions.insert(session.id.clone(), session);
349        Ok(())
350    }
351
352    pub async fn remove_session(&self, session_id: &str) {
353        let mut sessions = self.sessions.write().await;
354        sessions.remove(session_id);
355    }
356
357    pub async fn get_session(&self, session_id: &str) -> Option<Arc<DMSCWSSession>> {
358        let sessions = self.sessions.read().await;
359        sessions.get(session_id).cloned()
360    }
361
362    pub async fn broadcast(&self, data: &[u8]) -> DMSCResult<usize> {
363        let sessions = self.sessions.read().await;
364        let mut count = 0;
365
366        for session in sessions.values() {
367            if session.send(data).await.is_ok() {
368                count += 1;
369            }
370        }
371
372        Ok(count)
373    }
374
375    pub async fn get_session_count(&self) -> usize {
376        self.sessions.read().await.len()
377    }
378
379    pub async fn get_all_sessions(&self) -> Vec<DMSCWSSessionInfo> {
380        let sessions = self.sessions.read().await;
381        sessions.values().map(|s| s.get_info()).collect()
382    }
383}
384
385#[cfg(feature = "pyo3")]
386#[pyclass]
387pub struct DMSCWSPythonHandler {
388    on_connect: Arc<Py<PyAny>>,
389    on_disconnect: Arc<Py<PyAny>>,
390    on_message: Arc<Py<PyAny>>,
391    on_error: Arc<Py<PyAny>>,
392}
393
394#[cfg(feature = "pyo3")]
395#[pymethods]
396impl DMSCWSPythonHandler {
397    #[new]
398    fn new(
399        on_connect: Py<PyAny>,
400        on_disconnect: Py<PyAny>,
401        on_message: Py<PyAny>,
402        on_error: Py<PyAny>,
403    ) -> Self {
404        Self {
405            on_connect: Arc::new(on_connect),
406            on_disconnect: Arc::new(on_disconnect),
407            on_message: Arc::new(on_message),
408            on_error: Arc::new(on_error),
409        }
410    }
411}
412
413#[cfg(feature = "pyo3")]
414#[async_trait]
415impl DMSCWSSessionHandler for DMSCWSPythonHandler {
416    async fn on_connect(&self, session_id: &str, remote_addr: &str) -> DMSCResult<()> {
417        let on_connect = Arc::clone(&self.on_connect);
418        let session_id = session_id.to_string();
419        let remote_addr = remote_addr.to_string();
420        
421        tokio::task::spawn_blocking(move || {
422            Python::attach(|py| {
423                let handler = on_connect.clone_ref(py);
424                let _ = handler.call(py, (session_id, remote_addr), None);
425            });
426        }).await.ok();
427        
428        Ok(())
429    }
430    
431    async fn on_disconnect(&self, session_id: &str) -> DMSCResult<()> {
432        let on_disconnect = Arc::clone(&self.on_disconnect);
433        let session_id = session_id.to_string();
434        
435        tokio::task::spawn_blocking(move || {
436            Python::attach(|py| {
437                let handler = on_disconnect.clone_ref(py);
438                let _ = handler.call(py, (session_id,), None);
439            });
440        }).await.ok();
441        
442        Ok(())
443    }
444    
445    async fn on_message(&self, session_id: &str, data: &[u8]) -> DMSCResult<Vec<u8>> {
446        let on_message = Arc::clone(&self.on_message);
447        let session_id = session_id.to_string();
448        let data_vec = data.to_vec();
449        
450        let result = tokio::task::spawn_blocking(move || {
451            Python::attach(|py| {
452                let handler = on_message.clone_ref(py);
453                match handler.call(py, (session_id, data_vec), None) {
454                    Ok(obj) => obj.extract::<Vec<u8>>(py).ok(),
455                    Err(_) => None,
456                }
457            })
458        }).await.ok().flatten();
459        
460        Ok(result.unwrap_or_default())
461    }
462    
463    async fn on_error(&self, session_id: &str, error: &str) -> DMSCResult<()> {
464        let on_error = Arc::clone(&self.on_error);
465        let session_id = session_id.to_string();
466        let error = error.to_string();
467        
468        tokio::task::spawn_blocking(move || {
469            Python::attach(|py| {
470                let handler = on_error.clone_ref(py);
471                let _ = handler.call(py, (session_id, error), None);
472            });
473        }).await.ok();
474        
475        Ok(())
476    }
477}
478
479#[cfg(feature = "pyo3")]
480#[pyclass]
481pub struct DMSCWSSessionManagerPy {
482    manager: DMSCWSSessionManager,
483}
484
485#[cfg(feature = "pyo3")]
486#[pymethods]
487impl DMSCWSSessionManagerPy {
488    #[new]
489    fn new(max_connections: usize) -> Self {
490        Self {
491            manager: DMSCWSSessionManager::new(max_connections),
492        }
493    }
494    
495    fn get_session_count(&self) -> usize {
496        tokio::runtime::Handle::try_current()
497            .map(|handle| handle.block_on(async { self.manager.get_session_count().await }))
498            .unwrap_or(0)
499    }
500    
501    fn get_all_sessions(&self) -> Vec<DMSCWSSessionInfo> {
502        tokio::runtime::Handle::try_current()
503            .map(|handle| handle.block_on(async { self.manager.get_all_sessions().await }))
504            .unwrap_or_default()
505    }
506    
507    fn broadcast(&self, data: Vec<u8>) -> usize {
508        tokio::runtime::Handle::try_current()
509            .map(|handle| handle.block_on(async { self.manager.broadcast(&data).await.unwrap_or(0) }))
510            .unwrap_or(0)
511    }
512}
513
514#[derive(Debug, Clone)]
515#[cfg_attr(feature = "pyo3", pyclass)]
516pub struct DMSCWSServerStats {
517    pub total_connections: u64,
518    pub active_connections: u64,
519    pub total_messages_sent: u64,
520    pub total_messages_received: u64,
521    pub total_bytes_sent: u64,
522    pub total_bytes_received: u64,
523    pub connection_errors: u64,
524    pub message_errors: u64,
525}
526
527#[cfg(feature = "pyo3")]
528#[pymethods]
529impl DMSCWSServerStats {
530    #[getter]
531    fn get_total_connections(&self) -> u64 {
532        self.total_connections
533    }
534    
535    #[getter]
536    fn get_active_connections(&self) -> u64 {
537        self.active_connections
538    }
539    
540    #[getter]
541    fn get_total_messages_sent(&self) -> u64 {
542        self.total_messages_sent
543    }
544    
545    #[getter]
546    fn get_total_messages_received(&self) -> u64 {
547        self.total_messages_received
548    }
549    
550    #[getter]
551    fn get_total_bytes_sent(&self) -> u64 {
552        self.total_bytes_sent
553    }
554    
555    #[getter]
556    fn get_total_bytes_received(&self) -> u64 {
557        self.total_bytes_received
558    }
559    
560    #[getter]
561    fn get_connection_errors(&self) -> u64 {
562        self.connection_errors
563    }
564    
565    #[getter]
566    fn get_message_errors(&self) -> u64 {
567        self.message_errors
568    }
569}
570
571impl DMSCWSServerStats {
572    pub fn new() -> Self {
573        Self {
574            total_connections: 0,
575            active_connections: 0,
576            total_messages_sent: 0,
577            total_messages_received: 0,
578            total_bytes_sent: 0,
579            total_bytes_received: 0,
580            connection_errors: 0,
581            message_errors: 0,
582        }
583    }
584
585    pub fn record_connection(&mut self) {
586        self.total_connections += 1;
587        self.active_connections += 1;
588    }
589
590    pub fn record_disconnection(&mut self) {
591        if self.active_connections > 0 {
592            self.active_connections -= 1;
593        }
594    }
595
596    pub fn record_message_sent(&mut self, size: usize) {
597        self.total_messages_sent += 1;
598        self.total_bytes_sent += size as u64;
599    }
600
601    pub fn record_message_received(&mut self, size: usize) {
602        self.total_messages_received += 1;
603        self.total_bytes_received += size as u64;
604    }
605
606    pub fn record_connection_error(&mut self) {
607        self.connection_errors += 1;
608        if self.active_connections > 0 {
609            self.active_connections -= 1;
610        }
611    }
612
613    pub fn record_message_error(&mut self) {
614        self.message_errors += 1;
615    }
616}
617
618impl Default for DMSCWSServerStats {
619    fn default() -> Self {
620        Self::new()
621    }
622}