dmsc/grpc/
client.rs

1//! Copyright 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18//! # gRPC Client Implementation
19
20use super::*;
21use std::time::Duration;
22use std::sync::atomic::{AtomicU64, Ordering};
23
24#[cfg(feature = "pyo3")]
25#[allow(unused_imports)]
26use pyo3::prelude::*;
27
28pub struct DMSCGrpcClient {
29    channel: Option<tonic::transport::Channel>,
30    endpoint: String,
31    timeout: Duration,
32    stats: Arc<RwLock<DMSCGrpcStats>>,
33    request_id: Arc<AtomicU64>,
34    connected: Arc<RwLock<bool>>,
35    retry_count: u32,
36    retry_delay: Duration,
37}
38
39#[cfg(feature = "pyo3")]
40#[pyclass]
41pub struct DMSCGrpcClientPy {
42    inner: DMSCGrpcClient,
43}
44
45#[cfg(feature = "pyo3")]
46#[pymethods]
47impl DMSCGrpcClientPy {
48    #[new]
49    fn new(endpoint: String) -> Self {
50        Self {
51            inner: DMSCGrpcClient::new(endpoint),
52        }
53    }
54
55    #[pyo3(signature = (timeout_secs=30))]
56    fn with_timeout(&mut self, timeout_secs: u64) {
57        self.inner.timeout = Duration::from_secs(timeout_secs);
58    }
59
60    #[pyo3(signature = (count=3, delay_ms=100))]
61    fn with_retry(&mut self, count: u32, delay_ms: u64) {
62        self.inner.retry_count = count;
63        self.inner.retry_delay = Duration::from_millis(delay_ms);
64    }
65
66    fn get_stats(&self) -> DMSCGrpcStats {
67        self.inner.get_stats()
68    }
69
70    fn is_connected(&self) -> bool {
71        self.inner.channel.is_some()
72    }
73
74    fn get_endpoint(&self) -> String {
75        self.inner.endpoint.clone()
76    }
77
78    fn connect(&mut self) -> PyResult<()> {
79        let rt = tokio::runtime::Runtime::new()
80            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
81        
82        rt.block_on(async {
83            self.inner.connect().await
84        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
85    }
86
87    fn disconnect(&mut self) {
88        if let Ok(rt) = tokio::runtime::Runtime::new() {
89            rt.block_on(async {
90                self.inner.disconnect().await
91            });
92        }
93    }
94
95    #[pyo3(signature = (service_name, method, data))]
96    fn call(&mut self, service_name: String, method: String, data: Vec<u8>) -> PyResult<Vec<u8>> {
97        let rt = tokio::runtime::Runtime::new()
98            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
99        
100        rt.block_on(async {
101            self.inner.call(&service_name, &method, &data).await
102        }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
103    }
104}
105
106impl DMSCGrpcClient {
107    pub fn new(endpoint: String) -> Self {
108        Self {
109            channel: None,
110            endpoint,
111            timeout: Duration::from_secs(30),
112            stats: Arc::new(RwLock::new(DMSCGrpcStats::new())),
113            request_id: Arc::new(AtomicU64::new(0)),
114            connected: Arc::new(RwLock::new(false)),
115            retry_count: 3,
116            retry_delay: Duration::from_millis(100),
117        }
118    }
119
120    pub fn with_timeout(mut self, timeout: Duration) -> Self {
121        self.timeout = timeout;
122        self
123    }
124
125    pub fn with_retry(mut self, count: u32, delay: Duration) -> Self {
126        self.retry_count = count;
127        self.retry_delay = delay;
128        self
129    }
130
131    pub fn get_stats(&self) -> DMSCGrpcStats {
132        self.stats.try_read()
133            .map(|guard| guard.clone())
134            .unwrap_or_else(|_| DMSCGrpcStats::new())
135    }
136
137    pub async fn connect(&mut self) -> DMSCResult<()> {
138        let endpoint = tonic::transport::Endpoint::from_shared(self.endpoint.clone())
139            .map_err(|e| GrpcError::ConnectionFailed {
140                message: format!("Invalid endpoint: {}", e)
141            })?
142            .connect_timeout(self.timeout)
143            .timeout(self.timeout);
144
145        let channel = endpoint.connect()
146            .await
147            .map_err(|e| GrpcError::ConnectionFailed {
148                message: format!("Connection failed: {}", e)
149            })?;
150
151        self.channel = Some(channel);
152        *self.connected.write().await = true;
153
154        tracing::info!("gRPC client connected to {}", self.endpoint);
155        Ok(())
156    }
157
158    pub async fn is_connected(&self) -> bool {
159        *self.connected.read().await && self.channel.is_some()
160    }
161
162    fn generate_request_id(&self) -> u64 {
163        self.request_id.fetch_add(1, Ordering::SeqCst)
164    }
165
166    pub async fn call(&mut self, service_name: &str, method: &str, data: &[u8]) -> DMSCResult<Vec<u8>> {
167        let channel = match &self.channel {
168            Some(ch) => ch.clone(),
169            None => {
170                return Err(GrpcError::Client {
171                    message: "Not connected to gRPC server".to_string()
172                }.into());
173            }
174        };
175
176        if !*self.connected.read().await {
177            return Err(GrpcError::Client {
178                message: "gRPC client not connected".to_string()
179            }.into());
180        }
181
182        let request_id = self.generate_request_id();
183        let path = format!("/{}/{}", service_name, method);
184
185        tracing::debug!("gRPC call: {} (request_id={})", path, request_id);
186
187        let mut last_error: Option<DMSCError> = None;
188        for attempt in 0..=self.retry_count {
189            if attempt > 0 {
190                tokio::time::sleep(self.retry_delay).await;
191                tracing::warn!("Retrying gRPC call (attempt {}/{})", attempt, self.retry_count);
192            }
193
194            match Self::execute_unary_call(channel.clone(), &path, data).await {
195                Ok(response) => {
196                    let mut stats = self.stats.write().await;
197                    stats.record_request(data.len());
198                    stats.record_response(response.len());
199                    return Ok(response);
200                }
201                Err(e) => {
202                    last_error = Some(e.clone());
203                    let mut stats = self.stats.write().await;
204                    stats.record_error();
205                    
206                    if !Self::is_retryable_error(&e) {
207                        return Err(e);
208                    }
209                }
210            }
211        }
212
213        Err(last_error.unwrap_or_else(|| GrpcError::Client {
214            message: "Unknown error after retries".to_string()
215        }.into()))
216    }
217
218    async fn execute_unary_call(
219        channel: tonic::transport::Channel,
220        path: &str,
221        data: &[u8],
222    ) -> DMSCResult<Vec<u8>> {
223        use tonic::client::Grpc;
224        use tonic::codec::ProstCodec;
225        
226        let codec = ProstCodec::<Vec<u8>, Vec<u8>>::new();
227        let mut client = Grpc::new(channel);
228        
229        let request = tonic::Request::new(data.to_vec());
230        let path_and_query: http::uri::PathAndQuery = path.parse()
231            .map_err(|e| GrpcError::Client {
232                message: format!("Invalid path: {}", e)
233            })?;
234        
235        let response = client.unary(request, path_and_query, codec)
236            .await
237            .map_err(|e| GrpcError::Client {
238                message: format!("RPC call failed: {}", e)
239            })?;
240
241        Ok(response.into_inner())
242    }
243
244    fn is_retryable_error(error: &DMSCError) -> bool {
245        let error_str = error.to_string();
246        error_str.contains("UNAVAILABLE") ||
247        error_str.contains("DEADLINE_EXCEEDED") ||
248        error_str.contains("RESOURCE_EXHAUSTED")
249    }
250
251    pub async fn disconnect(&mut self) {
252        self.channel.take();
253        *self.connected.write().await = false;
254        tracing::info!("gRPC client disconnected from {}", self.endpoint);
255    }
256}
257
258impl Drop for DMSCGrpcClient {
259    fn drop(&mut self) {
260        if self.channel.is_some() {
261            if let Ok(rt) = tokio::runtime::Runtime::new() {
262                rt.block_on(async {
263                    self.disconnect().await;
264                });
265            }
266        }
267    }
268}
269
270impl Default for DMSCGrpcClient {
271    fn default() -> Self {
272        Self::new("http://127.0.0.1:50051".to_string())
273    }
274}
275
276impl Clone for DMSCGrpcClient {
277    fn clone(&self) -> Self {
278        Self {
279            channel: self.channel.clone(),
280            endpoint: self.endpoint.clone(),
281            timeout: self.timeout,
282            stats: self.stats.clone(),
283            request_id: self.request_id.clone(),
284            connected: self.connected.clone(),
285            retry_count: self.retry_count,
286            retry_delay: self.retry_delay,
287        }
288    }
289}