dmsc/grpc/
mod.rs

1//! Copyright 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18//! # gRPC Support
19
20use crate::core::DMSCResult;
21use async_trait::async_trait;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24#[cfg(feature = "grpc")]
25use std::collections::HashMap;
26
27#[cfg(feature = "pyo3")]
28use pyo3::prelude::*;
29
30#[cfg(feature = "grpc")]
31mod server;
32#[cfg(feature = "grpc")]
33mod client;
34
35#[cfg(feature = "grpc")]
36pub use server::DMSCGrpcServer;
37#[cfg(feature = "grpc")]
38pub use client::DMSCGrpcClient;
39
40#[cfg(all(feature = "grpc", feature = "pyo3"))]
41pub use server::DMSCGrpcServerPy;
42#[cfg(all(feature = "grpc", feature = "pyo3"))]
43pub use client::DMSCGrpcClientPy;
44
45#[derive(Debug, Clone, PartialEq, Eq)]
46#[cfg_attr(feature = "pyo3", pyclass)]
47pub struct DMSCGrpcConfig {
48    pub addr: String,
49    pub port: u16,
50    pub max_concurrent_requests: u32,
51    pub enable_tls: bool,
52    pub cert_path: Option<String>,
53    pub key_path: Option<String>,
54}
55
56#[cfg(feature = "pyo3")]
57#[pymethods]
58impl DMSCGrpcConfig {
59    #[new]
60    fn new() -> Self {
61        Self::default()
62    }
63    
64    #[getter]
65    fn get_addr(&self) -> String {
66        self.addr.clone()
67    }
68    
69    #[setter]
70    fn set_addr(&mut self, addr: String) {
71        self.addr = addr;
72    }
73    
74    #[getter]
75    fn get_port(&self) -> u16 {
76        self.port
77    }
78    
79    #[setter]
80    fn set_port(&mut self, port: u16) {
81        self.port = port;
82    }
83    
84    #[getter]
85    fn get_max_concurrent_requests(&self) -> u32 {
86        self.max_concurrent_requests
87    }
88    
89    #[setter]
90    fn set_max_concurrent_requests(&mut self, max_concurrent_requests: u32) {
91        self.max_concurrent_requests = max_concurrent_requests;
92    }
93    
94    #[getter]
95    fn get_enable_tls(&self) -> bool {
96        self.enable_tls
97    }
98    
99    #[setter]
100    fn set_enable_tls(&mut self, enable_tls: bool) {
101        self.enable_tls = enable_tls;
102    }
103    
104    #[getter]
105    fn get_cert_path(&self) -> Option<String> {
106        self.cert_path.clone()
107    }
108    
109    #[setter]
110    fn set_cert_path(&mut self, cert_path: Option<String>) {
111        self.cert_path = cert_path;
112    }
113    
114    #[getter]
115    fn get_key_path(&self) -> Option<String> {
116        self.key_path.clone()
117    }
118    
119    #[setter]
120    fn set_key_path(&mut self, key_path: Option<String>) {
121        self.key_path = key_path;
122    }
123}
124
125impl Default for DMSCGrpcConfig {
126    fn default() -> Self {
127        Self {
128            addr: "127.0.0.1".to_string(),
129            port: 50051,
130            max_concurrent_requests: 100,
131            enable_tls: false,
132            cert_path: None,
133            key_path: None,
134        }
135    }
136}
137
138#[cfg(feature = "grpc")]
139#[async_trait]
140pub trait DMSCGrpcService: Send + Sync {
141    async fn handle_request(&self, method: &str, data: &[u8]) -> DMSCResult<Vec<u8>>;
142    fn service_name(&self) -> &'static str;
143}
144
145#[cfg(feature = "grpc")]
146#[derive(Clone)]
147pub struct DMSCGrpcServiceRegistry {
148    pub services: Arc<RwLock<HashMap<String, Arc<dyn DMSCGrpcService>>>>,
149}
150
151#[cfg(feature = "grpc")]
152impl DMSCGrpcServiceRegistry {
153    pub fn new() -> Self {
154        Self {
155            services: Arc::new(RwLock::new(HashMap::new())),
156        }
157    }
158
159    pub fn register_service(&mut self, service: Arc<dyn DMSCGrpcService>) {
160        let name = service.service_name().to_string();
161        let mut services = self.services.blocking_write();
162        services.insert(name, service);
163    }
164
165    pub fn list_services(&self) -> Vec<String> {
166        let services = self.services.blocking_read();
167        services.keys().cloned().collect()
168    }
169}
170
171#[cfg(feature = "grpc")]
172impl Default for DMSCGrpcServiceRegistry {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178#[cfg(all(feature = "grpc", feature = "pyo3"))]
179#[pyclass]
180pub struct DMSCGrpcServiceRegistryPy {
181    registry: DMSCGrpcServiceRegistry,
182}
183
184#[cfg(all(feature = "grpc", feature = "pyo3"))]
185#[pymethods]
186impl DMSCGrpcServiceRegistryPy {
187    #[new]
188    fn new() -> Self {
189        Self {
190            registry: DMSCGrpcServiceRegistry::new(),
191        }
192    }
193    
194    fn register(&mut self, service_name: &str, handler: Py<PyAny>) {
195        let service = DMSCGrpcPythonService::new(service_name, handler);
196        self.registry.register_service(Arc::new(service));
197    }
198    
199    fn list_services(&self) -> Vec<String> {
200        self.registry.list_services()
201    }
202}
203
204#[cfg(all(feature = "grpc", feature = "pyo3"))]
205impl Default for DMSCGrpcServiceRegistryPy {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211#[derive(Debug, Clone)]
212#[cfg_attr(feature = "pyo3", pyclass)]
213pub struct DMSCGrpcStats {
214    pub requests_received: u64,
215    pub requests_completed: u64,
216    pub requests_failed: u64,
217    pub bytes_received: u64,
218    pub bytes_sent: u64,
219    pub active_connections: u64,
220}
221
222#[cfg(feature = "pyo3")]
223#[pymethods]
224impl DMSCGrpcStats {
225    #[getter]
226    fn get_requests_received(&self) -> u64 {
227        self.requests_received
228    }
229
230    #[getter]
231    fn get_requests_completed(&self) -> u64 {
232        self.requests_completed
233    }
234
235    #[getter]
236    fn get_requests_failed(&self) -> u64 {
237        self.requests_failed
238    }
239
240    #[getter]
241    fn get_bytes_received(&self) -> u64 {
242        self.bytes_received
243    }
244
245    #[getter]
246    fn get_bytes_sent(&self) -> u64 {
247        self.bytes_sent
248    }
249
250    #[getter]
251    fn get_active_connections(&self) -> u64 {
252        self.active_connections
253    }
254}
255
256impl DMSCGrpcStats {
257    pub fn new() -> Self {
258        Self {
259            requests_received: 0,
260            requests_completed: 0,
261            requests_failed: 0,
262            bytes_received: 0,
263            bytes_sent: 0,
264            active_connections: 0,
265        }
266    }
267
268    pub fn record_request(&mut self, size: usize) {
269        self.requests_received += 1;
270        self.bytes_received += size as u64;
271        self.active_connections += 1;
272    }
273
274    pub fn record_response(&mut self, size: usize) {
275        self.requests_completed += 1;
276        self.bytes_sent += size as u64;
277        if self.active_connections > 0 {
278            self.active_connections -= 1;
279        }
280    }
281
282    pub fn record_error(&mut self) {
283        self.requests_failed += 1;
284        if self.active_connections > 0 {
285            self.active_connections -= 1;
286        }
287    }
288}
289
290impl Default for DMSCGrpcStats {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296#[cfg(all(feature = "grpc", feature = "pyo3"))]
297#[pyclass]
298pub struct DMSCGrpcPythonService {
299    service_name: String,
300    handler: Py<PyAny>,
301}
302
303#[cfg(all(feature = "grpc", feature = "pyo3"))]
304impl DMSCGrpcPythonService {
305    pub fn new(service_name: &str, handler: Py<PyAny>) -> Self {
306        Self {
307            service_name: service_name.to_string(),
308            handler,
309        }
310    }
311}
312
313#[cfg(all(feature = "grpc", feature = "pyo3"))]
314#[async_trait]
315impl DMSCGrpcService for DMSCGrpcPythonService {
316    async fn handle_request(&self, method: &str, data: &[u8]) -> DMSCResult<Vec<u8>> {
317        let method_str = method.to_string();
318        let data_vec = data.to_vec();
319        
320        let result = pyo3::Python::attach(|py| {
321            self.handler.call1(py, (method_str, data_vec))
322        });
323        
324        match result {
325            Ok(obj) => {
326                let result_vec = pyo3::Python::attach(|py| {
327                    obj.extract::<Vec<u8>>(py)
328                });
329                match result_vec {
330                    Ok(bytes) => Ok(bytes),
331                    Err(e) => Err(DMSCError::Other(format!("Failed to extract response bytes: {:?}", e))),
332                }
333            }
334            Err(e) => Err(DMSCError::Other(format!("Python handler error: {:?}", e))),
335        }
336    }
337    
338    fn service_name(&self) -> &'static str {
339        Box::leak(self.service_name.clone().into_boxed_str())
340    }
341}
342
343#[derive(Debug, thiserror::Error)]
344pub enum GrpcError {
345    #[error("Server error: {message}")]
346    Server { message: String },
347    #[error("Client error: {message}")]
348    Client { message: String },
349    #[error("Service not found: {service_name}")]
350    ServiceNotFound { service_name: String },
351    #[error("Connection failed: {message}")]
352    ConnectionFailed { message: String },
353    #[error("Request timeout")]
354    Timeout,
355}
356
357impl From<GrpcError> for DMSCError {
358    fn from(error: GrpcError) -> Self {
359        DMSCError::Other(format!("gRPC error: {}", error))
360    }
361}
362
363use crate::core::DMSCError;