1use super::*;
21use uuid::Uuid;
22use tokio::sync::mpsc;
23use tokio::time::Duration;
24use futures::StreamExt;
25use tungstenite::Message;
26
27#[cfg(feature = "pyo3")]
28#[allow(unused_imports)]
29use pyo3::prelude::*;
30
31pub struct DMSCWSServer {
32 config: DMSCWSServerConfig,
33 stats: Arc<RwLock<DMSCWSServerStats>>,
34 session_manager: Arc<DMSCWSSessionManager>,
35 event_tx: Arc<RwLock<Option<broadcast::Sender<DMSCWSEvent>>>>,
36 shutdown_tx: Option<mpsc::Sender<()>>,
37 running: Arc<RwLock<bool>>,
38 handler: Arc<RwLock<Option<Arc<dyn DMSCWSSessionHandler>>>>,
39}
40
41#[cfg(feature = "pyo3")]
42#[pyclass]
43pub struct DMSCWSServerPy {
44 inner: DMSCWSServer,
45}
46
47#[cfg(feature = "pyo3")]
48#[pymethods]
49impl DMSCWSServerPy {
50 #[new]
51 fn new(config: DMSCWSServerConfig) -> Self {
52 Self {
53 inner: DMSCWSServer::new(config),
54 }
55 }
56
57 fn get_stats(&self) -> DMSCWSServerStats {
58 self.inner.get_stats()
59 }
60
61 fn is_running(&self) -> bool {
62 tokio::runtime::Handle::try_current()
63 .map(|handle| handle.block_on(async { self.inner.is_running().await }))
64 .unwrap_or(false)
65 }
66
67 fn start(&mut self) -> PyResult<()> {
68 let rt = tokio::runtime::Runtime::new()
69 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
70
71 rt.block_on(async {
72 self.inner.start().await
73 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
74 }
75
76 fn stop(&mut self) -> PyResult<()> {
77 let rt = tokio::runtime::Runtime::new()
78 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
79
80 rt.block_on(async {
81 self.inner.stop().await
82 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
83 }
84
85 fn broadcast(&self, data: Vec<u8>) -> usize {
86 tokio::runtime::Handle::try_current()
87 .map(|handle| handle.block_on(async { self.inner.broadcast(&data).await.unwrap_or(0) }))
88 .unwrap_or(0)
89 }
90
91 fn get_active_session_count(&self) -> usize {
92 tokio::runtime::Handle::try_current()
93 .map(|handle| handle.block_on(async { self.inner.get_active_session_count().await }))
94 .unwrap_or(0)
95 }
96}
97
98impl DMSCWSServer {
99 pub fn new(config: DMSCWSServerConfig) -> Self {
100 Self {
101 config,
102 stats: Arc::new(RwLock::new(DMSCWSServerStats::new())),
103 session_manager: Arc::new(DMSCWSSessionManager::new(1000)),
104 event_tx: Arc::new(RwLock::new(None)),
105 shutdown_tx: None,
106 running: Arc::new(RwLock::new(false)),
107 handler: Arc::new(RwLock::new(None)),
108 }
109 }
110
111 pub async fn set_handler<H: DMSCWSSessionHandler + 'static>(&self, handler: H) {
112 *self.handler.write().await = Some(Arc::new(handler));
113 }
114
115 pub async fn start(&mut self) -> DMSCResult<()> {
116 let addr: SocketAddr = format!("{}:{}", self.config.addr, self.config.port)
117 .parse()
118 .map_err(|e| WSError::Server {
119 message: format!("Invalid address: {}", e)
120 })?;
121
122 let listener = TcpListener::bind(&addr)
123 .await
124 .map_err(|e| WSError::Server {
125 message: format!("Failed to bind: {}", e)
126 })?;
127
128 let (event_tx, _) = broadcast::channel(100);
129 *self.event_tx.write().await = Some(event_tx);
130
131 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
132 self.shutdown_tx = Some(shutdown_tx);
133
134 let running = self.running.clone();
135 let stats = self.stats.clone();
136 let session_manager = self.session_manager.clone();
137 let handler = self.handler.clone();
138 let config = self.config.clone();
139
140 *running.write().await = true;
141
142 tokio::spawn(async move {
143 Self::accept_connections(
144 listener,
145 session_manager,
146 stats,
147 handler,
148 config,
149 shutdown_rx,
150 running,
151 ).await;
152 });
153
154 tracing::info!("WebSocket server started on {}", addr);
155 Ok(())
156 }
157
158 async fn accept_connections(
159 listener: TcpListener,
160 session_manager: Arc<DMSCWSSessionManager>,
161 stats: Arc<RwLock<DMSCWSServerStats>>,
162 handler: Arc<RwLock<Option<Arc<dyn DMSCWSSessionHandler>>>>,
163 config: DMSCWSServerConfig,
164 mut shutdown_rx: mpsc::Receiver<()>,
165 running: Arc<RwLock<bool>>,
166 ) {
167 let mut shutdown = false;
168
169 while !shutdown {
170 let result = listener.accept().await;
171
172 if shutdown {
173 break;
174 }
175
176 match result {
177 Ok((stream, remote_addr)) => {
178 let session_id = Uuid::new_v4().to_string();
179 let remote_addr_str = remote_addr.to_string();
180
181 tracing::info!("New WebSocket connection: {} (session: {})", remote_addr_str, session_id);
182
183 match tokio_tungstenite::accept_async(stream).await {
184 Ok(ws_stream) => {
185 let (_sender, receiver) = ws_stream.split();
186 let (tx, rx) = mpsc::channel(100);
187
188 let session = Arc::new(DMSCWSSession::new(
189 session_id.clone(),
190 tx,
191 receiver,
192 remote_addr_str.clone(),
193 ));
194
195 if session_manager.add_session(session.clone()).await.is_ok() {
196 stats.write().await.record_connection();
197
198 let handler_clone = handler.clone();
199 let session_manager_clone = session_manager.clone();
200 let stats_clone = stats.clone();
201
202 tokio::spawn(async move {
203 Self::handle_session(
204 session.clone(),
205 rx,
206 handler_clone,
207 session_manager_clone,
208 stats_clone,
209 ).await;
210 });
211 } else {
212 tracing::trace!("Failed to add session: {}", session_id);
213 }
214 }
215 Err(e) => {
216 tracing::error!("WebSocket upgrade failed: {}", e);
217 stats.write().await.record_connection_error();
218 }
219 }
220 }
221 Err(e) => {
222 tracing::error!("Failed to accept connection: {}", e);
223 stats.write().await.record_connection_error();
224 }
225 }
226
227 tokio::time::sleep(Duration::from_secs(config.heartbeat_interval)).await;
228
229 if !*running.read().await {
230 break;
231 }
232
233 let _timeout = Duration::from_secs(config.heartbeat_timeout);
234 let sessions = session_manager.get_all_sessions().await;
235 for session_info in sessions {
236 let last_heartbeat_time = chrono::DateTime::from_timestamp(session_info.last_heartbeat as i64, 0)
237 .unwrap_or_else(|| chrono::Utc::now());
238 let elapsed = last_heartbeat_time.signed_duration_since(chrono::Utc::now());
239 let elapsed_secs = elapsed.num_seconds() as u64;
240
241 if elapsed_secs > config.heartbeat_timeout {
242 if let Some(session) = session_manager.get_session(&session_info.session_id).await {
243 let _ = session.close().await;
244 }
245 }
246 }
247
248 if shutdown_rx.try_recv().is_ok() {
249 shutdown = true;
250 }
251 }
252
253 tracing::info!("WebSocket server stopped");
254 }
255
256 async fn handle_session(
257 session: Arc<DMSCWSSession>,
258 mut rx: mpsc::Receiver<std::result::Result<Message, tokio_tungstenite::tungstenite::Error>>,
259 handler: Arc<RwLock<Option<Arc<dyn DMSCWSSessionHandler>>>>,
260 session_manager: Arc<DMSCWSSessionManager>,
261 stats: Arc<RwLock<DMSCWSServerStats>>,
262 ) {
263 let session_id = session.id.clone();
264
265 while let Some(message_result) = rx.recv().await {
266 match message_result {
267 Ok(message) => {
268 match message {
269 Message::Binary(data) => {
270 stats.write().await.record_message_received(data.len());
271
272 let handler_read = handler.read().await;
273 if let Some(handler) = &*handler_read {
274 if let Ok(response) = handler.on_message(&session_id, &data).await {
275 if session.send(&response).await.is_err() {
276 break;
277 }
278 stats.write().await.record_message_sent(response.len());
279 }
280 } else {
281 if session.send(&data).await.is_err() {
282 break;
283 }
284 stats.write().await.record_message_sent(data.len());
285 }
286 }
287 Message::Text(text) => {
288 let data = text.into_bytes();
289 stats.write().await.record_message_received(data.len());
290
291 let handler_read = handler.read().await;
292 if let Some(handler) = &*handler_read {
293 if let Ok(response) = handler.on_message(&session_id, &data).await {
294 if session.send(&response).await.is_err() {
295 break;
296 }
297 }
298 }
299 }
300 Message::Ping(ping_data) => {
301 if session.send(&ping_data).await.is_err() {
302 break;
303 }
304 }
305 Message::Pong(_) => {}
306 Message::Close(_) => {
307 break;
308 }
309 _ => {}
310 }
311 }
312 Err(e) => {
313 tracing::error!("WebSocket error for session {}: {}", session_id, e);
314 stats.write().await.record_message_error();
315 break;
316 }
317 }
318 }
319
320 session_manager.remove_session(&session_id).await;
321 stats.write().await.record_disconnection();
322
323 let handler_read = handler.read().await;
324 if let Some(handler) = &*handler_read {
325 let _ = handler.on_disconnect(&session_id).await;
326 }
327 }
328
329 pub async fn stop(&mut self) -> DMSCResult<()> {
330 *self.running.write().await = false;
331
332 let sessions = self.session_manager.get_all_sessions().await;
333 for session_info in sessions {
334 if let Some(session) = self.session_manager.get_session(&session_info.session_id).await {
335 let _ = session.close().await;
336 }
337 }
338
339 if let Some(tx) = self.shutdown_tx.take() {
340 tx.send(()).await.map_err(|e| WSError::Server {
341 message: format!("Shutdown error: {}", e)
342 })?;
343 }
344
345 tracing::info!("WebSocket server stopped");
346 Ok(())
347 }
348
349 pub fn get_stats(&self) -> DMSCWSServerStats {
350 self.stats.try_read()
351 .map(|guard| guard.clone())
352 .unwrap_or_else(|_| DMSCWSServerStats::new())
353 }
354
355 pub async fn get_session_info(&self, session_id: &str) -> Option<DMSCWSSessionInfo> {
356 self.session_manager.get_session(session_id).await
357 .map(|s| s.get_info())
358 }
359
360 pub async fn broadcast(&self, data: &[u8]) -> DMSCResult<usize> {
361 let count = self.session_manager.broadcast(data).await?;
362 self.stats.write().await.record_message_sent(data.len() * count);
363 Ok(count)
364 }
365
366 pub async fn is_running(&self) -> bool {
367 *self.running.read().await
368 }
369
370 pub async fn get_active_session_count(&self) -> usize {
371 self.session_manager.get_session_count().await
372 }
373}
374
375impl Clone for DMSCWSServer {
376 fn clone(&self) -> Self {
377 Self {
378 config: self.config.clone(),
379 stats: self.stats.clone(),
380 session_manager: self.session_manager.clone(),
381 event_tx: self.event_tx.clone(),
382 shutdown_tx: None,
383 running: self.running.clone(),
384 handler: self.handler.clone(),
385 }
386 }
387}
388
389impl Default for DMSCWSServer {
390 fn default() -> Self {
391 Self::new(DMSCWSServerConfig::default())
392 }
393}