dmsc/grpc/
server.rs

1//! Copyright 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18//! # gRPC Server Implementation
19
20use super::*;
21use tokio::sync::mpsc;
22use std::net::SocketAddr;
23use tokio::io::{AsyncReadExt, AsyncWriteExt};
24
25#[cfg(feature = "pyo3")]
26#[allow(unused_imports)]
27use pyo3::prelude::*;
28
29pub struct DMSCGrpcServer {
30    config: DMSCGrpcConfig,
31    stats: Arc<RwLock<DMSCGrpcStats>>,
32    registry: DMSCGrpcServiceRegistry,
33    shutdown_tx: Option<mpsc::Sender<()>>,
34    running: Arc<RwLock<bool>>,
35}
36
37#[cfg(feature = "pyo3")]
38#[pyclass]
39pub struct DMSCGrpcServerPy {
40    inner: DMSCGrpcServer,
41}
42
43#[cfg(feature = "pyo3")]
44#[pymethods]
45impl DMSCGrpcServerPy {
46    #[new]
47    fn new() -> Self {
48        Self {
49            inner: DMSCGrpcServer::new(DMSCGrpcConfig::default()),
50        }
51    }
52
53    fn get_stats(&self) -> DMSCGrpcStats {
54        self.inner.get_stats()
55    }
56
57    fn is_running(&self) -> bool {
58        self.inner.is_running_sync()
59    }
60
61    fn list_services(&self) -> Vec<String> {
62        self.inner.list_services()
63    }
64
65    fn start(&mut self) -> PyResult<()> {
66        let rt = tokio::runtime::Runtime::new()
67            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
68        
69        rt.block_on(async {
70            self.inner.start().await
71        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
72    }
73
74    fn stop(&mut self) -> PyResult<()> {
75        let rt = tokio::runtime::Runtime::new()
76            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
77        
78        rt.block_on(async {
79            self.inner.stop().await
80        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
81    }
82}
83
84impl DMSCGrpcServer {
85    pub fn new(config: DMSCGrpcConfig) -> Self {
86        Self {
87            config,
88            stats: Arc::new(RwLock::new(DMSCGrpcStats::new())),
89            registry: DMSCGrpcServiceRegistry::new(),
90            shutdown_tx: None,
91            running: Arc::new(RwLock::new(false)),
92        }
93    }
94
95    pub fn get_stats(&self) -> DMSCGrpcStats {
96        self.stats.try_read()
97            .map(|guard| guard.clone())
98            .unwrap_or_else(|_| DMSCGrpcStats::new())
99    }
100
101    pub fn is_running_sync(&self) -> bool {
102        self.stats.try_read()
103            .map(|guard| guard.active_connections > 0)
104            .unwrap_or(false)
105    }
106
107    pub fn list_services(&self) -> Vec<String> {
108        self.registry.list_services()
109    }
110
111    pub async fn start(&mut self) -> DMSCResult<()> {
112        let addr: SocketAddr = format!("{}:{}", self.config.addr, self.config.port).parse()
113            .map_err(|e| GrpcError::Server { message: format!("Invalid address: {}", e) })?;
114
115        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
116        self.shutdown_tx = Some(shutdown_tx);
117
118        *self.running.write().await = true;
119
120        let stats = self.stats.clone();
121        let registry = self.registry.clone();
122        let running = self.running.clone();
123        let max_concurrent = self.config.max_concurrent_requests as usize;
124
125        tokio::spawn(async move {
126            let _ = Self::run_server(addr, stats, registry, shutdown_rx, running, max_concurrent).await;
127        });
128
129        tracing::info!("gRPC server started on {}", addr);
130        Ok(())
131    }
132
133    async fn run_server(
134        addr: SocketAddr,
135        stats: Arc<RwLock<DMSCGrpcStats>>,
136        registry: DMSCGrpcServiceRegistry,
137        mut shutdown_rx: mpsc::Receiver<()>,
138        running: Arc<RwLock<bool>>,
139        max_concurrent: usize,
140    ) {
141        let listener = match tokio::net::TcpListener::bind(&addr).await {
142            Ok(l) => l,
143            Err(e) => {
144                tracing::error!("Failed to bind gRPC server to {}: {}", addr, e);
145                return;
146            }
147        };
148        
149        tracing::info!("gRPC server listening on {}", addr);
150
151        let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent));
152
153        loop {
154            tokio::select! {
155                _ = shutdown_rx.recv() => {
156                    tracing::info!("gRPC server shutting down");
157                    break;
158                }
159                result = listener.accept() => {
160                    match result {
161                        Ok((stream, peer_addr)) => {
162                            let permit = semaphore.clone().acquire_owned().await;
163                            if let Ok(permit) = permit {
164                                stats.write().await.active_connections += 1;
165                                
166                                let stats_clone = stats.clone();
167                                let registry_clone = registry.clone();
168                                
169                                tokio::spawn(async move {
170                                    Self::handle_connection(stream, peer_addr, stats_clone, registry_clone).await;
171                                    drop(permit);
172                                });
173                            }
174                        }
175                        Err(e) => {
176                            tracing::error!("Failed to accept connection: {}", e);
177                        }
178                    }
179                }
180            }
181        }
182
183        *running.write().await = false;
184    }
185
186    async fn handle_connection(
187        mut stream: tokio::net::TcpStream,
188        peer_addr: SocketAddr,
189        stats: Arc<RwLock<DMSCGrpcStats>>,
190        registry: DMSCGrpcServiceRegistry,
191    ) {
192        tracing::debug!("gRPC client connected from {}", peer_addr);
193
194        let mut buffer = vec![0u8; 65536];
195        
196        loop {
197            let n = match stream.read(&mut buffer).await {
198                Ok(0) => break,
199                Ok(n) => n,
200                Err(e) => {
201                    tracing::debug!("Read error from {}: {}", peer_addr, e);
202                    break;
203                }
204            };
205
206            let request_data = &buffer[..n];
207            
208            if let Err(e) = Self::process_request(&mut stream, request_data, &stats, &registry).await {
209                tracing::error!("Error processing request from {}: {}", peer_addr, e);
210                break;
211            }
212        }
213
214        let mut stats_guard = stats.write().await;
215        if stats_guard.active_connections > 0 {
216            stats_guard.active_connections -= 1;
217        }
218    }
219
220    async fn process_request(
221        stream: &mut tokio::net::TcpStream,
222        request_data: &[u8],
223        stats: &Arc<RwLock<DMSCGrpcStats>>,
224        registry: &DMSCGrpcServiceRegistry,
225    ) -> DMSCResult<()> {
226        let request_str = String::from_utf8_lossy(request_data);
227        
228        let (service_name, method_name) = Self::parse_request_path(&request_str)?;
229        
230        tracing::debug!("gRPC request: {}/{}", service_name, method_name);
231
232        let services = registry.services.read().await;
233        let service = services.get(&service_name).cloned();
234        drop(services);
235
236        match service {
237            Some(svc) => {
238                let body_start = Self::find_body_start(request_data);
239                let body_data = if body_start < request_data.len() {
240                    &request_data[body_start..]
241                } else {
242                    &request_data[0..0]
243                };
244
245                stats.write().await.record_request(body_data.len());
246
247                match svc.handle_request(&method_name, body_data).await {
248                    Ok(response_data) => {
249                        stats.write().await.record_response(response_data.len());
250                        
251                        let grpc_response = Self::build_grpc_response(&response_data);
252                        stream.write_all(&grpc_response).await
253                            .map_err(|e| GrpcError::Server { message: format!("Write error: {}", e) })?;
254                    }
255                    Err(e) => {
256                        stats.write().await.record_error();
257                        
258                        let error_response = Self::build_grpc_error_response(&e.to_string());
259                        stream.write_all(&error_response).await
260                            .map_err(|e| GrpcError::Server { message: format!("Write error: {}", e) })?;
261                    }
262                }
263            }
264            None => {
265                stats.write().await.record_error();
266                
267                let error_response = Self::build_grpc_error_response(&format!("Service not found: {}", service_name));
268                stream.write_all(&error_response).await
269                    .map_err(|e| GrpcError::Server { message: format!("Write error: {}", e) })?;
270            }
271        }
272
273        Ok(())
274    }
275
276    fn parse_request_path(request_str: &str) -> DMSCResult<(String, String)> {
277        for line in request_str.lines() {
278            if line.contains(":path") {
279                let parts: Vec<&str> = line.split_whitespace().collect();
280                if parts.len() >= 2 {
281                    let full_path = parts[1].trim_start_matches('/');
282                    let path_parts: Vec<&str> = full_path.splitn(2, '/').collect();
283                    if path_parts.len() == 2 {
284                        return Ok((path_parts[0].to_string(), path_parts[1].to_string()));
285                    }
286                }
287            }
288        }
289        
290        Err(GrpcError::Server { message: "Invalid request path".to_string() }.into())
291    }
292
293    fn find_body_start(buffer: &[u8]) -> usize {
294        let mut pos = 0;
295        while pos + 3 < buffer.len() {
296            if buffer[pos] == b'\r' && buffer[pos + 1] == b'\n' && buffer[pos + 2] == b'\r' && buffer[pos + 3] == b'\n' {
297                return pos + 4;
298            }
299            pos += 1;
300        }
301        buffer.len()
302    }
303
304    fn build_grpc_response(data: &[u8]) -> Vec<u8> {
305        let mut response = Vec::new();
306        
307        let header = "HTTP/2.0 200 OK\r\ncontent-type: application/grpc\r\n\r\n";
308        response.extend_from_slice(header.as_bytes());
309        
310        let len = data.len() as u32;
311        response.push(0u8);
312        response.extend_from_slice(&len.to_be_bytes()[1..4]);
313        response.extend_from_slice(data);
314        
315        let trailers = "\r\ngrpc-status: 0\r\n\r\n";
316        response.extend_from_slice(trailers.as_bytes());
317        
318        response
319    }
320
321    fn build_grpc_error_response(message: &str) -> Vec<u8> {
322        let mut response = Vec::new();
323        
324        let header = "HTTP/2.0 200 OK\r\ncontent-type: application/grpc\r\n\r\n";
325        response.extend_from_slice(header.as_bytes());
326        
327        let trailers = format!("\r\ngrpc-status: 2\r\ngrpc-message: {}\r\n\r\n", message);
328        response.extend_from_slice(trailers.as_bytes());
329        
330        response
331    }
332
333    pub async fn stop(&mut self) -> DMSCResult<()> {
334        *self.running.write().await = false;
335
336        if let Some(tx) = self.shutdown_tx.take() {
337            tx.send(()).await.map_err(|e| GrpcError::Server {
338                message: format!("Shutdown error: {}", e)
339            })?;
340        }
341
342        tracing::info!("gRPC server stopped");
343        Ok(())
344    }
345
346    pub async fn is_running(&self) -> bool {
347        *self.running.read().await
348    }
349}
350
351impl Clone for DMSCGrpcServer {
352    fn clone(&self) -> Self {
353        Self {
354            config: self.config.clone(),
355            stats: self.stats.clone(),
356            registry: self.registry.clone(),
357            shutdown_tx: None,
358            running: self.running.clone(),
359        }
360    }
361}
362
363impl Default for DMSCGrpcServer {
364    fn default() -> Self {
365        Self::new(DMSCGrpcConfig::default())
366    }
367}