1use super::*;
21use std::time::Duration;
22use std::sync::atomic::{AtomicU64, Ordering};
23
24#[cfg(feature = "pyo3")]
25#[allow(unused_imports)]
26use pyo3::prelude::*;
27
28pub struct DMSCGrpcClient {
29 channel: Option<tonic::transport::Channel>,
30 endpoint: String,
31 timeout: Duration,
32 stats: Arc<RwLock<DMSCGrpcStats>>,
33 request_id: Arc<AtomicU64>,
34 connected: Arc<RwLock<bool>>,
35 retry_count: u32,
36 retry_delay: Duration,
37}
38
39#[cfg(feature = "pyo3")]
40#[pyclass]
41pub struct DMSCGrpcClientPy {
42 inner: DMSCGrpcClient,
43}
44
45#[cfg(feature = "pyo3")]
46#[pymethods]
47impl DMSCGrpcClientPy {
48 #[new]
49 fn new(endpoint: String) -> Self {
50 Self {
51 inner: DMSCGrpcClient::new(endpoint),
52 }
53 }
54
55 #[pyo3(signature = (timeout_secs=30))]
56 fn with_timeout(&mut self, timeout_secs: u64) {
57 self.inner.timeout = Duration::from_secs(timeout_secs);
58 }
59
60 #[pyo3(signature = (count=3, delay_ms=100))]
61 fn with_retry(&mut self, count: u32, delay_ms: u64) {
62 self.inner.retry_count = count;
63 self.inner.retry_delay = Duration::from_millis(delay_ms);
64 }
65
66 fn get_stats(&self) -> DMSCGrpcStats {
67 self.inner.get_stats()
68 }
69
70 fn is_connected(&self) -> bool {
71 self.inner.channel.is_some()
72 }
73
74 fn get_endpoint(&self) -> String {
75 self.inner.endpoint.clone()
76 }
77
78 fn connect(&mut self) -> PyResult<()> {
79 let rt = tokio::runtime::Runtime::new()
80 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
81
82 rt.block_on(async {
83 self.inner.connect().await
84 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
85 }
86
87 fn disconnect(&mut self) {
88 if let Ok(rt) = tokio::runtime::Runtime::new() {
89 rt.block_on(async {
90 self.inner.disconnect().await
91 });
92 }
93 }
94
95 #[pyo3(signature = (service_name, method, data))]
96 fn call(&mut self, service_name: String, method: String, data: Vec<u8>) -> PyResult<Vec<u8>> {
97 let rt = tokio::runtime::Runtime::new()
98 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
99
100 rt.block_on(async {
101 self.inner.call(&service_name, &method, &data).await
102 }).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
103 }
104}
105
106impl DMSCGrpcClient {
107 pub fn new(endpoint: String) -> Self {
108 Self {
109 channel: None,
110 endpoint,
111 timeout: Duration::from_secs(30),
112 stats: Arc::new(RwLock::new(DMSCGrpcStats::new())),
113 request_id: Arc::new(AtomicU64::new(0)),
114 connected: Arc::new(RwLock::new(false)),
115 retry_count: 3,
116 retry_delay: Duration::from_millis(100),
117 }
118 }
119
120 pub fn with_timeout(mut self, timeout: Duration) -> Self {
121 self.timeout = timeout;
122 self
123 }
124
125 pub fn with_retry(mut self, count: u32, delay: Duration) -> Self {
126 self.retry_count = count;
127 self.retry_delay = delay;
128 self
129 }
130
131 pub fn get_stats(&self) -> DMSCGrpcStats {
132 self.stats.try_read()
133 .map(|guard| guard.clone())
134 .unwrap_or_else(|_| DMSCGrpcStats::new())
135 }
136
137 pub async fn connect(&mut self) -> DMSCResult<()> {
138 let endpoint = tonic::transport::Endpoint::from_shared(self.endpoint.clone())
139 .map_err(|e| GrpcError::ConnectionFailed {
140 message: format!("Invalid endpoint: {}", e)
141 })?
142 .connect_timeout(self.timeout)
143 .timeout(self.timeout);
144
145 let channel = endpoint.connect()
146 .await
147 .map_err(|e| GrpcError::ConnectionFailed {
148 message: format!("Connection failed: {}", e)
149 })?;
150
151 self.channel = Some(channel);
152 *self.connected.write().await = true;
153
154 tracing::info!("gRPC client connected to {}", self.endpoint);
155 Ok(())
156 }
157
158 pub async fn is_connected(&self) -> bool {
159 *self.connected.read().await && self.channel.is_some()
160 }
161
162 fn generate_request_id(&self) -> u64 {
163 self.request_id.fetch_add(1, Ordering::SeqCst)
164 }
165
166 pub async fn call(&mut self, service_name: &str, method: &str, data: &[u8]) -> DMSCResult<Vec<u8>> {
167 let channel = match &self.channel {
168 Some(ch) => ch.clone(),
169 None => {
170 return Err(GrpcError::Client {
171 message: "Not connected to gRPC server".to_string()
172 }.into());
173 }
174 };
175
176 if !*self.connected.read().await {
177 return Err(GrpcError::Client {
178 message: "gRPC client not connected".to_string()
179 }.into());
180 }
181
182 let request_id = self.generate_request_id();
183 let path = format!("/{}/{}", service_name, method);
184
185 tracing::debug!("gRPC call: {} (request_id={})", path, request_id);
186
187 let mut last_error: Option<DMSCError> = None;
188 for attempt in 0..=self.retry_count {
189 if attempt > 0 {
190 tokio::time::sleep(self.retry_delay).await;
191 tracing::warn!("Retrying gRPC call (attempt {}/{})", attempt, self.retry_count);
192 }
193
194 match Self::execute_unary_call(channel.clone(), &path, data).await {
195 Ok(response) => {
196 let mut stats = self.stats.write().await;
197 stats.record_request(data.len());
198 stats.record_response(response.len());
199 return Ok(response);
200 }
201 Err(e) => {
202 last_error = Some(e.clone());
203 let mut stats = self.stats.write().await;
204 stats.record_error();
205
206 if !Self::is_retryable_error(&e) {
207 return Err(e);
208 }
209 }
210 }
211 }
212
213 Err(last_error.unwrap_or_else(|| GrpcError::Client {
214 message: "Unknown error after retries".to_string()
215 }.into()))
216 }
217
218 async fn execute_unary_call(
219 channel: tonic::transport::Channel,
220 path: &str,
221 data: &[u8],
222 ) -> DMSCResult<Vec<u8>> {
223 use tonic::client::Grpc;
224 use tonic::codec::ProstCodec;
225
226 let codec = ProstCodec::<Vec<u8>, Vec<u8>>::new();
227 let mut client = Grpc::new(channel);
228
229 let request = tonic::Request::new(data.to_vec());
230 let path_and_query: http::uri::PathAndQuery = path.parse()
231 .map_err(|e| GrpcError::Client {
232 message: format!("Invalid path: {}", e)
233 })?;
234
235 let response = client.unary(request, path_and_query, codec)
236 .await
237 .map_err(|e| GrpcError::Client {
238 message: format!("RPC call failed: {}", e)
239 })?;
240
241 Ok(response.into_inner())
242 }
243
244 fn is_retryable_error(error: &DMSCError) -> bool {
245 let error_str = error.to_string();
246 error_str.contains("UNAVAILABLE") ||
247 error_str.contains("DEADLINE_EXCEEDED") ||
248 error_str.contains("RESOURCE_EXHAUSTED")
249 }
250
251 pub async fn disconnect(&mut self) {
252 self.channel.take();
253 *self.connected.write().await = false;
254 tracing::info!("gRPC client disconnected from {}", self.endpoint);
255 }
256}
257
258impl Drop for DMSCGrpcClient {
259 fn drop(&mut self) {
260 if self.channel.is_some() {
261 if let Ok(rt) = tokio::runtime::Runtime::new() {
262 rt.block_on(async {
263 self.disconnect().await;
264 });
265 }
266 }
267 }
268}
269
270impl Default for DMSCGrpcClient {
271 fn default() -> Self {
272 Self::new("http://127.0.0.1:50051".to_string())
273 }
274}
275
276impl Clone for DMSCGrpcClient {
277 fn clone(&self) -> Self {
278 Self {
279 channel: self.channel.clone(),
280 endpoint: self.endpoint.clone(),
281 timeout: self.timeout,
282 stats: self.stats.clone(),
283 request_id: self.request_id.clone(),
284 connected: self.connected.clone(),
285 retry_count: self.retry_count,
286 retry_delay: self.retry_delay,
287 }
288 }
289}