1use super::*;
21use tokio::sync::mpsc;
22use std::net::SocketAddr;
23use tokio::io::{AsyncReadExt, AsyncWriteExt};
24
25#[cfg(feature = "pyo3")]
26#[allow(unused_imports)]
27use pyo3::prelude::*;
28
29pub struct DMSCGrpcServer {
30 config: DMSCGrpcConfig,
31 stats: Arc<RwLock<DMSCGrpcStats>>,
32 registry: DMSCGrpcServiceRegistry,
33 shutdown_tx: Option<mpsc::Sender<()>>,
34 running: Arc<RwLock<bool>>,
35}
36
37#[cfg(feature = "pyo3")]
38#[pyclass]
39pub struct DMSCGrpcServerPy {
40 inner: DMSCGrpcServer,
41}
42
43#[cfg(feature = "pyo3")]
44#[pymethods]
45impl DMSCGrpcServerPy {
46 #[new]
47 fn new() -> Self {
48 Self {
49 inner: DMSCGrpcServer::new(DMSCGrpcConfig::default()),
50 }
51 }
52
53 fn get_stats(&self) -> DMSCGrpcStats {
54 self.inner.get_stats()
55 }
56
57 fn is_running(&self) -> bool {
58 self.inner.is_running_sync()
59 }
60
61 fn list_services(&self) -> Vec<String> {
62 self.inner.list_services()
63 }
64
65 fn start(&mut self) -> PyResult<()> {
66 let rt = tokio::runtime::Runtime::new()
67 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
68
69 rt.block_on(async {
70 self.inner.start().await
71 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
72 }
73
74 fn stop(&mut self) -> PyResult<()> {
75 let rt = tokio::runtime::Runtime::new()
76 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
77
78 rt.block_on(async {
79 self.inner.stop().await
80 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
81 }
82}
83
84impl DMSCGrpcServer {
85 pub fn new(config: DMSCGrpcConfig) -> Self {
86 Self {
87 config,
88 stats: Arc::new(RwLock::new(DMSCGrpcStats::new())),
89 registry: DMSCGrpcServiceRegistry::new(),
90 shutdown_tx: None,
91 running: Arc::new(RwLock::new(false)),
92 }
93 }
94
95 pub fn get_stats(&self) -> DMSCGrpcStats {
96 self.stats.try_read()
97 .map(|guard| guard.clone())
98 .unwrap_or_else(|_| DMSCGrpcStats::new())
99 }
100
101 pub fn is_running_sync(&self) -> bool {
102 self.stats.try_read()
103 .map(|guard| guard.active_connections > 0)
104 .unwrap_or(false)
105 }
106
107 pub fn list_services(&self) -> Vec<String> {
108 self.registry.list_services()
109 }
110
111 pub async fn start(&mut self) -> DMSCResult<()> {
112 let addr: SocketAddr = format!("{}:{}", self.config.addr, self.config.port).parse()
113 .map_err(|e| GrpcError::Server { message: format!("Invalid address: {}", e) })?;
114
115 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
116 self.shutdown_tx = Some(shutdown_tx);
117
118 *self.running.write().await = true;
119
120 let stats = self.stats.clone();
121 let registry = self.registry.clone();
122 let running = self.running.clone();
123 let max_concurrent = self.config.max_concurrent_requests as usize;
124
125 tokio::spawn(async move {
126 let _ = Self::run_server(addr, stats, registry, shutdown_rx, running, max_concurrent).await;
127 });
128
129 tracing::info!("gRPC server started on {}", addr);
130 Ok(())
131 }
132
133 async fn run_server(
134 addr: SocketAddr,
135 stats: Arc<RwLock<DMSCGrpcStats>>,
136 registry: DMSCGrpcServiceRegistry,
137 mut shutdown_rx: mpsc::Receiver<()>,
138 running: Arc<RwLock<bool>>,
139 max_concurrent: usize,
140 ) {
141 let listener = match tokio::net::TcpListener::bind(&addr).await {
142 Ok(l) => l,
143 Err(e) => {
144 tracing::error!("Failed to bind gRPC server to {}: {}", addr, e);
145 return;
146 }
147 };
148
149 tracing::info!("gRPC server listening on {}", addr);
150
151 let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrent));
152
153 loop {
154 tokio::select! {
155 _ = shutdown_rx.recv() => {
156 tracing::info!("gRPC server shutting down");
157 break;
158 }
159 result = listener.accept() => {
160 match result {
161 Ok((stream, peer_addr)) => {
162 let permit = semaphore.clone().acquire_owned().await;
163 if let Ok(permit) = permit {
164 stats.write().await.active_connections += 1;
165
166 let stats_clone = stats.clone();
167 let registry_clone = registry.clone();
168
169 tokio::spawn(async move {
170 Self::handle_connection(stream, peer_addr, stats_clone, registry_clone).await;
171 drop(permit);
172 });
173 }
174 }
175 Err(e) => {
176 tracing::error!("Failed to accept connection: {}", e);
177 }
178 }
179 }
180 }
181 }
182
183 *running.write().await = false;
184 }
185
186 async fn handle_connection(
187 mut stream: tokio::net::TcpStream,
188 peer_addr: SocketAddr,
189 stats: Arc<RwLock<DMSCGrpcStats>>,
190 registry: DMSCGrpcServiceRegistry,
191 ) {
192 tracing::debug!("gRPC client connected from {}", peer_addr);
193
194 let mut buffer = vec![0u8; 65536];
195
196 loop {
197 let n = match stream.read(&mut buffer).await {
198 Ok(0) => break,
199 Ok(n) => n,
200 Err(e) => {
201 tracing::debug!("Read error from {}: {}", peer_addr, e);
202 break;
203 }
204 };
205
206 let request_data = &buffer[..n];
207
208 if let Err(e) = Self::process_request(&mut stream, request_data, &stats, ®istry).await {
209 tracing::error!("Error processing request from {}: {}", peer_addr, e);
210 break;
211 }
212 }
213
214 let mut stats_guard = stats.write().await;
215 if stats_guard.active_connections > 0 {
216 stats_guard.active_connections -= 1;
217 }
218 }
219
220 async fn process_request(
221 stream: &mut tokio::net::TcpStream,
222 request_data: &[u8],
223 stats: &Arc<RwLock<DMSCGrpcStats>>,
224 registry: &DMSCGrpcServiceRegistry,
225 ) -> DMSCResult<()> {
226 let request_str = String::from_utf8_lossy(request_data);
227
228 let (service_name, method_name) = Self::parse_request_path(&request_str)?;
229
230 tracing::debug!("gRPC request: {}/{}", service_name, method_name);
231
232 let services = registry.services.read().await;
233 let service = services.get(&service_name).cloned();
234 drop(services);
235
236 match service {
237 Some(svc) => {
238 let body_start = Self::find_body_start(request_data);
239 let body_data = if body_start < request_data.len() {
240 &request_data[body_start..]
241 } else {
242 &request_data[0..0]
243 };
244
245 stats.write().await.record_request(body_data.len());
246
247 match svc.handle_request(&method_name, body_data).await {
248 Ok(response_data) => {
249 stats.write().await.record_response(response_data.len());
250
251 let grpc_response = Self::build_grpc_response(&response_data);
252 stream.write_all(&grpc_response).await
253 .map_err(|e| GrpcError::Server { message: format!("Write error: {}", e) })?;
254 }
255 Err(e) => {
256 stats.write().await.record_error();
257
258 let error_response = Self::build_grpc_error_response(&e.to_string());
259 stream.write_all(&error_response).await
260 .map_err(|e| GrpcError::Server { message: format!("Write error: {}", e) })?;
261 }
262 }
263 }
264 None => {
265 stats.write().await.record_error();
266
267 let error_response = Self::build_grpc_error_response(&format!("Service not found: {}", service_name));
268 stream.write_all(&error_response).await
269 .map_err(|e| GrpcError::Server { message: format!("Write error: {}", e) })?;
270 }
271 }
272
273 Ok(())
274 }
275
276 fn parse_request_path(request_str: &str) -> DMSCResult<(String, String)> {
277 for line in request_str.lines() {
278 if line.contains(":path") {
279 let parts: Vec<&str> = line.split_whitespace().collect();
280 if parts.len() >= 2 {
281 let full_path = parts[1].trim_start_matches('/');
282 let path_parts: Vec<&str> = full_path.splitn(2, '/').collect();
283 if path_parts.len() == 2 {
284 return Ok((path_parts[0].to_string(), path_parts[1].to_string()));
285 }
286 }
287 }
288 }
289
290 Err(GrpcError::Server { message: "Invalid request path".to_string() }.into())
291 }
292
293 fn find_body_start(buffer: &[u8]) -> usize {
294 let mut pos = 0;
295 while pos + 3 < buffer.len() {
296 if buffer[pos] == b'\r' && buffer[pos + 1] == b'\n' && buffer[pos + 2] == b'\r' && buffer[pos + 3] == b'\n' {
297 return pos + 4;
298 }
299 pos += 1;
300 }
301 buffer.len()
302 }
303
304 fn build_grpc_response(data: &[u8]) -> Vec<u8> {
305 let mut response = Vec::new();
306
307 let header = "HTTP/2.0 200 OK\r\ncontent-type: application/grpc\r\n\r\n";
308 response.extend_from_slice(header.as_bytes());
309
310 let len = data.len() as u32;
311 response.push(0u8);
312 response.extend_from_slice(&len.to_be_bytes()[1..4]);
313 response.extend_from_slice(data);
314
315 let trailers = "\r\ngrpc-status: 0\r\n\r\n";
316 response.extend_from_slice(trailers.as_bytes());
317
318 response
319 }
320
321 fn build_grpc_error_response(message: &str) -> Vec<u8> {
322 let mut response = Vec::new();
323
324 let header = "HTTP/2.0 200 OK\r\ncontent-type: application/grpc\r\n\r\n";
325 response.extend_from_slice(header.as_bytes());
326
327 let trailers = format!("\r\ngrpc-status: 2\r\ngrpc-message: {}\r\n\r\n", message);
328 response.extend_from_slice(trailers.as_bytes());
329
330 response
331 }
332
333 pub async fn stop(&mut self) -> DMSCResult<()> {
334 *self.running.write().await = false;
335
336 if let Some(tx) = self.shutdown_tx.take() {
337 tx.send(()).await.map_err(|e| GrpcError::Server {
338 message: format!("Shutdown error: {}", e)
339 })?;
340 }
341
342 tracing::info!("gRPC server stopped");
343 Ok(())
344 }
345
346 pub async fn is_running(&self) -> bool {
347 *self.running.read().await
348 }
349}
350
351impl Clone for DMSCGrpcServer {
352 fn clone(&self) -> Self {
353 Self {
354 config: self.config.clone(),
355 stats: self.stats.clone(),
356 registry: self.registry.clone(),
357 shutdown_tx: None,
358 running: self.running.clone(),
359 }
360 }
361}
362
363impl Default for DMSCGrpcServer {
364 fn default() -> Self {
365 Self::new(DMSCGrpcConfig::default())
366 }
367}