1use super::*;
21use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
22
23#[cfg(feature = "pyo3")]
24#[allow(unused_imports)]
25use pyo3::prelude::*;
26
27#[derive(Debug, Clone)]
28#[cfg_attr(feature = "pyo3", pyclass)]
29pub struct DMSCWSClientConfig {
30 pub heartbeat_interval: u64,
31 pub heartbeat_timeout: u64,
32 pub max_message_size: usize,
33 pub connect_timeout: u64,
34 pub auto_reconnect: bool,
35 pub reconnect_interval: u64,
36}
37
38#[cfg(feature = "pyo3")]
39#[pymethods]
40impl DMSCWSClientConfig {
41 #[new]
42 fn new() -> Self {
43 Self::default()
44 }
45
46 #[getter]
47 fn get_heartbeat_interval(&self) -> u64 {
48 self.heartbeat_interval
49 }
50
51 #[setter]
52 fn set_heartbeat_interval(&mut self, val: u64) {
53 self.heartbeat_interval = val;
54 }
55
56 #[getter]
57 fn get_heartbeat_timeout(&self) -> u64 {
58 self.heartbeat_timeout
59 }
60
61 #[setter]
62 fn set_heartbeat_timeout(&mut self, val: u64) {
63 self.heartbeat_timeout = val;
64 }
65
66 #[getter]
67 fn get_max_message_size(&self) -> usize {
68 self.max_message_size
69 }
70
71 #[setter]
72 fn set_max_message_size(&mut self, val: usize) {
73 self.max_message_size = val;
74 }
75
76 #[getter]
77 fn get_connect_timeout(&self) -> u64 {
78 self.connect_timeout
79 }
80
81 #[setter]
82 fn set_connect_timeout(&mut self, val: u64) {
83 self.connect_timeout = val;
84 }
85
86 #[getter]
87 fn get_auto_reconnect(&self) -> bool {
88 self.auto_reconnect
89 }
90
91 #[setter]
92 fn set_auto_reconnect(&mut self, val: bool) {
93 self.auto_reconnect = val;
94 }
95
96 #[getter]
97 fn get_reconnect_interval(&self) -> u64 {
98 self.reconnect_interval
99 }
100
101 #[setter]
102 fn set_reconnect_interval(&mut self, val: u64) {
103 self.reconnect_interval = val;
104 }
105}
106
107impl Default for DMSCWSClientConfig {
108 fn default() -> Self {
109 Self {
110 heartbeat_interval: 30,
111 heartbeat_timeout: 60,
112 max_message_size: 65536,
113 connect_timeout: 10,
114 auto_reconnect: false,
115 reconnect_interval: 5,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
121#[cfg_attr(feature = "pyo3", pyclass)]
122pub struct DMSCWSClientStats {
123 pub total_connections: u64,
124 pub total_messages_sent: u64,
125 pub total_messages_received: u64,
126 pub total_bytes_sent: u64,
127 pub total_bytes_received: u64,
128 pub connection_errors: u64,
129 pub message_errors: u64,
130 pub last_connected_at: Option<u64>,
131}
132
133#[cfg(feature = "pyo3")]
134#[pymethods]
135impl DMSCWSClientStats {
136 #[getter]
137 fn get_total_connections(&self) -> u64 {
138 self.total_connections
139 }
140
141 #[getter]
142 fn get_total_messages_sent(&self) -> u64 {
143 self.total_messages_sent
144 }
145
146 #[getter]
147 fn get_total_messages_received(&self) -> u64 {
148 self.total_messages_received
149 }
150
151 #[getter]
152 fn get_total_bytes_sent(&self) -> u64 {
153 self.total_bytes_sent
154 }
155
156 #[getter]
157 fn get_total_bytes_received(&self) -> u64 {
158 self.total_bytes_received
159 }
160
161 #[getter]
162 fn get_connection_errors(&self) -> u64 {
163 self.connection_errors
164 }
165
166 #[getter]
167 fn get_message_errors(&self) -> u64 {
168 self.message_errors
169 }
170
171 #[getter]
172 fn get_last_connected_at(&self) -> Option<u64> {
173 self.last_connected_at
174 }
175}
176
177impl DMSCWSClientStats {
178 pub fn new() -> Self {
179 Self {
180 total_connections: 0,
181 total_messages_sent: 0,
182 total_messages_received: 0,
183 total_bytes_sent: 0,
184 total_bytes_received: 0,
185 connection_errors: 0,
186 message_errors: 0,
187 last_connected_at: None,
188 }
189 }
190
191 fn record_connection(&mut self) {
192 self.total_connections += 1;
193 self.last_connected_at = Some(chrono::Utc::now().timestamp() as u64);
194 }
195
196 #[allow(dead_code)]
197 fn record_message_sent(&mut self, size: usize) {
198 self.total_messages_sent += 1;
199 self.total_bytes_sent += size as u64;
200 }
201
202 #[allow(dead_code)]
203 fn record_message_received(&mut self, size: usize) {
204 self.total_messages_received += 1;
205 self.total_bytes_received += size as u64;
206 }
207
208 #[allow(dead_code)]
209 fn record_connection_error(&mut self) {
210 self.connection_errors += 1;
211 }
212
213 #[allow(dead_code)]
214 fn record_message_error(&mut self) {
215 self.message_errors += 1;
216 }
217}
218
219impl Default for DMSCWSClientStats {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225pub struct DMSCWSClient {
226 config: DMSCWSClientConfig,
227 stats: Arc<RwLock<DMSCWSClientStats>>,
228 connected: Arc<RwLock<bool>>,
229 server_url: String,
230}
231
232#[cfg(feature = "pyo3")]
233#[pyclass]
234pub struct DMSCWSClientPy {
235 inner: DMSCWSClient,
236}
237
238#[cfg(feature = "pyo3")]
239#[pymethods]
240impl DMSCWSClientPy {
241 #[new]
242 fn new(server_url: String) -> Self {
243 Self {
244 inner: DMSCWSClient::new(server_url),
245 }
246 }
247
248 #[staticmethod]
249 fn with_config(server_url: String, config: DMSCWSClientConfig) -> Self {
250 Self {
251 inner: DMSCWSClient::with_config(server_url, config),
252 }
253 }
254
255 fn get_stats(&self) -> DMSCWSClientStats {
256 self.inner.get_stats()
257 }
258
259 fn is_connected(&self) -> bool {
260 tokio::runtime::Handle::try_current()
261 .map(|handle| handle.block_on(async { self.inner.is_connected().await }))
262 .unwrap_or(false)
263 }
264
265 fn connect(&mut self) -> PyResult<()> {
266 let rt = tokio::runtime::Runtime::new()
267 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
268
269 rt.block_on(async {
270 self.inner.connect().await
271 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
272 }
273
274 fn send(&self, data: Vec<u8>) -> PyResult<()> {
275 let rt = tokio::runtime::Runtime::new()
276 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
277
278 rt.block_on(async {
279 self.inner.send(&data).await
280 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
281 }
282
283 fn send_text(&self, text: String) -> PyResult<()> {
284 let rt = tokio::runtime::Runtime::new()
285 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
286
287 rt.block_on(async {
288 self.inner.send_text(&text).await
289 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
290 }
291
292 fn close(&mut self) -> PyResult<()> {
293 let rt = tokio::runtime::Runtime::new()
294 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
295
296 rt.block_on(async {
297 self.inner.close().await
298 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
299 }
300}
301
302impl DMSCWSClient {
303 pub fn new(server_url: String) -> Self {
304 Self::with_config(server_url, DMSCWSClientConfig::default())
305 }
306
307 pub fn with_config(server_url: String, config: DMSCWSClientConfig) -> Self {
308 Self {
309 config,
310 stats: Arc::new(RwLock::new(DMSCWSClientStats::new())),
311 connected: Arc::new(RwLock::new(false)),
312 server_url,
313 }
314 }
315
316 pub fn get_stats(&self) -> DMSCWSClientStats {
317 self.stats.try_read()
318 .map(|guard| guard.clone())
319 .unwrap_or_else(|_| DMSCWSClientStats::new())
320 }
321
322 pub async fn is_connected(&self) -> bool {
323 *self.connected.read().await
324 }
325
326 pub async fn connect(&mut self) -> DMSCResult<()> {
327 if *self.connected.read().await {
328 return Ok(());
329 }
330
331 let url = self.server_url.parse::<http::Uri>().map_err(|e| WSError::Connection {
332 message: format!("Invalid WebSocket URL: {}", e)
333 })?;
334
335 let ws_config = WebSocketConfig {
336 max_message_size: Some(self.config.max_message_size),
337 max_frame_size: Some(self.config.max_message_size),
338 max_write_buffer_size: self.config.max_message_size,
339 ..Default::default()
340 };
341
342 let (_ws_stream, _response) = tokio_tungstenite::connect_async_with_config(
343 &url,
344 Some(ws_config),
345 true,
346 )
347 .await
348 .map_err(|e| WSError::Connection {
349 message: format!("Failed to connect to WebSocket server: {}", e)
350 })?;
351
352 *self.connected.write().await = true;
353 self.stats.write().await.record_connection();
354
355 tracing::info!("WebSocket client connected to {}", self.server_url);
356 Ok(())
357 }
358
359 pub async fn send(&self, _data: &[u8]) -> DMSCResult<()> {
360 if !*self.connected.read().await {
361 return Err(WSError::Connection {
362 message: "Not connected to WebSocket server".to_string()
363 }.into());
364 }
365 Ok(())
366 }
367
368 pub async fn send_text(&self, _text: &str) -> DMSCResult<()> {
369 if !*self.connected.read().await {
370 return Err(WSError::Connection {
371 message: "Not connected to WebSocket server".to_string()
372 }.into());
373 }
374 Ok(())
375 }
376
377 pub async fn close(&mut self) -> DMSCResult<()> {
378 *self.connected.write().await = false;
379 tracing::info!("WebSocket client disconnected from {}", self.server_url);
380 Ok(())
381 }
382
383 pub async fn disconnect(&mut self) {
384 let _ = self.close().await;
385 }
386}
387
388impl Clone for DMSCWSClient {
389 fn clone(&self) -> Self {
390 Self {
391 config: self.config.clone(),
392 stats: self.stats.clone(),
393 connected: self.connected.clone(),
394 server_url: self.server_url.clone(),
395 }
396 }
397}
398
399impl Default for DMSCWSClient {
400 fn default() -> Self {
401 Self::new("ws://127.0.0.1:8080".to_string())
402 }
403}