1use std::collections::HashMap;
69use std::fmt;
70use std::sync::Arc;
71use tokio::sync::RwLock;
72use tokio::time::{Duration, timeout};
73
74use crate::core::DMSCResult;
75
76#[derive(Debug, Clone)]
77#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
78pub struct DMSCMethodCall {
79 pub method_name: String,
80 pub params: Vec<u8>,
81 pub timeout_ms: u64,
82}
83
84#[cfg(feature = "pyo3")]
85#[pyo3::prelude::pymethods]
86impl DMSCMethodCall {
87 #[new]
88 fn py_new(method_name: String, params: Vec<u8>) -> Self {
89 Self::new(method_name, params)
90 }
91}
92
93impl DMSCMethodCall {
94 pub fn new(method_name: String, params: Vec<u8>) -> Self {
95 Self {
96 method_name,
97 params,
98 timeout_ms: 5000,
99 }
100 }
101
102 pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
103 self.timeout_ms = timeout_ms;
104 self
105 }
106}
107
108#[derive(Debug, Clone)]
109#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
110pub struct DMSCMethodResponse {
111 pub success: bool,
112 pub data: Vec<u8>,
113 pub error: String,
114 pub is_timeout: bool,
115}
116
117#[cfg(feature = "pyo3")]
118#[pyo3::prelude::pymethods]
119impl DMSCMethodResponse {
120 #[new]
121 fn py_new() -> Self {
122 Self::default()
123 }
124}
125
126impl DMSCMethodResponse {
127 pub fn new() -> Self {
128 Self {
129 success: false,
130 data: Vec::new(),
131 error: String::new(),
132 is_timeout: false,
133 }
134 }
135
136 pub fn success_data(data: Vec<u8>) -> Self {
137 Self {
138 success: true,
139 data,
140 error: String::new(),
141 is_timeout: false,
142 }
143 }
144
145 pub fn error_msg(msg: String) -> Self {
146 Self {
147 success: false,
148 data: Vec::new(),
149 error: msg,
150 is_timeout: false,
151 }
152 }
153
154 pub fn timeout() -> Self {
155 Self {
156 success: false,
157 data: Vec::new(),
158 error: "Method call timed out".to_string(),
159 is_timeout: true,
160 }
161 }
162
163 pub fn is_success(&self) -> bool {
164 self.success
165 }
166}
167
168impl Default for DMSCMethodResponse {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174type DMSCMethodHandler = Arc<dyn Fn(Vec<u8>) -> DMSCResult<Vec<u8>> + Send + Sync>;
175
176#[async_trait::async_trait]
177pub trait DMSCMethodHandlerAsync: Send + Sync {
178 async fn call(&self, params: Vec<u8>) -> DMSCMethodResponse;
179}
180
181struct SyncMethodHandler {
182 handler: DMSCMethodHandler,
183}
184
185#[async_trait::async_trait]
186impl DMSCMethodHandlerAsync for SyncMethodHandler {
187 async fn call(&self, params: Vec<u8>) -> DMSCMethodResponse {
188 match (self.handler)(params) {
189 Ok(data) => DMSCMethodResponse::success_data(data),
190 Err(e) => DMSCMethodResponse::error_msg(e.to_string()),
191 }
192 }
193}
194
195#[derive(Clone)]
196pub struct DMSCMethodRegistration {
197 name: String,
198 handler: Arc<dyn DMSCMethodHandlerAsync>,
199}
200
201impl DMSCMethodRegistration {
202 pub fn new<S: Into<String>>(
203 name: S,
204 handler: Arc<dyn DMSCMethodHandlerAsync>,
205 ) -> Self {
206 Self {
207 name: name.into(),
208 handler,
209 }
210 }
211
212 pub fn name(&self) -> &str {
213 &self.name
214 }
215
216 pub async fn call(&self, params: Vec<u8>, timeout_ms: u64) -> DMSCMethodResponse {
217 if timeout_ms == 0 {
218 self.handler.call(params).await
219 } else {
220 match timeout(Duration::from_millis(timeout_ms), self.handler.call(params)).await {
221 Ok(response) => response,
222 Err(_) => DMSCMethodResponse::timeout(),
223 }
224 }
225 }
226}
227
228#[derive(Clone)]
229#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
230pub struct DMSCModuleEndpoint {
231 module_name: String,
232 methods: Arc<RwLock<HashMap<String, DMSCMethodRegistration>>>,
233}
234
235#[cfg(feature = "pyo3")]
236#[pyo3::prelude::pymethods]
237impl DMSCModuleEndpoint {
238 #[new]
239 fn py_new(module_name: String) -> Self {
240 Self::new(&module_name)
241 }
242
243 #[pyo3(name = "get_module_name")]
244 fn py_get_module_name(&self) -> String {
245 self.module_name.clone()
246 }
247
248 #[pyo3(name = "list_methods")]
249 fn py_list_methods(&self) -> Vec<String> {
250 let methods = self.methods.blocking_read();
251 methods.keys().cloned().collect()
252 }
253}
254
255impl DMSCModuleEndpoint {
256 pub fn new(module_name: &str) -> Self {
257 Self {
258 module_name: module_name.to_string(),
259 methods: Arc::new(RwLock::new(HashMap::new())),
260 }
261 }
262
263 pub fn module_name(&self) -> &str {
264 &self.module_name
265 }
266
267 pub fn register_method<H>(&self, name: &str, handler: H) -> &Self
268 where
269 H: Fn(Vec<u8>) -> DMSCResult<Vec<u8>> + Send + Sync + 'static,
270 {
271 let registration = DMSCMethodRegistration::new(
272 name,
273 Arc::new(SyncMethodHandler {
274 handler: Arc::new(handler),
275 }),
276 );
277 let mut methods = self.methods.blocking_write();
278 methods.insert(name.to_string(), registration);
279 self
280 }
281
282 pub async fn register_method_async<H>(&self, name: &str, handler: H) -> &Self
283 where
284 H: Fn(Vec<u8>) -> DMSCResult<Vec<u8>> + Send + Sync + 'static,
285 {
286 self.register_method(name, handler)
287 }
288
289 pub async fn get_method(&self, name: &str) -> Option<DMSCMethodRegistration> {
290 let methods = self.methods.read().await;
291 methods.get(name).cloned()
292 }
293
294 pub async fn list_methods(&self) -> Vec<String> {
295 let methods = self.methods.read().await;
296 methods.keys().cloned().collect()
297 }
298}
299
300#[derive(Clone)]
301#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
302pub struct DMSCModuleRPC {
303 endpoints: Arc<RwLock<HashMap<String, Arc<DMSCModuleEndpoint>>>>,
304 default_timeout: Duration,
305}
306
307impl DMSCModuleRPC {
308 pub fn new() -> Self {
309 Self {
310 endpoints: Arc::new(RwLock::new(HashMap::new())),
311 default_timeout: Duration::from_millis(5000),
312 }
313 }
314
315 pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
316 self.default_timeout = timeout;
317 self
318 }
319
320 pub async fn register_endpoint(&self, endpoint: DMSCModuleEndpoint) {
321 let mut endpoints = self.endpoints.write().await;
322 endpoints.insert(endpoint.module_name().to_string(), Arc::new(endpoint));
323 }
324
325 pub async fn unregister_endpoint(&self, module_name: &str) {
326 let mut endpoints = self.endpoints.write().await;
327 endpoints.remove(module_name);
328 }
329
330 pub async fn get_endpoint(&self, module_name: &str) -> Option<Arc<DMSCModuleEndpoint>> {
331 let endpoints = self.endpoints.read().await;
332 endpoints.get(module_name).cloned()
333 }
334
335 pub async fn call_method(
336 &self,
337 module_name: &str,
338 method_name: &str,
339 params: Vec<u8>,
340 timeout_ms: Option<u64>,
341 ) -> DMSCMethodResponse {
342 let endpoint = self.get_endpoint(module_name).await;
343
344 if let Some(ep) = endpoint {
345 if let Some(method) = ep.get_method(method_name).await {
346 let timeout = timeout_ms.unwrap_or(self.default_timeout.as_millis() as u64);
347 return method.call(params, timeout).await;
348 }
349 return DMSCMethodResponse::error_msg(format!(
350 "Method '{}' not found on module '{}'",
351 method_name, module_name
352 ));
353 }
354
355 DMSCMethodResponse::error_msg(format!(
356 "Module '{}' not found",
357 module_name
358 ))
359 }
360
361 pub async fn list_registered_modules(&self) -> Vec<String> {
362 let endpoints = self.endpoints.read().await;
363 endpoints.keys().cloned().collect()
364 }
365}
366
367impl Default for DMSCModuleRPC {
368 fn default() -> Self {
369 Self::new()
370 }
371}
372
373#[derive(Clone)]
374#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
375pub struct DMSCModuleClient {
376 rpc: Arc<DMSCModuleRPC>,
377}
378
379impl DMSCModuleClient {
380 pub fn new(rpc: Arc<DMSCModuleRPC>) -> Self {
381 Self { rpc }
382 }
383
384 pub async fn call(
385 &self,
386 module_name: &str,
387 method_name: &str,
388 params: Vec<u8>,
389 ) -> DMSCMethodResponse {
390 self.rpc.call_method(module_name, method_name, params, None).await
391 }
392
393 pub async fn call_with_timeout(
394 &self,
395 module_name: &str,
396 method_name: &str,
397 params: Vec<u8>,
398 timeout_ms: u64,
399 ) -> DMSCMethodResponse {
400 self.rpc
401 .call_method(module_name, method_name, params, Some(timeout_ms))
402 .await
403 }
404}
405
406impl fmt::Debug for DMSCModuleRPC {
407 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408 f.debug_struct("DMSCModuleRPC")
409 .field("default_timeout", &self.default_timeout)
410 .finish()
411 }
412}