1use crate::core::{DMSCResult, DMSCError};
21use async_trait::async_trait;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24use std::net::SocketAddr;
25use futures::stream::SplitStream;
26use tokio::net::TcpListener;
27use tokio::sync::broadcast;
28use std::collections::HashMap;
29use tungstenite::Message;
30
31#[cfg(feature = "pyo3")]
32use pyo3::prelude::*;
33
34#[cfg(feature = "websocket")]
35mod server;
36
37#[cfg(feature = "websocket")]
38mod client;
39
40#[cfg(feature = "websocket")]
41pub use server::DMSCWSServer;
42
43#[cfg(feature = "websocket")]
44pub use client::DMSCWSClient;
45
46#[cfg(feature = "websocket")]
47pub use client::DMSCWSClientConfig;
48
49#[cfg(feature = "websocket")]
50pub use client::DMSCWSClientStats;
51
52#[cfg(all(feature = "websocket", feature = "pyo3"))]
53pub use server::DMSCWSServerPy;
54
55#[cfg(all(feature = "websocket", feature = "pyo3"))]
56pub use client::DMSCWSClientPy;
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59#[cfg_attr(feature = "pyo3", pyclass)]
60pub struct DMSCWSServerConfig {
61 pub addr: String,
62 pub port: u16,
63 pub max_connections: usize,
64 pub heartbeat_interval: u64,
65 pub heartbeat_timeout: u64,
66 pub max_message_size: usize,
67 pub ping_interval: u64,
68}
69
70#[cfg(feature = "pyo3")]
71#[pymethods]
72impl DMSCWSServerConfig {
73 #[new]
74 fn new() -> Self {
75 Self::default()
76 }
77
78 #[getter]
79 fn get_addr(&self) -> String {
80 self.addr.clone()
81 }
82
83 #[setter]
84 fn set_addr(&mut self, addr: String) {
85 self.addr = addr;
86 }
87
88 #[getter]
89 fn get_port(&self) -> u16 {
90 self.port
91 }
92
93 #[setter]
94 fn set_port(&mut self, port: u16) {
95 self.port = port;
96 }
97}
98
99impl Default for DMSCWSServerConfig {
100 fn default() -> Self {
101 Self {
102 addr: "127.0.0.1".to_string(),
103 port: 8080,
104 max_connections: 1000,
105 heartbeat_interval: 30,
106 heartbeat_timeout: 60,
107 max_message_size: 65536,
108 ping_interval: 25,
109 }
110 }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114#[cfg_attr(feature = "pyo3", pyclass)]
115pub enum DMSCWSEvent {
116 Connected { session_id: String },
117 Disconnected { session_id: String },
118 Message { session_id: String, data: Vec<u8> },
119 Error { session_id: String, message: String },
120}
121
122#[derive(Debug, Clone)]
123#[cfg_attr(feature = "pyo3", pyclass)]
124pub struct DMSCWSSessionInfo {
125 pub session_id: String,
126 pub remote_addr: String,
127 pub connected_at: u64,
128 pub messages_sent: u64,
129 pub messages_received: u64,
130 pub bytes_sent: u64,
131 pub bytes_received: u64,
132 pub is_active: bool,
133 pub last_heartbeat: u64,
134}
135
136#[cfg(feature = "pyo3")]
137#[pymethods]
138impl DMSCWSSessionInfo {
139 #[getter]
140 fn get_session_id(&self) -> String {
141 self.session_id.clone()
142 }
143
144 #[getter]
145 fn get_remote_addr(&self) -> String {
146 self.remote_addr.clone()
147 }
148
149 #[getter]
150 fn get_connected_at(&self) -> u64 {
151 self.connected_at
152 }
153
154 #[getter]
155 fn get_messages_sent(&self) -> u64 {
156 self.messages_sent
157 }
158
159 #[getter]
160 fn get_messages_received(&self) -> u64 {
161 self.messages_received
162 }
163
164 #[getter]
165 fn get_bytes_sent(&self) -> u64 {
166 self.bytes_sent
167 }
168
169 #[getter]
170 fn get_bytes_received(&self) -> u64 {
171 self.bytes_received
172 }
173
174 #[getter]
175 fn get_is_active(&self) -> bool {
176 self.is_active
177 }
178
179 #[getter]
180 fn get_last_heartbeat(&self) -> u64 {
181 self.last_heartbeat
182 }
183}
184
185impl Default for DMSCWSSessionInfo {
186 fn default() -> Self {
187 Self {
188 session_id: String::new(),
189 remote_addr: String::new(),
190 connected_at: 0,
191 messages_sent: 0,
192 messages_received: 0,
193 bytes_sent: 0,
194 bytes_received: 0,
195 is_active: false,
196 last_heartbeat: 0,
197 }
198 }
199}
200
201#[async_trait]
202pub trait DMSCWSSessionHandler: Send + Sync {
203 async fn on_connect(&self, session_id: &str, remote_addr: &str) -> DMSCResult<()>;
204 async fn on_disconnect(&self, session_id: &str) -> DMSCResult<()>;
205 async fn on_message(&self, session_id: &str, data: &[u8]) -> DMSCResult<Vec<u8>>;
206 async fn on_error(&self, session_id: &str, error: &str) -> DMSCResult<()>;
207}
208
209#[derive(Debug, thiserror::Error)]
210pub enum WSError {
211 #[error("Server error: {message}")]
212 Server { message: String },
213 #[error("Session error: {message}")]
214 Session { message: String },
215 #[error("Connection error: {message}")]
216 Connection { message: String },
217 #[error("Message too large: {size} bytes (max: {max_size})")]
218 MessageTooLarge { size: usize, max_size: usize },
219 #[error("Session not found: {session_id}")]
220 SessionNotFound { session_id: String },
221 #[error("Invalid message format")]
222 InvalidFormat,
223}
224
225impl From<WSError> for DMSCError {
226 fn from(error: WSError) -> Self {
227 DMSCError::Other(format!("WebSocket error: {}", error))
228 }
229}
230
231pub struct DMSCWSSession {
232 pub id: String,
233 pub sender: tokio::sync::mpsc::Sender<std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>,
234 pub receiver: SplitStream<tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>>,
235 pub info: Arc<RwLock<DMSCWSSessionInfo>>,
236}
237
238impl DMSCWSSession {
239 pub fn new(
240 id: String,
241 sender: tokio::sync::mpsc::Sender<std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>,
242 receiver: SplitStream<tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>>,
243 remote_addr: String,
244 ) -> Self {
245 let now = chrono::Utc::now().timestamp() as u64;
246 let session_id = id.clone();
247 Self {
248 id,
249 sender,
250 receiver,
251 info: Arc::new(RwLock::new(DMSCWSSessionInfo {
252 session_id,
253 remote_addr,
254 connected_at: now,
255 messages_sent: 0,
256 messages_received: 0,
257 bytes_sent: 0,
258 bytes_received: 0,
259 is_active: true,
260 last_heartbeat: now,
261 })),
262 }
263 }
264
265 pub async fn send(&self, data: &[u8]) -> DMSCResult<()> {
266 let message = Message::Binary(data.to_vec());
267
268 self.sender.send(Ok(message))
269 .await
270 .map_err(|e| WSError::Session {
271 message: format!("Failed to send message: {}", e)
272 })?;
273
274 let mut info = self.info.write().await;
275 info.messages_sent += 1;
276 info.bytes_sent += data.len() as u64;
277
278 Ok(())
279 }
280
281 pub async fn send_text(&self, text: &str) -> DMSCResult<()> {
282 let message = Message::Text(text.to_string());
283
284 self.sender.send(Ok(message))
285 .await
286 .map_err(|e| WSError::Session {
287 message: format!("Failed to send message: {}", e)
288 })?;
289
290 let mut info = self.info.write().await;
291 info.messages_sent += 1;
292 info.bytes_sent += text.len() as u64;
293
294 Ok(())
295 }
296
297 pub async fn close(&self) -> DMSCResult<()> {
298 self.sender.send(Ok(Message::Close(None)))
299 .await
300 .map_err(|e| WSError::Session {
301 message: format!("Failed to close session: {}", e)
302 })?;
303
304 let mut info = self.info.write().await;
305 info.is_active = false;
306
307 Ok(())
308 }
309
310 pub fn get_info(&self) -> DMSCWSSessionInfo {
311 self.info.try_read()
312 .map(|guard| guard.clone())
313 .unwrap_or_else(|_| DMSCWSSessionInfo::default())
314 }
315}
316
317pub struct DMSCWSSessionManager {
318 sessions: Arc<RwLock<HashMap<String, Arc<DMSCWSSession>>>>,
319 max_connections: usize,
320}
321
322impl Clone for DMSCWSSessionManager {
323 fn clone(&self) -> Self {
324 Self {
325 sessions: self.sessions.clone(),
326 max_connections: self.max_connections,
327 }
328 }
329}
330
331impl DMSCWSSessionManager {
332 pub fn new(max_connections: usize) -> Self {
333 Self {
334 sessions: Arc::new(RwLock::new(HashMap::new())),
335 max_connections,
336 }
337 }
338
339 pub async fn add_session(&self, session: Arc<DMSCWSSession>) -> DMSCResult<()> {
340 let mut sessions = self.sessions.write().await;
341
342 if sessions.len() >= self.max_connections {
343 return Err(WSError::Session {
344 message: format!("Max connections reached: {}", self.max_connections)
345 }.into());
346 }
347
348 sessions.insert(session.id.clone(), session);
349 Ok(())
350 }
351
352 pub async fn remove_session(&self, session_id: &str) {
353 let mut sessions = self.sessions.write().await;
354 sessions.remove(session_id);
355 }
356
357 pub async fn get_session(&self, session_id: &str) -> Option<Arc<DMSCWSSession>> {
358 let sessions = self.sessions.read().await;
359 sessions.get(session_id).cloned()
360 }
361
362 pub async fn broadcast(&self, data: &[u8]) -> DMSCResult<usize> {
363 let sessions = self.sessions.read().await;
364 let mut count = 0;
365
366 for session in sessions.values() {
367 if session.send(data).await.is_ok() {
368 count += 1;
369 }
370 }
371
372 Ok(count)
373 }
374
375 pub async fn get_session_count(&self) -> usize {
376 self.sessions.read().await.len()
377 }
378
379 pub async fn get_all_sessions(&self) -> Vec<DMSCWSSessionInfo> {
380 let sessions = self.sessions.read().await;
381 sessions.values().map(|s| s.get_info()).collect()
382 }
383}
384
385#[cfg(feature = "pyo3")]
386#[pyclass]
387pub struct DMSCWSPythonHandler {
388 on_connect: Arc<Py<PyAny>>,
389 on_disconnect: Arc<Py<PyAny>>,
390 on_message: Arc<Py<PyAny>>,
391 on_error: Arc<Py<PyAny>>,
392}
393
394#[cfg(feature = "pyo3")]
395#[pymethods]
396impl DMSCWSPythonHandler {
397 #[new]
398 fn new(
399 on_connect: Py<PyAny>,
400 on_disconnect: Py<PyAny>,
401 on_message: Py<PyAny>,
402 on_error: Py<PyAny>,
403 ) -> Self {
404 Self {
405 on_connect: Arc::new(on_connect),
406 on_disconnect: Arc::new(on_disconnect),
407 on_message: Arc::new(on_message),
408 on_error: Arc::new(on_error),
409 }
410 }
411}
412
413#[cfg(feature = "pyo3")]
414#[async_trait]
415impl DMSCWSSessionHandler for DMSCWSPythonHandler {
416 async fn on_connect(&self, session_id: &str, remote_addr: &str) -> DMSCResult<()> {
417 let on_connect = Arc::clone(&self.on_connect);
418 let session_id = session_id.to_string();
419 let remote_addr = remote_addr.to_string();
420
421 tokio::task::spawn_blocking(move || {
422 Python::attach(|py| {
423 let handler = on_connect.clone_ref(py);
424 let _ = handler.call(py, (session_id, remote_addr), None);
425 });
426 }).await.ok();
427
428 Ok(())
429 }
430
431 async fn on_disconnect(&self, session_id: &str) -> DMSCResult<()> {
432 let on_disconnect = Arc::clone(&self.on_disconnect);
433 let session_id = session_id.to_string();
434
435 tokio::task::spawn_blocking(move || {
436 Python::attach(|py| {
437 let handler = on_disconnect.clone_ref(py);
438 let _ = handler.call(py, (session_id,), None);
439 });
440 }).await.ok();
441
442 Ok(())
443 }
444
445 async fn on_message(&self, session_id: &str, data: &[u8]) -> DMSCResult<Vec<u8>> {
446 let on_message = Arc::clone(&self.on_message);
447 let session_id = session_id.to_string();
448 let data_vec = data.to_vec();
449
450 let result = tokio::task::spawn_blocking(move || {
451 Python::attach(|py| {
452 let handler = on_message.clone_ref(py);
453 match handler.call(py, (session_id, data_vec), None) {
454 Ok(obj) => obj.extract::<Vec<u8>>(py).ok(),
455 Err(_) => None,
456 }
457 })
458 }).await.ok().flatten();
459
460 Ok(result.unwrap_or_default())
461 }
462
463 async fn on_error(&self, session_id: &str, error: &str) -> DMSCResult<()> {
464 let on_error = Arc::clone(&self.on_error);
465 let session_id = session_id.to_string();
466 let error = error.to_string();
467
468 tokio::task::spawn_blocking(move || {
469 Python::attach(|py| {
470 let handler = on_error.clone_ref(py);
471 let _ = handler.call(py, (session_id, error), None);
472 });
473 }).await.ok();
474
475 Ok(())
476 }
477}
478
479#[cfg(feature = "pyo3")]
480#[pyclass]
481pub struct DMSCWSSessionManagerPy {
482 manager: DMSCWSSessionManager,
483}
484
485#[cfg(feature = "pyo3")]
486#[pymethods]
487impl DMSCWSSessionManagerPy {
488 #[new]
489 fn new(max_connections: usize) -> Self {
490 Self {
491 manager: DMSCWSSessionManager::new(max_connections),
492 }
493 }
494
495 fn get_session_count(&self) -> usize {
496 tokio::runtime::Handle::try_current()
497 .map(|handle| handle.block_on(async { self.manager.get_session_count().await }))
498 .unwrap_or(0)
499 }
500
501 fn get_all_sessions(&self) -> Vec<DMSCWSSessionInfo> {
502 tokio::runtime::Handle::try_current()
503 .map(|handle| handle.block_on(async { self.manager.get_all_sessions().await }))
504 .unwrap_or_default()
505 }
506
507 fn broadcast(&self, data: Vec<u8>) -> usize {
508 tokio::runtime::Handle::try_current()
509 .map(|handle| handle.block_on(async { self.manager.broadcast(&data).await.unwrap_or(0) }))
510 .unwrap_or(0)
511 }
512}
513
514#[derive(Debug, Clone)]
515#[cfg_attr(feature = "pyo3", pyclass)]
516pub struct DMSCWSServerStats {
517 pub total_connections: u64,
518 pub active_connections: u64,
519 pub total_messages_sent: u64,
520 pub total_messages_received: u64,
521 pub total_bytes_sent: u64,
522 pub total_bytes_received: u64,
523 pub connection_errors: u64,
524 pub message_errors: u64,
525}
526
527#[cfg(feature = "pyo3")]
528#[pymethods]
529impl DMSCWSServerStats {
530 #[getter]
531 fn get_total_connections(&self) -> u64 {
532 self.total_connections
533 }
534
535 #[getter]
536 fn get_active_connections(&self) -> u64 {
537 self.active_connections
538 }
539
540 #[getter]
541 fn get_total_messages_sent(&self) -> u64 {
542 self.total_messages_sent
543 }
544
545 #[getter]
546 fn get_total_messages_received(&self) -> u64 {
547 self.total_messages_received
548 }
549
550 #[getter]
551 fn get_total_bytes_sent(&self) -> u64 {
552 self.total_bytes_sent
553 }
554
555 #[getter]
556 fn get_total_bytes_received(&self) -> u64 {
557 self.total_bytes_received
558 }
559
560 #[getter]
561 fn get_connection_errors(&self) -> u64 {
562 self.connection_errors
563 }
564
565 #[getter]
566 fn get_message_errors(&self) -> u64 {
567 self.message_errors
568 }
569}
570
571impl DMSCWSServerStats {
572 pub fn new() -> Self {
573 Self {
574 total_connections: 0,
575 active_connections: 0,
576 total_messages_sent: 0,
577 total_messages_received: 0,
578 total_bytes_sent: 0,
579 total_bytes_received: 0,
580 connection_errors: 0,
581 message_errors: 0,
582 }
583 }
584
585 pub fn record_connection(&mut self) {
586 self.total_connections += 1;
587 self.active_connections += 1;
588 }
589
590 pub fn record_disconnection(&mut self) {
591 if self.active_connections > 0 {
592 self.active_connections -= 1;
593 }
594 }
595
596 pub fn record_message_sent(&mut self, size: usize) {
597 self.total_messages_sent += 1;
598 self.total_bytes_sent += size as u64;
599 }
600
601 pub fn record_message_received(&mut self, size: usize) {
602 self.total_messages_received += 1;
603 self.total_bytes_received += size as u64;
604 }
605
606 pub fn record_connection_error(&mut self) {
607 self.connection_errors += 1;
608 if self.active_connections > 0 {
609 self.active_connections -= 1;
610 }
611 }
612
613 pub fn record_message_error(&mut self) {
614 self.message_errors += 1;
615 }
616}
617
618impl Default for DMSCWSServerStats {
619 fn default() -> Self {
620 Self::new()
621 }
622}