dmsc/gateway/
rate_limiter.rs

1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18#![allow(non_snake_case)]
19
20//! # Rate Limiter Module
21//! 
22//! This module provides rate limiting functionality for the DMSC gateway, allowing for
23//! controlling the rate of requests from clients to prevent abuse and ensure fair usage.
24//! 
25//! ## Key Components
26//! 
27//! - **DMSCRateLimitConfig**: Configuration for rate limiting behavior
28//! - **DMSCRateLimiter**: Token bucket based rate limiter implementation
29//! - **DMSCSlidingWindowRateLimiter**: Sliding window based rate limiter for fine-grained control
30//! - **DMSCRateLimitStats**: Metrics for monitoring rate limiter performance
31//! 
32//! ## Design Principles
33//! 
34//! 1. **Token Bucket Algorithm**: Implements the token bucket algorithm for smooth rate limiting
35//! 2. **Sliding Window**: Provides a sliding window implementation for more precise control
36//! 3. **Thread Safe**: Uses Arc and RwLock for safe operation in multi-threaded environments
37//! 4. **Configurable**: Allows fine-tuning of requests per second, burst size, and window duration
38//! 5. **Metrics Collection**: Tracks and reports rate limiter statistics
39//! 6. **Async Compatibility**: Built with async/await patterns for modern Rust applications
40//! 7. **Burst Support**: Allows for temporary bursts of requests beyond the steady rate
41//! 8. **Key-Based Limiting**: Supports rate limiting by client IP or custom keys
42//! 
43//! ## Usage
44//! 
45//! ```rust
46//! use dmsc::prelude::*;
47//! 
48//! async fn example() {
49//!     // Create a rate limiter with default configuration
50//!     let mut limiter = DMSCRateLimiter::new(DMSCRateLimitConfig::default());
51//!     
52//!     // Check if a request should be allowed
53//!     let client_ip = "192.168.1.1";
54//!     if limiter.check_rate_limit(client_ip, 1).await {
55//!         println!("Request allowed");
56//!     } else {
57//!         println!("Request rate limited");
58//!     }
59//!     
60//!     // Get rate limit stats for a client
61//!     if let Some(stats) = limiter.get_stats(client_ip).await {
62//!         println!("Current tokens: {}, Total requests: {}", 
63//!             stats.current_tokens, stats.total_requests);
64//!     }
65//!     
66//!     // Create a sliding window rate limiter
67//!     let sliding_limiter = DMSCSlidingWindowRateLimiter::new(100, 60);
68//!     if sliding_limiter.allow_request().await {
69//!         println!("Sliding window request allowed");
70//!     }
71//! }
72//! ```
73
74use std::collections::HashMap;
75use std::sync::Arc;
76use std::sync::atomic::{AtomicUsize, Ordering};
77use tokio::sync::RwLock;
78use std::time::{Duration, Instant};
79
80/// Configuration for rate limiting behavior.
81/// 
82/// This struct defines the parameters that control how the rate limiter behaves,
83/// including the steady rate, burst capacity, and window duration.
84#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
85#[derive(Debug, Clone)]
86pub struct DMSCRateLimitConfig {
87    /// Maximum number of requests allowed per second in steady state
88    pub requests_per_second: u32,
89    
90    /// Maximum number of requests allowed in a burst (temporary spike)
91    pub burst_size: u32,
92    
93    /// Duration of the rate limiting window in seconds
94    pub window_seconds: u64,
95}
96
97#[cfg(feature = "pyo3")]
98#[pyo3::prelude::pymethods]
99impl DMSCRateLimitConfig {
100    #[new]
101    fn py_new() -> Self {
102        Self::default()
103    }
104    
105    #[staticmethod]
106    fn py_new_with_values(requests_per_second: u32, burst_size: u32, window_seconds: u64) -> Self {
107        Self {
108            requests_per_second,
109            burst_size,
110            window_seconds,
111        }
112    }
113    
114    fn get_requests_per_second(&self) -> u32 {
115        self.requests_per_second
116    }
117    
118    fn set_requests_per_second(&mut self, value: u32) {
119        self.requests_per_second = value;
120    }
121    
122    fn get_burst_size(&self) -> u32 {
123        self.burst_size
124    }
125    
126    fn set_burst_size(&mut self, value: u32) {
127        self.burst_size = value;
128    }
129    
130    fn get_window_seconds(&self) -> u64 {
131        self.window_seconds
132    }
133    
134    fn set_window_seconds(&mut self, value: u64) {
135        self.window_seconds = value;
136    }
137}
138
139impl Default for DMSCRateLimitConfig {
140    /// Creates a default rate limit configuration.
141    /// 
142    /// Default values:
143    /// - requests_per_second: 10 requests per second
144    /// - burst_size: 20 requests (temporary burst capacity)
145    /// - window_seconds: 60 seconds window duration
146    fn default() -> Self {
147        Self {
148            requests_per_second: 10,
149            burst_size: 20,
150            window_seconds: 60,
151        }
152    }
153}
154
155/// Internal token bucket for rate limiting.
156/// 
157/// This struct implements the token bucket algorithm for rate limiting, tracking
158/// available tokens, last update time, and request count.
159#[derive(Debug)]
160struct RateLimitBucket {
161    /// Current number of available tokens in the bucket
162    tokens: AtomicUsize,
163    
164    /// Timestamp of the last token refill
165    last_update: RwLock<Instant>,
166    
167    /// Total number of requests processed by this bucket
168    request_count: AtomicUsize,
169}
170
171impl RateLimitBucket {
172    /// Creates a new token bucket with the specified initial tokens.
173    /// 
174    /// # Parameters
175    /// 
176    /// - `tokens`: Initial number of tokens in the bucket
177    /// 
178    /// # Returns
179    /// 
180    /// A new `RateLimitBucket` instance
181    fn new(tokens: usize) -> Self {
182        Self {
183            tokens: AtomicUsize::new(tokens),
184            last_update: RwLock::new(Instant::now()),
185            request_count: AtomicUsize::new(0),
186        }
187    }
188
189    /// Attempts to consume tokens from the bucket.
190    /// 
191    /// This method refills tokens based on time elapsed since the last update,
192    /// then attempts to consume the requested number of tokens.
193    /// 
194    /// # Parameters
195    /// 
196    /// - `tokens`: Number of tokens to consume
197    /// - `config`: Rate limit configuration for token refill
198    /// 
199    /// # Returns
200    /// 
201    /// `true` if tokens were successfully consumed, `false` otherwise
202    async fn try_consume(&self, tokens: usize, config: &DMSCRateLimitConfig) -> bool {
203        let now = Instant::now();
204        let mut last_update = self.last_update.write().await;
205        
206        // Refill tokens based on time elapsed
207        let elapsed = now.duration_since(*last_update).as_secs_f64();
208        let tokens_to_add = (elapsed * config.requests_per_second as f64) as usize;
209        
210        if tokens_to_add > 0 {
211            let current_tokens = self.tokens.load(Ordering::Relaxed);
212            let new_tokens = std::cmp::min(current_tokens + tokens_to_add, config.burst_size as usize);
213            self.tokens.store(new_tokens, Ordering::Relaxed);
214            *last_update = now;
215        }
216        
217        // Try to consume tokens
218        let current_tokens = self.tokens.load(Ordering::Relaxed);
219        if current_tokens >= tokens {
220            self.tokens.fetch_sub(tokens, Ordering::Relaxed);
221            self.request_count.fetch_add(1, Ordering::Relaxed);
222            true
223        } else {
224            false
225        }
226    }
227
228    /// Gets the current statistics for this bucket.
229    /// 
230    /// # Returns
231    /// 
232    /// A `DMSCRateLimitStats` struct containing current tokens and total requests
233    fn get_stats(&self) -> DMSCRateLimitStats {
234        DMSCRateLimitStats {
235            current_tokens: self.tokens.load(Ordering::Relaxed),
236            total_requests: self.request_count.load(Ordering::Relaxed),
237        }
238    }
239}
240
241/// Statistics for rate limiting monitoring.
242/// 
243/// This struct contains metrics about a rate limiter bucket, including the current
244/// number of available tokens and the total number of requests processed.
245#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass(get_all, set_all))]
246#[derive(Debug, Clone)]
247pub struct DMSCRateLimitStats {
248    /// Current number of available tokens in the bucket
249    pub current_tokens: usize,
250
251    /// Total number of requests processed by the bucket
252    pub total_requests: usize,
253}
254
255#[cfg(feature = "pyo3")]
256#[pyo3::prelude::pymethods]
257impl DMSCRateLimitStats {
258    #[new]
259    fn py_new(current_tokens: usize, total_requests: usize) -> Self {
260        Self {
261            current_tokens,
262            total_requests,
263        }
264    }
265    
266    fn get_current_tokens(&self) -> usize {
267        self.current_tokens
268    }
269    
270    fn get_total_requests(&self) -> usize {
271        self.total_requests
272    }
273}
274
275/// Token bucket based rate limiter implementation.
276/// 
277/// This struct implements the token bucket algorithm for rate limiting, allowing
278/// for both steady-state rate limiting and temporary bursts of requests.
279#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
280pub struct DMSCRateLimiter {
281    /// Configuration for rate limiting behavior
282    config: DMSCRateLimitConfig,
283    
284    /// Map of key to token bucket instances
285    buckets: RwLock<HashMap<String, Arc<RateLimitBucket>>>,
286}
287
288impl DMSCRateLimiter {
289    /// Creates a new rate limiter with the specified configuration.
290    /// 
291    /// # Parameters
292    /// 
293    /// - `config`: The configuration for rate limiting behavior
294    /// 
295    /// # Returns
296    /// 
297    /// A new `DMSCRateLimiter` instance
298    pub fn new(config: DMSCRateLimitConfig) -> Self {
299        Self {
300            config,
301            buckets: RwLock::new(HashMap::new()),
302        }
303    }
304
305    /// Checks if a gateway request should be allowed based on rate limiting.
306    /// 
307    /// This method uses the client IP address as the key for rate limiting.
308    /// 
309    /// # Parameters
310    /// 
311    /// - `request`: The gateway request to check
312    /// 
313    /// # Returns
314    /// 
315    /// `true` if the request should be allowed, `false` otherwise
316    pub async fn check_request(&self, request: &crate::gateway::DMSCGatewayRequest) -> bool {
317        // Use client IP as the key for rate limiting
318        let key = request.remote_addr.clone();
319        self.check_rate_limit(&key, 1)
320    }
321
322    /// Checks if a request with a custom key should be allowed based on rate limiting.
323    /// 
324    /// This method attempts to consume tokens from the bucket associated with the given key.
325    /// If no bucket exists for the key, a new one is created.
326    /// 
327    /// # Parameters
328    /// 
329    /// - `key`: The key to use for rate limiting (e.g., client IP, API key)
330    /// - `tokens`: Number of tokens to consume for this request
331    /// 
332    /// # Returns
333    /// 
334    /// `true` if the request should be allowed, `false` otherwise
335    pub fn check_rate_limit(&self, key: &str, tokens: usize) -> bool {
336        futures::executor::block_on(async {
337            let buckets = self.buckets.read().await;
338            
339            if let Some(bucket) = buckets.get(key) {
340                bucket.try_consume(tokens, &self.config).await
341            } else {
342                drop(buckets);
343                let mut buckets = self.buckets.write().await;
344                
345                if let Some(bucket) = buckets.get(key) {
346                    bucket.try_consume(tokens, &self.config).await
347                } else {
348                    let bucket = Arc::new(RateLimitBucket::new(self.config.burst_size as usize));
349                    let result = bucket.try_consume(tokens, &self.config).await;
350                    buckets.insert(key.to_string(), bucket);
351                    result
352                }
353            }
354        })
355    }
356
357    /// Gets rate limit statistics for a specific key.
358    /// 
359    /// # Parameters
360    /// 
361    /// - `key`: The key to get statistics for
362    /// 
363    /// # Returns
364    /// 
365    /// An `Option<DMSCRateLimitStats>` with the statistics, or `None` if no bucket exists for the key
366    pub fn get_stats(&self, key: &str) -> Option<DMSCRateLimitStats> {
367        futures::executor::block_on(async {
368            let buckets = self.buckets.read().await;
369            buckets.get(key).map(|bucket| bucket.get_stats())
370        })
371    }
372    
373    /// Gets the remaining tokens for a specific key.
374    pub fn get_remaining(&self, key: &str) -> Option<f64> {
375        futures::executor::block_on(async {
376            let buckets = self.buckets.read().await;
377            buckets.get(key).map(|bucket| {
378                let stats = bucket.get_stats();
379                stats.current_tokens as f64
380            })
381        })
382    }
383
384    /// Gets rate limit statistics for all keys.
385    /// 
386    /// # Returns
387    /// 
388    /// A `HashMap<String, DMSCRateLimitStats>` with statistics for all keys
389    pub fn get_all_stats(&self) -> HashMap<String, DMSCRateLimitStats> {
390        futures::executor::block_on(async {
391            let buckets = self.buckets.read().await;
392            let mut stats = HashMap::new();
393            
394            for (key, bucket) in buckets.iter() {
395                stats.insert(key.clone(), bucket.get_stats());
396            }
397            
398            stats
399        })
400    }
401
402    /// Resets the rate limit bucket for a specific key.
403    /// 
404    /// This method removes the bucket for the given key, effectively resetting the rate limit.
405    /// 
406    /// # Parameters
407    /// 
408    /// - `key`: The key to reset the bucket for
409    pub fn reset_bucket(&self, key: &str) {
410        futures::executor::block_on(async {
411            let mut buckets = self.buckets.write().await;
412            buckets.remove(key);
413        })
414    }
415
416    /// Clears all rate limit buckets.
417    /// 
418    /// This method removes all buckets, effectively resetting rate limits for all keys.
419    pub fn clear_all_buckets(&self) {
420        futures::executor::block_on(async {
421            let mut buckets = self.buckets.write().await;
422            buckets.clear();
423        })
424    }
425
426    /// Gets the current rate limit configuration.
427    /// 
428    /// # Returns
429    /// 
430    /// A reference to the current `DMSCRateLimitConfig`
431    pub fn get_config(&self) -> DMSCRateLimitConfig {
432        self.config.clone()
433    }
434
435    /// Updates the rate limit configuration.
436    /// 
437    /// This method updates the configuration and resets all buckets with the new settings.
438    /// 
439    /// # Parameters
440    /// 
441    /// - `config`: The new rate limit configuration
442    pub async fn update_config(&mut self, config: DMSCRateLimitConfig) {
443        self.config = config;
444        
445        let mut buckets = self.buckets.write().await;
446        buckets.clear();
447    }
448
449    pub async fn check_multi(&self, keys: &[String], tokens: usize) -> Vec<bool> {
450        let mut results = Vec::with_capacity(keys.len());
451        for key in keys {
452            results.push(self.check_rate_limit(key, tokens));
453        }
454        results
455    }
456
457    pub async fn get_keys(&self) -> Vec<String> {
458        let buckets = self.buckets.read().await;
459        buckets.keys().cloned().collect()
460    }
461    
462    pub fn bucket_count(&self) -> usize {
463        futures::executor::block_on(async {
464            let buckets = self.buckets.read().await;
465            buckets.len()
466        })
467    }
468}
469
470/// Sliding window rate limiter for fine-grained control.
471/// 
472/// This struct implements a sliding window rate limiter, which provides more precise
473/// rate limiting by tracking all requests within a sliding time window.
474#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
475pub struct DMSCSlidingWindowRateLimiter {
476    /// Maximum number of requests allowed within the window
477    max_requests: u32,
478    /// Duration of the sliding window
479    window_duration: Duration,
480    /// Vector of request timestamps within the window
481    requests: RwLock<Vec<Instant>>,
482}
483
484impl DMSCSlidingWindowRateLimiter {
485    /// Creates a new sliding window rate limiter.
486    /// 
487    /// # Parameters
488    /// 
489    /// - `max_requests`: Maximum number of requests allowed within the window
490    /// - `window_seconds`: Duration of the window in seconds
491    /// 
492    /// # Returns
493    /// 
494    /// A new `DMSCSlidingWindowRateLimiter` instance
495    pub fn new(max_requests: u32, window_seconds: u64) -> Self {
496        Self {
497            max_requests,
498            window_duration: Duration::from_secs(window_seconds),
499            requests: RwLock::new(Vec::new()),
500        }
501    }
502
503    /// Checks if a request should be allowed based on the sliding window.
504    /// 
505    /// This method removes old requests outside the window, then checks if the number
506    /// of remaining requests is below the maximum allowed.
507    /// 
508    /// # Returns
509    /// 
510    /// `true` if the request should be allowed, `false` otherwise
511    pub fn allow_request(&self) -> bool {
512        futures::executor::block_on(async {
513            let mut requests = self.requests.write().await;
514            let now = Instant::now();
515            
516            requests.retain(|&timestamp| now.duration_since(timestamp) < self.window_duration);
517            
518            if requests.len() < self.max_requests as usize {
519                requests.push(now);
520                true
521            } else {
522                false
523            }
524        })
525    }
526
527    /// Gets the current number of requests within the sliding window.
528    /// 
529    /// This method removes old requests outside the window, then returns the count
530    /// of remaining requests.
531    /// 
532    /// # Returns
533    /// 
534    /// The number of requests within the current window
535    pub fn get_current_count(&self) -> usize {
536        futures::executor::block_on(async {
537            let mut requests = self.requests.write().await;
538            let now = Instant::now();
539            
540            requests.retain(|&timestamp| now.duration_since(timestamp) < self.window_duration);
541            
542            requests.len()
543        })
544    }
545
546    /// Resets the sliding window by clearing all request timestamps.
547    pub fn reset(&self) {
548        futures::executor::block_on(async {
549            let mut requests = self.requests.write().await;
550            requests.clear();
551        })
552    }
553    
554    pub fn get_max_requests(&self) -> u32 {
555        self.max_requests
556    }
557    
558    pub fn get_window_seconds(&self) -> u64 {
559        self.window_duration.as_secs()
560    }
561}