dmsc/gateway/load_balancer.rs
1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//! http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18#![allow(non_snake_case)]
19
20//! # Load Balancer Module
21//!
22//! This module provides a robust load balancer implementation for distributing incoming requests
23//! across multiple backend servers. It supports various load balancing strategies and includes
24//! health checking, connection management, and detailed statistics.
25//!
26//! ## Key Components
27//!
28//! - **DMSCLoadBalancerStrategy**: Enum representing different load balancing algorithms
29//! - **DMSCBackendServer**: Represents a backend server with configuration and health status
30//! - **DMSCLoadBalancer**: Main load balancer implementation
31//! - **DMSCLoadBalancerServerStats**: Metrics for monitoring server performance
32//!
33//! ## Design Principles
34//!
35//! 1. **Multiple Strategies**: Supports RoundRobin, WeightedRoundRobin, LeastConnections, Random, and IpHash
36//! 2. **Health Checking**: Automatic periodic health checks to ensure traffic is only sent to healthy servers
37//! 3. **Connection Management**: Tracks active connections and enforces max connections per server
38//! 4. **Detailed Statistics**: Collects metrics on requests, failures, and response times
39//! 5. **Thread Safety**: Uses Arc and RwLock for safe operation in multi-threaded environments
40//! 6. **Scalability**: Designed to handle large numbers of servers and requests
41//! 7. **Configurable**: Allows fine-tuning of server weights, max connections, and health check paths
42//! 8. **Async Compatibility**: Built with async/await patterns for modern Rust applications
43//!
44//! ## Usage
45//!
46//! ```rust
47//! use dmsc::prelude::*;
48//! use std::sync::Arc;
49//!
50//! async fn example() -> DMSCResult<()> {
51//! // Create a load balancer with Round Robin strategy
52//! let lb = Arc::new(DMSCLoadBalancer::new(DMSCLoadBalancerStrategy::RoundRobin));
53//!
54//! // Add backend servers
55//! lb.add_server(DMSCBackendServer::new("server1".to_string(), "http://localhost:8081".to_string())
56//! .with_weight(2)
57//! .with_max_connections(200))
58//! .await;
59//!
60//! lb.add_server(DMSCBackendServer::new("server2".to_string(), "http://localhost:8082".to_string())
61//! .with_weight(1)
62//! .with_max_connections(100))
63//! .await;
64//!
65//! // Start periodic health checks every 30 seconds
66//! lb.clone().start_health_checks(30).await;
67//!
68//! // Select a server for a client request
69//! let server = lb.select_server(Some("192.168.1.1")).await?;
70//! println!("Selected server: {}", server.url);
71//!
72//! // Record response time when done
73//! lb.record_response_time(&server.id, 150).await;
74//!
75//! // Release the server when the request is complete
76//! lb.release_server(&server.id).await;
77//!
78//! // Get server statistics
79//! let stats = lb.get_all_stats().await;
80//! println!("Server stats: {:?}", stats);
81//!
82//! Ok(())
83//! }
84//! ```
85
86use crate::core::DMSCResult;
87use std::collections::HashMap;
88use std::sync::Arc;
89use std::sync::atomic::{AtomicUsize, Ordering};
90use std::time::Instant;
91use tokio::sync::RwLock;
92use std::sync::RwLock as StdRwLock;
93
94#[cfg(feature = "gateway")]
95use hyper;
96
97/// Load balancing strategies supported by DMSC.
98///
99/// These strategies determine how the load balancer selects which backend server
100/// to route incoming requests to.
101#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
102#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
103pub enum DMSCLoadBalancerStrategy {
104 /// **Round Robin**: Sequentially selects the next available server in rotation.
105 ///
106 /// Simple and fair distribution, ideal for servers with similar capabilities.
107 RoundRobin,
108
109 /// **Weighted Round Robin**: Selects servers based on assigned weights.
110 ///
111 /// Allows more powerful servers to handle a larger share of traffic.
112 WeightedRoundRobin,
113
114 /// **Least Connections**: Selects the server with the fewest active connections.
115 ///
116 /// Ideal for handling varying request durations, ensuring balanced load.
117 LeastConnections,
118
119 /// **Random**: Randomly selects an available server.
120 ///
121 /// Simple implementation with good distribution characteristics.
122 Random,
123
124 /// **IP Hash**: Uses client IP address to consistently route to the same server.
125 ///
126 /// Maintains session persistence by mapping clients to specific servers.
127 IpHash,
128
129 /// **Least Response Time**: Selects the server with the lowest average response time.
130 ///
131 /// Ideal for optimizing user experience by directing traffic to the fastest servers.
132 LeastResponseTime,
133
134 /// **Consistent Hash**: Uses a consistent hashing algorithm for server selection.
135 ///
136 /// Provides stable mapping between requests and servers, minimizing disruption when servers are added or removed.
137 ConsistentHash,
138}
139
140/// Represents a backend server in the load balancer.
141///
142/// This struct contains all the configuration and state information for a backend server,
143/// including its ID, URL, weight, max connections, health check path, and health status.
144#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
145#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
146pub struct DMSCBackendServer {
147 /// Unique identifier for the server
148 pub id: String,
149
150 /// Base URL of the server (e.g., "http://localhost:8080")
151 pub url: String,
152
153 /// Weight assigned to the server for weighted load balancing strategies
154 pub weight: u32,
155
156 /// Maximum number of concurrent connections allowed to this server
157 pub max_connections: usize,
158
159 /// Path to check for health status (e.g., "/health")
160 pub health_check_path: String,
161
162 /// Current health status of the server (true = healthy, false = unhealthy)
163 pub is_healthy: bool,
164}
165
166impl DMSCBackendServer {
167 /// Creates a new backend server with the specified ID and URL.
168 ///
169 /// # Parameters
170 ///
171 /// - `id`: Unique identifier for the server
172 /// - `url`: Base URL of the server
173 ///
174 /// # Returns
175 ///
176 /// A new `DMSCBackendServer` instance with default values
177 pub fn new(id: String, url: String) -> Self {
178 Self {
179 id,
180 url,
181 weight: 1,
182 max_connections: 100,
183 health_check_path: "/health".to_string(),
184 is_healthy: true,
185 }
186 }
187}
188
189#[cfg_attr(feature = "pyo3", pyo3::prelude::pymethods)]
190impl DMSCBackendServer {
191 #[cfg(feature = "pyo3")]
192 #[new]
193 fn py_new(id: String, url: String) -> Self {
194 Self::new(id, url)
195 }
196
197 #[cfg(feature = "pyo3")]
198 fn set_weight(&mut self, weight: u32) {
199 self.weight = weight;
200 }
201
202 #[cfg(feature = "pyo3")]
203 fn set_max_connections(&mut self, max_connections: usize) {
204 self.max_connections = max_connections;
205 }
206
207 #[cfg(feature = "pyo3")]
208 fn set_health_check_path(&mut self, path: String) {
209 self.health_check_path = path;
210 }
211}
212
213/// Internal server statistics tracking.
214///
215/// This struct tracks real-time statistics for each backend server, including active connections,
216/// request counts, failures, and response times. It is designed to be thread-safe for use in
217/// multi-threaded environments.
218#[derive(Debug)]
219struct ServerStats {
220 /// Number of currently active connections to the server
221 active_connections: AtomicUsize,
222
223 /// Total number of requests sent to the server since it was added
224 total_requests: AtomicUsize,
225
226 /// Number of failed requests to the server
227 failed_requests: AtomicUsize,
228
229 /// Most recent response time in milliseconds
230 response_time_ms: AtomicUsize,
231
232 /// Timestamp of when the server was last used
233 last_used: StdRwLock<Instant>,
234}
235
236impl ServerStats {
237 /// Creates a new server statistics instance with default values.
238 ///
239 /// Initializes all counters to zero and sets the last used time to now.
240 ///
241 /// # Returns
242 ///
243 /// A new `ServerStats` instance with default values
244 fn new() -> Self {
245 Self {
246 active_connections: AtomicUsize::new(0),
247 total_requests: AtomicUsize::new(0),
248 failed_requests: AtomicUsize::new(0),
249 response_time_ms: AtomicUsize::new(0),
250 last_used: StdRwLock::new(Instant::now()),
251 }
252 }
253
254 /// Gets the current number of active connections to the server.
255 ///
256 /// # Returns
257 ///
258 /// The number of active connections as a `usize`
259 fn get_active_connections(&self) -> usize {
260 self.active_connections.load(Ordering::Relaxed)
261 }
262
263 /// Increments the active connection count and updates request statistics.
264 ///
265 /// This method should be called when a new connection is established to the server.
266 ///
267 /// - Increments active_connections by 1
268 /// - Increments total_requests by 1
269 /// - Updates last_used to the current time
270 fn increment_connections(&self) {
271 self.active_connections.fetch_add(1, Ordering::Relaxed);
272 self.total_requests.fetch_add(1, Ordering::Relaxed);
273 if let Ok(mut last_used) = self.last_used.write() {
274 *last_used = Instant::now();
275 }
276 }
277
278 /// Decrements the active connection count.
279 ///
280 /// This method should be called when a connection to the server is closed.
281 fn decrement_connections(&self) {
282 self.active_connections.fetch_sub(1, Ordering::Relaxed);
283 }
284
285 /// Records a failed request to the server.
286 ///
287 /// This method should be called when a request to the server fails.
288 ///
289 /// - Increments failed_requests by 1
290 /// - Decrements active_connections by 1 (since the connection failed)
291 fn record_failure(&self) {
292 self.failed_requests.fetch_add(1, Ordering::Relaxed);
293 self.decrement_connections();
294 }
295
296 /// Records the response time for a successful request.
297 ///
298 /// This method should be called when a request to the server completes successfully.
299 ///
300 /// # Parameters
301 ///
302 /// - `response_time_ms`: Response time in milliseconds
303 fn record_response_time(&self, response_time_ms: u64) {
304 self.response_time_ms.store(response_time_ms as usize, Ordering::Relaxed);
305 }
306
307 /// Gets a snapshot of the current server statistics.
308 ///
309 /// This method converts the internal statistics into a public-facing `DMSCLoadBalancerServerStats` struct.
310 ///
311 /// # Returns
312 ///
313 /// A `DMSCLoadBalancerServerStats` struct containing the current statistics
314 fn get_stats(&self) -> DMSCLoadBalancerServerStats {
315 DMSCLoadBalancerServerStats {
316 active_connections: self.get_active_connections(),
317 total_requests: self.total_requests.load(Ordering::Relaxed),
318 failed_requests: self.failed_requests.load(Ordering::Relaxed),
319 response_time_ms: self.response_time_ms.load(Ordering::Relaxed),
320 }
321 }
322}
323
324/// Load balancer server statistics for monitoring and reporting.
325///
326/// This struct contains metrics for a backend server, providing insights into its
327/// performance, load, and reliability.
328#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
329#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
330pub struct DMSCLoadBalancerServerStats {
331 /// Number of currently active connections to the server
332 pub active_connections: usize,
333
334 /// Total number of requests sent to the server since it was added
335 pub total_requests: usize,
336
337 /// Number of failed requests to the server
338 pub failed_requests: usize,
339
340 /// Most recent response time in milliseconds
341 pub response_time_ms: usize,
342}
343
344/// Main load balancer implementation.
345///
346/// This struct provides a comprehensive load balancing solution with support for multiple
347/// strategies, health checking, connection management, and detailed statistics.
348#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
349pub struct DMSCLoadBalancer {
350 /// Load balancing strategy to use
351 strategy: DMSCLoadBalancerStrategy,
352
353 /// List of backend servers
354 servers: RwLock<Vec<DMSCBackendServer>>,
355
356 /// Statistics for each backend server
357 server_stats: RwLock<HashMap<String, Arc<ServerStats>>>,
358
359 /// Counter for round robin scheduling
360 round_robin_counter: AtomicUsize,
361}
362
363impl Clone for DMSCLoadBalancer {
364 /// Creates a clone of the load balancer.
365 ///
366 /// Note: The clone will have the same strategy and counter, but empty servers and stats
367 /// since we can't await in the Clone trait.
368 fn clone(&self) -> Self {
369 Self {
370 strategy: self.strategy.clone(),
371 servers: RwLock::new(Vec::new()),
372 server_stats: RwLock::new(HashMap::new()),
373 round_robin_counter: AtomicUsize::new(self.round_robin_counter.load(Ordering::Relaxed)),
374 }
375 }
376}
377
378impl DMSCLoadBalancer {
379 /// Creates a new load balancer with the specified strategy.
380 ///
381 /// # Parameters
382 ///
383 /// - `strategy`: The load balancing strategy to use
384 ///
385 /// # Returns
386 ///
387 /// A new `DMSCLoadBalancer` instance with the specified strategy
388 pub fn new(strategy: DMSCLoadBalancerStrategy) -> Self {
389 Self {
390 strategy,
391 servers: RwLock::new(Vec::new()),
392 server_stats: RwLock::new(HashMap::new()),
393 round_robin_counter: AtomicUsize::new(0),
394 }
395 }
396
397 /// Adds a backend server to the load balancer.
398 ///
399 /// # Parameters
400 ///
401 /// - `server`: The backend server to add
402 pub async fn add_server(&self, server: DMSCBackendServer) {
403 let mut servers = self.servers.write().await;
404 let mut stats = self.server_stats.write().await;
405
406 servers.push(server.clone());
407 stats.insert(server.id.clone(), Arc::new(ServerStats::new()));
408 }
409
410 /// Removes a backend server from the load balancer.
411 ///
412 /// # Parameters
413 ///
414 /// - `server_id`: The ID of the server to remove
415 ///
416 /// # Returns
417 ///
418 /// `true` if the server was removed, `false` otherwise
419 pub async fn remove_server(&self, server_id: &str) -> bool {
420 let mut servers = self.servers.write().await;
421 let mut stats = self.server_stats.write().await;
422
423 let initial_len = servers.len();
424 servers.retain(|s| s.id != server_id);
425 stats.remove(server_id);
426
427 servers.len() < initial_len
428 }
429
430 /// Gets a list of all healthy backend servers.
431 ///
432 /// # Returns
433 ///
434 /// A vector of healthy `DMSCBackendServer` instances
435 pub async fn get_healthy_servers(&self) -> Vec<DMSCBackendServer> {
436 let servers = self.servers.read().await;
437 servers.iter()
438 .filter(|s| s.is_healthy)
439 .cloned()
440 .collect()
441 }
442
443 /// Selects the most appropriate backend server for a client request.
444 ///
445 /// This method applies the configured load balancing strategy to select a server,
446 /// considering only healthy servers with available connections.
447 ///
448 /// # Parameters
449 ///
450 /// - `client_ip`: Optional client IP address for IP Hash strategy
451 ///
452 /// # Returns
453 ///
454 /// A `DMSCResult<DMSCBackendServer>` with the selected server, or an error if no servers are available
455 pub async fn select_server(&self, client_ip: Option<&str>) -> DMSCResult<DMSCBackendServer> {
456 let healthy_servers = self.get_healthy_servers().await;
457
458 if healthy_servers.is_empty() {
459 return Err(crate::core::DMSCError::Other("No healthy servers available".to_string()));
460 }
461
462 let stats = self.server_stats.read().await;
463
464 // Filter servers that have available connections
465 let available_servers: Vec<DMSCBackendServer> = healthy_servers.into_iter()
466 .filter(|server| {
467 if let Some(server_stats) = stats.get(&server.id) {
468 let connections = server_stats.get_active_connections();
469 connections < server.max_connections
470 } else {
471 true // If no stats, assume server is available
472 }
473 })
474 .collect();
475
476 if available_servers.is_empty() {
477 return Err(crate::core::DMSCError::Other("No servers with available connections".to_string()));
478 }
479
480 let server = match self.strategy {
481 DMSCLoadBalancerStrategy::RoundRobin => self.select_round_robin(&available_servers).await,
482 DMSCLoadBalancerStrategy::WeightedRoundRobin => self.select_weighted_round_robin(&available_servers).await,
483 DMSCLoadBalancerStrategy::LeastConnections => self.select_least_connections(&available_servers).await,
484 DMSCLoadBalancerStrategy::Random => self.select_random(&available_servers),
485 DMSCLoadBalancerStrategy::IpHash => self.select_ip_hash(&available_servers, client_ip),
486 DMSCLoadBalancerStrategy::LeastResponseTime => self.select_least_response_time(&available_servers).await,
487 DMSCLoadBalancerStrategy::ConsistentHash => self.select_consistent_hash(&available_servers, client_ip),
488 };
489
490 if let Some(server) = server {
491 // Increment connection count
492 if let Some(stats) = self.server_stats.read().await.get(&server.id) {
493 stats.increment_connections();
494 }
495 Ok(server)
496 } else {
497 Err(crate::core::DMSCError::Other("Failed to select server".to_string()))
498 }
499 }
500
501 /// Selects a server using the Round Robin strategy.
502 ///
503 /// This method sequentially selects the next available server in rotation.
504 ///
505 /// # Parameters
506 ///
507 /// - `servers`: List of available servers
508 ///
509 /// # Returns
510 ///
511 /// The selected server, or `None` if no servers are available
512 async fn select_round_robin(&self, servers: &[DMSCBackendServer]) -> Option<DMSCBackendServer> {
513 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
514 let index = counter % servers.len();
515 servers.get(index).cloned()
516 }
517
518 /// Selects a server using the Smooth Weighted Round Robin strategy.
519 ///
520 /// This method uses a smooth weighted round robin algorithm to distribute traffic more evenly
521 /// across servers, avoiding the problem of sudden traffic spikes to high-weight servers.
522 ///
523 /// # Parameters
524 ///
525 /// - `servers`: List of available servers with weights
526 ///
527 /// # Returns
528 ///
529 /// The selected server, or `None` if no servers are available
530 async fn select_weighted_round_robin(&self, servers: &[DMSCBackendServer]) -> Option<DMSCBackendServer> {
531 if servers.is_empty() {
532 return None;
533 }
534
535 // Simple weighted round robin implementation with improved distribution
536 let total_weight: u32 = servers.iter().map(|s| s.weight).sum();
537 let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
538 let weighted_index = counter % total_weight as usize;
539
540 let mut accumulated_weight = 0;
541 for server in servers {
542 accumulated_weight += server.weight as usize;
543 if weighted_index < accumulated_weight {
544 return Some(server.clone());
545 }
546 }
547
548 servers.first().cloned()
549 }
550
551 /// Selects a server using the Weighted Least Connections strategy.
552 ///
553 /// This method selects the server with the fewest active connections relative to its weight,
554 /// ensuring balanced load across servers with different capacities.
555 ///
556 /// # Parameters
557 ///
558 /// - `servers`: List of available servers
559 ///
560 /// # Returns
561 ///
562 /// The selected server, or `None` if no servers are available
563 async fn select_least_connections(&self, servers: &[DMSCBackendServer]) -> Option<DMSCBackendServer> {
564 let stats = self.server_stats.read().await;
565
566 let mut best_server = None;
567 let mut best_score = f64::MAX; // Lower score is better (connections per weight unit)
568
569 for server in servers {
570 if let Some(server_stats) = stats.get(&server.id) {
571 let connections = server_stats.get_active_connections();
572
573 // Skip servers that have reached max connections
574 if connections >= server.max_connections {
575 continue;
576 }
577
578 // Calculate score as connections per weight unit
579 // Use a small epsilon to avoid division by zero
580 let weight = server.weight as f64 + 0.001;
581 let score = connections as f64 / weight;
582
583 if score < best_score {
584 best_score = score;
585 best_server = Some(server.clone());
586 }
587 }
588 }
589
590 best_server.or_else(|| servers.first().cloned())
591 }
592
593 /// Selects a server using the Random strategy.
594 ///
595 /// This method randomly selects an available server, providing good distribution characteristics.
596 ///
597 /// # Parameters
598 ///
599 /// - `servers`: List of available servers
600 ///
601 /// # Returns
602 ///
603 /// The selected server, or `None` if no servers are available
604 fn select_random(&self, servers: &[DMSCBackendServer]) -> Option<DMSCBackendServer> {
605 use rand::Rng;
606 let mut rng = rand::thread_rng();
607 let index = rng.gen_range(0..servers.len());
608 servers.get(index).cloned()
609 }
610
611 /// Selects a server using the Least Response Time strategy.
612 ///
613 /// This method selects the server with the lowest response time, optimizing for user experience.
614 ///
615 /// # Parameters
616 ///
617 /// - `servers`: List of available servers
618 ///
619 /// # Returns
620 ///
621 /// The selected server, or `None` if no servers are available
622 async fn select_least_response_time(&self, servers: &[DMSCBackendServer]) -> Option<DMSCBackendServer> {
623 let stats = self.server_stats.read().await;
624
625 let mut best_server = None;
626 let mut min_response_time = u64::MAX;
627
628 for server in servers {
629 if let Some(server_stats) = stats.get(&server.id) {
630 let response_time = server_stats.response_time_ms.load(Ordering::Relaxed) as u64;
631 if response_time < min_response_time {
632 min_response_time = response_time;
633 best_server = Some(server.clone());
634 }
635 }
636 }
637
638 best_server.or_else(|| servers.first().cloned())
639 }
640
641 /// Selects a server using the Consistent Hash strategy.
642 ///
643 /// This method uses a consistent hashing algorithm to map requests to servers, minimizing
644 /// disruption when servers are added or removed.
645 ///
646 /// # Parameters
647 ///
648 /// - `servers`: List of available servers
649 /// - `client_ip`: Optional client IP address for hashing
650 ///
651 /// # Returns
652 ///
653 /// The selected server, or `None` if no servers are available
654 fn select_ip_hash(&self, servers: &[DMSCBackendServer], client_ip: Option<&str>) -> Option<DMSCBackendServer> {
655 if let Some(ip) = client_ip {
656 let hash = self.hash_ip(ip);
657 let index = hash as usize % servers.len();
658 servers.get(index).cloned()
659 } else {
660 self.select_random(servers)
661 }
662 }
663
664 /// Hashes an IP address for the IP Hash strategy.
665 ///
666 /// # Parameters
667 ///
668 /// - `ip`: IP address to hash
669 ///
670 /// # Returns
671 ///
672 /// A 64-bit hash value of the IP address
673 fn hash_ip(&self, ip: &str) -> u64 {
674 use std::collections::hash_map::DefaultHasher;
675 use std::hash::{Hash, Hasher};
676
677 let mut hasher = DefaultHasher::new();
678 ip.hash(&mut hasher);
679 hasher.finish()
680 }
681
682 /// Selects a server using the Consistent Hash strategy.
683 ///
684 /// This method uses a consistent hashing algorithm to map requests to servers, minimizing
685 /// disruption when servers are added or removed.
686 ///
687 /// # Parameters
688 ///
689 /// - `servers`: List of available servers
690 /// - `client_ip`: Optional client IP address for hashing
691 ///
692 /// # Returns
693 ///
694 /// The selected server, or `None` if no servers are available
695 fn select_consistent_hash(&self, servers: &[DMSCBackendServer], client_ip: Option<&str>) -> Option<DMSCBackendServer> {
696 if servers.is_empty() {
697 return None;
698 }
699
700 let key = client_ip.unwrap_or("127.0.0.1");
701
702 // Create a sorted list of server hashes
703 let mut server_hashes: Vec<(u64, DMSCBackendServer)> = servers
704 .iter()
705 .map(|server| {
706 let hash = self.hash_ip(&server.id);
707 (hash, server.clone())
708 })
709 .collect();
710
711 // Sort server hashes
712 server_hashes.sort_by(|a, b| a.0.cmp(&b.0));
713
714 // Calculate hash for the key
715 let key_hash = self.hash_ip(key);
716
717 // Find the first server with hash >= key_hash
718 for (server_hash, server) in &server_hashes {
719 if *server_hash >= key_hash {
720 return Some(server.clone());
721 }
722 }
723
724 // If no server with hash >= key_hash, return the first server
725 server_hashes.first().map(|(_, server)| server.clone())
726 }
727
728 /// Releases a server after a request is completed.
729 ///
730 /// This method decrements the active connection count for the specified server.
731 ///
732 /// # Parameters
733 ///
734 /// - `server_id`: ID of the server to release
735 pub async fn release_server(&self, server_id: &str) {
736 if let Some(stats) = self.server_stats.read().await.get(server_id) {
737 stats.decrement_connections();
738 }
739 }
740
741 /// Records a failed request to a server.
742 ///
743 /// This method increments the failed request count and may mark the server as unhealthy
744 /// if the failure rate exceeds a threshold.
745 ///
746 /// # Parameters
747 ///
748 /// - `server_id`: ID of the server that failed
749 pub async fn record_server_failure(&self, server_id: &str) {
750 if let Some(stats) = self.server_stats.read().await.get(server_id) {
751 stats.record_failure();
752 }
753
754 // Mark server as unhealthy if too many failures
755 let mut servers = self.servers.write().await;
756 if let Some(server) = servers.iter_mut().find(|s| s.id == server_id) {
757 // Simple heuristic: mark unhealthy if failure rate > 50%
758 if let Some(stats) = self.server_stats.read().await.get(server_id) {
759 let total = stats.total_requests.load(Ordering::Relaxed);
760 let failed = stats.failed_requests.load(Ordering::Relaxed);
761
762 if total > 10 && (failed as f64 / total as f64) > 0.5 {
763 server.is_healthy = false;
764 }
765 }
766 }
767 }
768
769 /// Records the response time for a successful request.
770 ///
771 /// # Parameters
772 ///
773 /// - `server_id`: ID of the server that handled the request
774 /// - `response_time_ms`: Response time in milliseconds
775 pub async fn record_response_time(&self, server_id: &str, response_time_ms: u64) {
776 if let Some(stats) = self.server_stats.read().await.get(server_id) {
777 stats.record_response_time(response_time_ms);
778 }
779 }
780
781 /// Gets statistics for a specific server.
782 ///
783 /// # Parameters
784 ///
785 /// - `server_id`: ID of the server to get statistics for
786 ///
787 /// # Returns
788 ///
789 /// An `Option<DMSCLoadBalancerServerStats>` with the server statistics, or `None` if the server doesn't exist
790 pub async fn get_server_stats(&self, server_id: &str) -> Option<DMSCLoadBalancerServerStats> {
791 self.server_stats.read().await
792 .get(server_id)
793 .map(|stats| stats.get_stats())
794 }
795
796 /// Gets statistics for all servers.
797 ///
798 /// # Returns
799 ///
800 /// A `HashMap<String, DMSCLoadBalancerServerStats>` with statistics for all servers
801 pub async fn get_all_stats(&self) -> HashMap<String, DMSCLoadBalancerServerStats> {
802 let stats = self.server_stats.read().await;
803 let mut result = HashMap::new();
804
805 for (server_id, server_stats) in stats.iter() {
806 result.insert(server_id.clone(), server_stats.get_stats());
807 }
808
809 result
810 }
811
812 /// Marks a server as healthy or unhealthy.
813 ///
814 /// # Parameters
815 ///
816 /// - `server_id`: ID of the server to update
817 /// - `healthy`: New health status (true = healthy, false = unhealthy)
818 pub async fn mark_server_healthy(&self, server_id: &str, healthy: bool) {
819 let mut servers = self.servers.write().await;
820 if let Some(server) = servers.iter_mut().find(|s| s.id == server_id) {
821 server.is_healthy = healthy;
822 }
823 }
824
825 /// Performs an HTTP health check on a server.
826 ///
827 /// This method sends an HTTP GET request to the server's health check path and
828 /// considers the server healthy if it returns a 2xx status code.
829 ///
830 /// # Parameters
831 ///
832 /// - `server_id`: ID of the server to check
833 ///
834 /// # Returns
835 ///
836 /// `true` if the server is healthy, `false` otherwise
837 #[cfg(feature = "gateway")]
838 pub async fn perform_health_check(&self, server_id: &str) -> bool {
839 let servers = self.servers.read().await;
840
841 if let Some(server) = servers.iter().find(|s| s.id == server_id) {
842 let health_check_url = format!("{}{}", server.url, server.health_check_path);
843
844 let uri = match hyper::Uri::from_maybe_shared(health_check_url.clone()) {
845 Ok(uri) => uri,
846 Err(e) => {
847 // Log warning for invalid health check URL
848 if let Ok(fs) = crate::fs::DMSCFileSystem::new_auto_root() {
849 let logger = crate::log::DMSCLogger::new(&crate::log::DMSCLogConfig::default(), fs);
850 let _ = logger.warn("load_balancer", format!("Invalid health check URL for server {server_id}: {e}"));
851 }
852 return false;
853 }
854 };
855
856 match hyper::Client::builder().build::<_, hyper::Body>(hyper::client::HttpConnector::new()).get(uri).await {
857 Ok(response) => {
858 // Consider server healthy if status code is 2xx
859 (200..300).contains(&response.status().as_u16())
860 },
861 Err(_) => false,
862 }
863 } else {
864 false
865 }
866 }
867
868 #[cfg(not(feature = "gateway"))]
869 pub async fn perform_health_check(&self, _server_id: &str) -> bool {
870 // If gateway feature is not enabled, assume all servers are healthy
871 true
872 }
873
874 /// Starts periodic health checks for all servers.
875 ///
876 /// This method spawns a background task that performs health checks on all servers
877 /// at the specified interval.
878 ///
879 /// # Parameters
880 ///
881 /// - `interval_secs`: Interval between health checks in seconds
882 ///
883 /// # Returns
884 ///
885 /// A `tokio::task::JoinHandle` for the background health check task
886 pub async fn start_health_checks(self: Arc<Self>, interval_secs: u64) -> tokio::task::JoinHandle<()> {
887 let this = self.clone();
888
889 tokio::spawn(async move {
890 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(interval_secs));
891
892 loop {
893 interval.tick().await;
894
895 let servers = this.servers.read().await;
896 let server_ids: Vec<String> = servers.iter().map(|s| s.id.clone()).collect();
897
898 for server_id in server_ids {
899 let is_healthy = this.perform_health_check(&server_id).await;
900 let _ = this.mark_server_healthy(&server_id, is_healthy).await;
901
902 // If server is unhealthy, record the failure
903 if !is_healthy {
904 let _ = this.record_server_failure(&server_id).await;
905 }
906 }
907 }
908 })
909 }
910
911 /// Gets the current load balancing strategy.
912 ///
913 /// # Returns
914 ///
915 /// A reference to the current `DMSCLoadBalancerStrategy`
916 pub fn get_strategy(&self) -> &DMSCLoadBalancerStrategy {
917 &self.strategy
918 }
919
920 /// Sets the load balancing strategy.
921 ///
922 /// # Parameters
923 ///
924 /// - `strategy`: The new load balancing strategy to use
925 pub async fn set_strategy(&mut self, strategy: DMSCLoadBalancerStrategy) {
926 self.strategy = strategy;
927 }
928
929 /// Gets the total number of servers.
930 ///
931 /// # Returns
932 ///
933 /// The total number of servers in the load balancer
934 pub async fn get_server_count(&self) -> usize {
935 self.servers.read().await.len()
936 }
937
938 /// Gets the number of healthy servers.
939 ///
940 /// # Returns
941 ///
942 /// The number of healthy servers in the load balancer
943 pub async fn get_healthy_server_count(&self) -> usize {
944 self.get_healthy_servers().await.len()
945 }
946}