dmsc/gateway/
server.rs

1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18#![cfg(feature = "gateway")]
19#![allow(non_snake_case)]
20
21use crate::core::{DMSCResult, DMSCError};
22use crate::gateway::{DMSCGateway, DMSCGatewayConfig, DMSCGatewayRequest};
23use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
24use hyper::service::{make_service_fn, service_fn};
25use std::collections::HashMap;
26use std::convert::Infallible;
27use std::net::SocketAddr;
28use std::sync::Arc;
29use tokio::sync::RwLock;
30use tokio_rustls::rustls::ServerConfig;
31
32pub struct DMSCGatewayServer {
33    gateway: Arc<DMSCGateway>,
34    config: Arc<RwLock<DMSCGatewayConfig>>,
35    addr: SocketAddr,
36    tls_config: Option<ServerConfig>,
37    shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
38}
39
40impl DMSCGatewayServer {
41    pub fn new(gateway: Arc<DMSCGateway>, config: Arc<RwLock<DMSCGatewayConfig>>, addr: SocketAddr) -> Self {
42        Self {
43            gateway,
44            config,
45            addr,
46            tls_config: None,
47            shutdown_tx: None,
48        }
49    }
50
51    pub fn with_tls(mut self, tls_config: ServerConfig) -> Self {
52        self.tls_config = Some(tls_config);
53        self
54    }
55
56    pub async fn serve(&mut self) -> DMSCResult<()> {
57        let addr = self.addr;
58        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
59        self.shutdown_tx = Some(shutdown_tx);
60
61        let gateway = self.gateway.clone();
62        let config = self.config.clone();
63
64        let service = make_service_fn(move |_conn| {
65            let gateway = gateway.clone();
66            let config = config.clone();
67            async move {
68                Ok::<_, Infallible>(service_fn(move |req: HyperRequest<Body>| {
69                    Self::handle_request(req, gateway.clone(), config.clone())
70                }))
71            }
72        });
73
74        let server = Server::bind(&addr)
75            .http1_pipeline_flush(true)
76            .serve(service);
77
78        let graceful = server.with_graceful_shutdown(async {
79            shutdown_rx.await.ok();
80        });
81
82        graceful.await.map_err(|e| DMSCError::Other(format!("Server error: {}", e)))
83    }
84
85    async fn handle_request(
86        req: HyperRequest<Body>,
87        gateway: Arc<DMSCGateway>,
88        config: Arc<RwLock<DMSCGatewayConfig>>,
89    ) -> Result<HyperResponse<Body>, Infallible> {
90        let request_id = uuid::Uuid::new_v4().to_string();
91        let start = std::time::Instant::now();
92
93        let method = req.method().to_string();
94        let path = req.uri().path().to_string();
95        let remote_addr = req
96            .headers()
97            .get("X-Forwarded-For")
98            .and_then(|v| v.to_str().ok())
99            .map(|s| s.to_string())
100            .unwrap_or_else(|| {
101                req.extensions()
102                    .get::<SocketAddr>()
103                    .map(|a| a.to_string())
104                    .unwrap_or_else(|| "unknown".to_string())
105            });
106
107        let mut headers = HashMap::new();
108        for (key, value) in req.headers() {
109            if let Ok(v) = value.to_str() {
110                headers.insert(key.as_str().to_string(), v.to_string());
111            }
112        }
113
114        let query_params = {
115            let uri = req.uri();
116            let query = uri.query().unwrap_or("");
117            let mut params = HashMap::new();
118            for pair in query.split('&') {
119                if let Some((key, value)) = pair.split_once('=') {
120                    params.insert(
121                        key.to_string(),
122                        value.to_string(),
123                    );
124                }
125            }
126            params
127        };
128
129        let body = match hyper::body::to_bytes(req.into_body()).await {
130            Ok(bytes) => {
131                if bytes.is_empty() {
132                    None
133                } else {
134                    Some(bytes.to_vec())
135                }
136            }
137            Err(_) => None,
138        };
139
140        let dmsc_request = DMSCGatewayRequest::new(
141            method.clone(),
142            path.clone(),
143            headers,
144            query_params,
145            body,
146            remote_addr.clone(),
147        );
148
149        let response = gateway.handle_request(dmsc_request).await;
150
151        let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
152
153        if config.read().await.enable_logging {
154            let log_level = &config.read().await.log_level;
155            match log_level.as_str() {
156                "debug" => {
157                    log::debug!(
158                        target: "DMSC.Gateway",
159                        "{} {} {} {} {}ms",
160                        method,
161                        path,
162                        response.status_code,
163                        request_id,
164                        duration_ms
165                    );
166                }
167                "info" => {
168                    log::info!(
169                        target: "DMSC.Gateway",
170                        "{} {} {} {}ms",
171                        method,
172                        path,
173                        response.status_code,
174                        duration_ms
175                    );
176                }
177                "warn" => {
178                    log::warn!(
179                        target: "DMSC.Gateway",
180                        "{} {} {} {}ms",
181                        method,
182                        path,
183                        response.status_code,
184                        duration_ms
185                    );
186                }
187                "error" => {
188                    log::error!(
189                        target: "DMSC.Gateway",
190                        "{} {} {} {}ms",
191                        method,
192                        path,
193                        response.status_code,
194                        duration_ms
195                    );
196                }
197                _ => {}
198            }
199        }
200
201        let mut hyper_response = HyperResponse::builder()
202            .status(StatusCode::from_u16(response.status_code).unwrap_or(StatusCode::OK));
203
204        for (key, value) in response.headers {
205            if let (Ok(k), Ok(v)) = (key.parse::<hyper::header::HeaderName>(), value.parse::<hyper::header::HeaderValue>()) {
206                hyper_response = hyper_response.header(k, v);
207            }
208        }
209
210        let body = Body::from(response.body);
211        Ok(hyper_response.body(body).unwrap_or_else(|_| HyperResponse::default()))
212    }
213
214    pub async fn shutdown(&mut self) {
215        if let Some(tx) = self.shutdown_tx.take() {
216            let _ = tx.send(());
217        }
218    }
219}
220
221impl Drop for DMSCGatewayServer {
222    fn drop(&mut self) {
223        if let Some(tx) = self.shutdown_tx.take() {
224            let _ = tx.send(());
225        }
226    }
227}
228
229pub fn load_tls_config(
230    cert_path: &str,
231    key_path: &str,
232) -> DMSCResult<ServerConfig> {
233    let cert = std::fs::read(cert_path)
234        .map_err(|e| DMSCError::Config(format!("Failed to read TLS certificate: {}", e)))?;
235    let key = std::fs::read(key_path)
236        .map_err(|e| DMSCError::Config(format!("Failed to read TLS key: {}", e)))?;
237
238    let cert_chain = tokio_rustls::rustls::Certificate(cert);
239    let private_key = tokio_rustls::rustls::PrivateKey(key);
240
241    let mut server_config = tokio_rustls::rustls::ServerConfig::builder()
242        .with_safe_defaults()
243        .with_no_client_auth()
244        .with_single_cert(vec![cert_chain], private_key)
245        .map_err(|e| DMSCError::Config(format!("Failed to build TLS config: {}", e)))?;
246
247    server_config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
248
249    Ok(server_config)
250}