1#![cfg(feature = "gateway")]
19#![allow(non_snake_case)]
20
21use crate::core::{DMSCResult, DMSCError};
22use crate::gateway::{DMSCGateway, DMSCGatewayConfig, DMSCGatewayRequest};
23use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
24use hyper::service::{make_service_fn, service_fn};
25use std::collections::HashMap;
26use std::convert::Infallible;
27use std::net::SocketAddr;
28use std::sync::Arc;
29use tokio::sync::RwLock;
30use tokio_rustls::rustls::ServerConfig;
31
32pub struct DMSCGatewayServer {
33 gateway: Arc<DMSCGateway>,
34 config: Arc<RwLock<DMSCGatewayConfig>>,
35 addr: SocketAddr,
36 tls_config: Option<ServerConfig>,
37 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
38}
39
40impl DMSCGatewayServer {
41 pub fn new(gateway: Arc<DMSCGateway>, config: Arc<RwLock<DMSCGatewayConfig>>, addr: SocketAddr) -> Self {
42 Self {
43 gateway,
44 config,
45 addr,
46 tls_config: None,
47 shutdown_tx: None,
48 }
49 }
50
51 pub fn with_tls(mut self, tls_config: ServerConfig) -> Self {
52 self.tls_config = Some(tls_config);
53 self
54 }
55
56 pub async fn serve(&mut self) -> DMSCResult<()> {
57 let addr = self.addr;
58 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
59 self.shutdown_tx = Some(shutdown_tx);
60
61 let gateway = self.gateway.clone();
62 let config = self.config.clone();
63
64 let service = make_service_fn(move |_conn| {
65 let gateway = gateway.clone();
66 let config = config.clone();
67 async move {
68 Ok::<_, Infallible>(service_fn(move |req: HyperRequest<Body>| {
69 Self::handle_request(req, gateway.clone(), config.clone())
70 }))
71 }
72 });
73
74 let server = Server::bind(&addr)
75 .http1_pipeline_flush(true)
76 .serve(service);
77
78 let graceful = server.with_graceful_shutdown(async {
79 shutdown_rx.await.ok();
80 });
81
82 graceful.await.map_err(|e| DMSCError::Other(format!("Server error: {}", e)))
83 }
84
85 async fn handle_request(
86 req: HyperRequest<Body>,
87 gateway: Arc<DMSCGateway>,
88 config: Arc<RwLock<DMSCGatewayConfig>>,
89 ) -> Result<HyperResponse<Body>, Infallible> {
90 let request_id = uuid::Uuid::new_v4().to_string();
91 let start = std::time::Instant::now();
92
93 let method = req.method().to_string();
94 let path = req.uri().path().to_string();
95 let remote_addr = req
96 .headers()
97 .get("X-Forwarded-For")
98 .and_then(|v| v.to_str().ok())
99 .map(|s| s.to_string())
100 .unwrap_or_else(|| {
101 req.extensions()
102 .get::<SocketAddr>()
103 .map(|a| a.to_string())
104 .unwrap_or_else(|| "unknown".to_string())
105 });
106
107 let mut headers = HashMap::new();
108 for (key, value) in req.headers() {
109 if let Ok(v) = value.to_str() {
110 headers.insert(key.as_str().to_string(), v.to_string());
111 }
112 }
113
114 let query_params = {
115 let uri = req.uri();
116 let query = uri.query().unwrap_or("");
117 let mut params = HashMap::new();
118 for pair in query.split('&') {
119 if let Some((key, value)) = pair.split_once('=') {
120 params.insert(
121 key.to_string(),
122 value.to_string(),
123 );
124 }
125 }
126 params
127 };
128
129 let body = match hyper::body::to_bytes(req.into_body()).await {
130 Ok(bytes) => {
131 if bytes.is_empty() {
132 None
133 } else {
134 Some(bytes.to_vec())
135 }
136 }
137 Err(_) => None,
138 };
139
140 let dmsc_request = DMSCGatewayRequest::new(
141 method.clone(),
142 path.clone(),
143 headers,
144 query_params,
145 body,
146 remote_addr.clone(),
147 );
148
149 let response = gateway.handle_request(dmsc_request).await;
150
151 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
152
153 if config.read().await.enable_logging {
154 let log_level = &config.read().await.log_level;
155 match log_level.as_str() {
156 "debug" => {
157 log::debug!(
158 target: "DMSC.Gateway",
159 "{} {} {} {} {}ms",
160 method,
161 path,
162 response.status_code,
163 request_id,
164 duration_ms
165 );
166 }
167 "info" => {
168 log::info!(
169 target: "DMSC.Gateway",
170 "{} {} {} {}ms",
171 method,
172 path,
173 response.status_code,
174 duration_ms
175 );
176 }
177 "warn" => {
178 log::warn!(
179 target: "DMSC.Gateway",
180 "{} {} {} {}ms",
181 method,
182 path,
183 response.status_code,
184 duration_ms
185 );
186 }
187 "error" => {
188 log::error!(
189 target: "DMSC.Gateway",
190 "{} {} {} {}ms",
191 method,
192 path,
193 response.status_code,
194 duration_ms
195 );
196 }
197 _ => {}
198 }
199 }
200
201 let mut hyper_response = HyperResponse::builder()
202 .status(StatusCode::from_u16(response.status_code).unwrap_or(StatusCode::OK));
203
204 for (key, value) in response.headers {
205 if let (Ok(k), Ok(v)) = (key.parse::<hyper::header::HeaderName>(), value.parse::<hyper::header::HeaderValue>()) {
206 hyper_response = hyper_response.header(k, v);
207 }
208 }
209
210 let body = Body::from(response.body);
211 Ok(hyper_response.body(body).unwrap_or_else(|_| HyperResponse::default()))
212 }
213
214 pub async fn shutdown(&mut self) {
215 if let Some(tx) = self.shutdown_tx.take() {
216 let _ = tx.send(());
217 }
218 }
219}
220
221impl Drop for DMSCGatewayServer {
222 fn drop(&mut self) {
223 if let Some(tx) = self.shutdown_tx.take() {
224 let _ = tx.send(());
225 }
226 }
227}
228
229pub fn load_tls_config(
230 cert_path: &str,
231 key_path: &str,
232) -> DMSCResult<ServerConfig> {
233 let cert = std::fs::read(cert_path)
234 .map_err(|e| DMSCError::Config(format!("Failed to read TLS certificate: {}", e)))?;
235 let key = std::fs::read(key_path)
236 .map_err(|e| DMSCError::Config(format!("Failed to read TLS key: {}", e)))?;
237
238 let cert_chain = tokio_rustls::rustls::Certificate(cert);
239 let private_key = tokio_rustls::rustls::PrivateKey(key);
240
241 let mut server_config = tokio_rustls::rustls::ServerConfig::builder()
242 .with_safe_defaults()
243 .with_no_client_auth()
244 .with_single_cert(vec![cert_chain], private_key)
245 .map_err(|e| DMSCError::Config(format!("Failed to build TLS config: {}", e)))?;
246
247 server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
248
249 Ok(server_config)
250}