dmsc/ws/
client.rs

1//! Copyright 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18//! # WebSocket Client Implementation
19
20use super::*;
21use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
22
23#[cfg(feature = "pyo3")]
24#[allow(unused_imports)]
25use pyo3::prelude::*;
26
27#[derive(Debug, Clone)]
28#[cfg_attr(feature = "pyo3", pyclass)]
29pub struct DMSCWSClientConfig {
30    pub heartbeat_interval: u64,
31    pub heartbeat_timeout: u64,
32    pub max_message_size: usize,
33    pub connect_timeout: u64,
34    pub auto_reconnect: bool,
35    pub reconnect_interval: u64,
36}
37
38#[cfg(feature = "pyo3")]
39#[pymethods]
40impl DMSCWSClientConfig {
41    #[new]
42    fn new() -> Self {
43        Self::default()
44    }
45    
46    #[getter]
47    fn get_heartbeat_interval(&self) -> u64 {
48        self.heartbeat_interval
49    }
50    
51    #[setter]
52    fn set_heartbeat_interval(&mut self, val: u64) {
53        self.heartbeat_interval = val;
54    }
55    
56    #[getter]
57    fn get_heartbeat_timeout(&self) -> u64 {
58        self.heartbeat_timeout
59    }
60    
61    #[setter]
62    fn set_heartbeat_timeout(&mut self, val: u64) {
63        self.heartbeat_timeout = val;
64    }
65    
66    #[getter]
67    fn get_max_message_size(&self) -> usize {
68        self.max_message_size
69    }
70    
71    #[setter]
72    fn set_max_message_size(&mut self, val: usize) {
73        self.max_message_size = val;
74    }
75    
76    #[getter]
77    fn get_connect_timeout(&self) -> u64 {
78        self.connect_timeout
79    }
80    
81    #[setter]
82    fn set_connect_timeout(&mut self, val: u64) {
83        self.connect_timeout = val;
84    }
85    
86    #[getter]
87    fn get_auto_reconnect(&self) -> bool {
88        self.auto_reconnect
89    }
90    
91    #[setter]
92    fn set_auto_reconnect(&mut self, val: bool) {
93        self.auto_reconnect = val;
94    }
95    
96    #[getter]
97    fn get_reconnect_interval(&self) -> u64 {
98        self.reconnect_interval
99    }
100    
101    #[setter]
102    fn set_reconnect_interval(&mut self, val: u64) {
103        self.reconnect_interval = val;
104    }
105}
106
107impl Default for DMSCWSClientConfig {
108    fn default() -> Self {
109        Self {
110            heartbeat_interval: 30,
111            heartbeat_timeout: 60,
112            max_message_size: 65536,
113            connect_timeout: 10,
114            auto_reconnect: false,
115            reconnect_interval: 5,
116        }
117    }
118}
119
120#[derive(Debug, Clone)]
121#[cfg_attr(feature = "pyo3", pyclass)]
122pub struct DMSCWSClientStats {
123    pub total_connections: u64,
124    pub total_messages_sent: u64,
125    pub total_messages_received: u64,
126    pub total_bytes_sent: u64,
127    pub total_bytes_received: u64,
128    pub connection_errors: u64,
129    pub message_errors: u64,
130    pub last_connected_at: Option<u64>,
131}
132
133#[cfg(feature = "pyo3")]
134#[pymethods]
135impl DMSCWSClientStats {
136    #[getter]
137    fn get_total_connections(&self) -> u64 {
138        self.total_connections
139    }
140    
141    #[getter]
142    fn get_total_messages_sent(&self) -> u64 {
143        self.total_messages_sent
144    }
145    
146    #[getter]
147    fn get_total_messages_received(&self) -> u64 {
148        self.total_messages_received
149    }
150    
151    #[getter]
152    fn get_total_bytes_sent(&self) -> u64 {
153        self.total_bytes_sent
154    }
155    
156    #[getter]
157    fn get_total_bytes_received(&self) -> u64 {
158        self.total_bytes_received
159    }
160    
161    #[getter]
162    fn get_connection_errors(&self) -> u64 {
163        self.connection_errors
164    }
165    
166    #[getter]
167    fn get_message_errors(&self) -> u64 {
168        self.message_errors
169    }
170    
171    #[getter]
172    fn get_last_connected_at(&self) -> Option<u64> {
173        self.last_connected_at
174    }
175}
176
177impl DMSCWSClientStats {
178    pub fn new() -> Self {
179        Self {
180            total_connections: 0,
181            total_messages_sent: 0,
182            total_messages_received: 0,
183            total_bytes_sent: 0,
184            total_bytes_received: 0,
185            connection_errors: 0,
186            message_errors: 0,
187            last_connected_at: None,
188        }
189    }
190
191    fn record_connection(&mut self) {
192        self.total_connections += 1;
193        self.last_connected_at = Some(chrono::Utc::now().timestamp() as u64);
194    }
195
196    #[allow(dead_code)]
197    fn record_message_sent(&mut self, size: usize) {
198        self.total_messages_sent += 1;
199        self.total_bytes_sent += size as u64;
200    }
201
202    #[allow(dead_code)]
203    fn record_message_received(&mut self, size: usize) {
204        self.total_messages_received += 1;
205        self.total_bytes_received += size as u64;
206    }
207
208    #[allow(dead_code)]
209    fn record_connection_error(&mut self) {
210        self.connection_errors += 1;
211    }
212
213    #[allow(dead_code)]
214    fn record_message_error(&mut self) {
215        self.message_errors += 1;
216    }
217}
218
219impl Default for DMSCWSClientStats {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225pub struct DMSCWSClient {
226    config: DMSCWSClientConfig,
227    stats: Arc<RwLock<DMSCWSClientStats>>,
228    connected: Arc<RwLock<bool>>,
229    server_url: String,
230}
231
232#[cfg(feature = "pyo3")]
233#[pyclass]
234pub struct DMSCWSClientPy {
235    inner: DMSCWSClient,
236}
237
238#[cfg(feature = "pyo3")]
239#[pymethods]
240impl DMSCWSClientPy {
241    #[new]
242    fn new(server_url: String) -> Self {
243        Self {
244            inner: DMSCWSClient::new(server_url),
245        }
246    }
247
248    #[staticmethod]
249    fn with_config(server_url: String, config: DMSCWSClientConfig) -> Self {
250        Self {
251            inner: DMSCWSClient::with_config(server_url, config),
252        }
253    }
254
255    fn get_stats(&self) -> DMSCWSClientStats {
256        self.inner.get_stats()
257    }
258
259    fn is_connected(&self) -> bool {
260        tokio::runtime::Handle::try_current()
261            .map(|handle| handle.block_on(async { self.inner.is_connected().await }))
262            .unwrap_or(false)
263    }
264
265    fn connect(&mut self) -> PyResult<()> {
266        let rt = tokio::runtime::Runtime::new()
267            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
268        
269        rt.block_on(async {
270            self.inner.connect().await
271        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
272    }
273
274    fn send(&self, data: Vec<u8>) -> PyResult<()> {
275        let rt = tokio::runtime::Runtime::new()
276            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
277        
278        rt.block_on(async {
279            self.inner.send(&data).await
280        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
281    }
282
283    fn send_text(&self, text: String) -> PyResult<()> {
284        let rt = tokio::runtime::Runtime::new()
285            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
286        
287        rt.block_on(async {
288            self.inner.send_text(&text).await
289        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
290    }
291
292    fn close(&mut self) -> PyResult<()> {
293        let rt = tokio::runtime::Runtime::new()
294            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
295        
296        rt.block_on(async {
297            self.inner.close().await
298        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
299    }
300}
301
302impl DMSCWSClient {
303    pub fn new(server_url: String) -> Self {
304        Self::with_config(server_url, DMSCWSClientConfig::default())
305    }
306
307    pub fn with_config(server_url: String, config: DMSCWSClientConfig) -> Self {
308        Self {
309            config,
310            stats: Arc::new(RwLock::new(DMSCWSClientStats::new())),
311            connected: Arc::new(RwLock::new(false)),
312            server_url,
313        }
314    }
315
316    pub fn get_stats(&self) -> DMSCWSClientStats {
317        self.stats.try_read()
318            .map(|guard| guard.clone())
319            .unwrap_or_else(|_| DMSCWSClientStats::new())
320    }
321
322    pub async fn is_connected(&self) -> bool {
323        *self.connected.read().await
324    }
325
326    pub async fn connect(&mut self) -> DMSCResult<()> {
327        if *self.connected.read().await {
328            return Ok(());
329        }
330
331        let url = self.server_url.parse::<http::Uri>().map_err(|e| WSError::Connection {
332            message: format!("Invalid WebSocket URL: {}", e)
333        })?;
334
335        let ws_config = WebSocketConfig {
336            max_message_size: Some(self.config.max_message_size),
337            max_frame_size: Some(self.config.max_message_size),
338            max_write_buffer_size: self.config.max_message_size,
339            ..Default::default()
340        };
341
342        let (_ws_stream, _response) = tokio_tungstenite::connect_async_with_config(
343            &url,
344            Some(ws_config),
345            true,
346        )
347        .await
348        .map_err(|e| WSError::Connection {
349            message: format!("Failed to connect to WebSocket server: {}", e)
350        })?;
351
352        *self.connected.write().await = true;
353        self.stats.write().await.record_connection();
354
355        tracing::info!("WebSocket client connected to {}", self.server_url);
356        Ok(())
357    }
358
359    pub async fn send(&self, _data: &[u8]) -> DMSCResult<()> {
360        if !*self.connected.read().await {
361            return Err(WSError::Connection {
362                message: "Not connected to WebSocket server".to_string()
363            }.into());
364        }
365        Ok(())
366    }
367
368    pub async fn send_text(&self, _text: &str) -> DMSCResult<()> {
369        if !*self.connected.read().await {
370            return Err(WSError::Connection {
371                message: "Not connected to WebSocket server".to_string()
372            }.into());
373        }
374        Ok(())
375    }
376
377    pub async fn close(&mut self) -> DMSCResult<()> {
378        *self.connected.write().await = false;
379        tracing::info!("WebSocket client disconnected from {}", self.server_url);
380        Ok(())
381    }
382
383    pub async fn disconnect(&mut self) {
384        let _ = self.close().await;
385    }
386}
387
388impl Clone for DMSCWSClient {
389    fn clone(&self) -> Self {
390        Self {
391            config: self.config.clone(),
392            stats: self.stats.clone(),
393            connected: self.connected.clone(),
394            server_url: self.server_url.clone(),
395        }
396    }
397}
398
399impl Default for DMSCWSClient {
400    fn default() -> Self {
401        Self::new("ws://127.0.0.1:8080".to_string())
402    }
403}