dmsc/gateway/rate_limiter.rs
1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//! http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18#![allow(non_snake_case)]
19
20//! # Rate Limiter Module
21//!
22//! This module provides rate limiting functionality for the DMSC gateway, allowing for
23//! controlling the rate of requests from clients to prevent abuse and ensure fair usage.
24//!
25//! ## Key Components
26//!
27//! - **DMSCRateLimitConfig**: Configuration for rate limiting behavior
28//! - **DMSCRateLimiter**: Token bucket based rate limiter implementation
29//! - **DMSCSlidingWindowRateLimiter**: Sliding window based rate limiter for fine-grained control
30//! - **DMSCRateLimitStats**: Metrics for monitoring rate limiter performance
31//!
32//! ## Design Principles
33//!
34//! 1. **Token Bucket Algorithm**: Implements the token bucket algorithm for smooth rate limiting
35//! 2. **Sliding Window**: Provides a sliding window implementation for more precise control
36//! 3. **Thread Safe**: Uses Arc and RwLock for safe operation in multi-threaded environments
37//! 4. **Configurable**: Allows fine-tuning of requests per second, burst size, and window duration
38//! 5. **Metrics Collection**: Tracks and reports rate limiter statistics
39//! 6. **Async Compatibility**: Built with async/await patterns for modern Rust applications
40//! 7. **Burst Support**: Allows for temporary bursts of requests beyond the steady rate
41//! 8. **Key-Based Limiting**: Supports rate limiting by client IP or custom keys
42//!
43//! ## Usage
44//!
45//! ```rust
46//! use dmsc::prelude::*;
47//!
48//! async fn example() {
49//! // Create a rate limiter with default configuration
50//! let mut limiter = DMSCRateLimiter::new(DMSCRateLimitConfig::default());
51//!
52//! // Check if a request should be allowed
53//! let client_ip = "192.168.1.1";
54//! if limiter.check_rate_limit(client_ip, 1).await {
55//! println!("Request allowed");
56//! } else {
57//! println!("Request rate limited");
58//! }
59//!
60//! // Get rate limit stats for a client
61//! if let Some(stats) = limiter.get_stats(client_ip).await {
62//! println!("Current tokens: {}, Total requests: {}",
63//! stats.current_tokens, stats.total_requests);
64//! }
65//!
66//! // Create a sliding window rate limiter
67//! let sliding_limiter = DMSCSlidingWindowRateLimiter::new(100, 60);
68//! if sliding_limiter.allow_request().await {
69//! println!("Sliding window request allowed");
70//! }
71//! }
72//! ```
73
74use std::collections::HashMap;
75use std::sync::Arc;
76use std::sync::atomic::{AtomicUsize, Ordering};
77use tokio::sync::RwLock;
78use std::time::{Duration, Instant};
79
80/// Configuration for rate limiting behavior.
81///
82/// This struct defines the parameters that control how the rate limiter behaves,
83/// including the steady rate, burst capacity, and window duration.
84#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
85#[derive(Debug, Clone)]
86pub struct DMSCRateLimitConfig {
87 /// Maximum number of requests allowed per second in steady state
88 pub requests_per_second: u32,
89
90 /// Maximum number of requests allowed in a burst (temporary spike)
91 pub burst_size: u32,
92
93 /// Duration of the rate limiting window in seconds
94 pub window_seconds: u64,
95}
96
97#[cfg(feature = "pyo3")]
98#[pyo3::prelude::pymethods]
99impl DMSCRateLimitConfig {
100 #[new]
101 fn py_new() -> Self {
102 Self::default()
103 }
104
105 #[staticmethod]
106 fn py_new_with_values(requests_per_second: u32, burst_size: u32, window_seconds: u64) -> Self {
107 Self {
108 requests_per_second,
109 burst_size,
110 window_seconds,
111 }
112 }
113
114 fn get_requests_per_second(&self) -> u32 {
115 self.requests_per_second
116 }
117
118 fn set_requests_per_second(&mut self, value: u32) {
119 self.requests_per_second = value;
120 }
121
122 fn get_burst_size(&self) -> u32 {
123 self.burst_size
124 }
125
126 fn set_burst_size(&mut self, value: u32) {
127 self.burst_size = value;
128 }
129
130 fn get_window_seconds(&self) -> u64 {
131 self.window_seconds
132 }
133
134 fn set_window_seconds(&mut self, value: u64) {
135 self.window_seconds = value;
136 }
137}
138
139impl Default for DMSCRateLimitConfig {
140 /// Creates a default rate limit configuration.
141 ///
142 /// Default values:
143 /// - requests_per_second: 10 requests per second
144 /// - burst_size: 20 requests (temporary burst capacity)
145 /// - window_seconds: 60 seconds window duration
146 fn default() -> Self {
147 Self {
148 requests_per_second: 10,
149 burst_size: 20,
150 window_seconds: 60,
151 }
152 }
153}
154
155/// Internal token bucket for rate limiting.
156///
157/// This struct implements the token bucket algorithm for rate limiting, tracking
158/// available tokens, last update time, and request count.
159#[derive(Debug)]
160struct RateLimitBucket {
161 /// Current number of available tokens in the bucket
162 tokens: AtomicUsize,
163
164 /// Timestamp of the last token refill
165 last_update: RwLock<Instant>,
166
167 /// Total number of requests processed by this bucket
168 request_count: AtomicUsize,
169}
170
171impl RateLimitBucket {
172 /// Creates a new token bucket with the specified initial tokens.
173 ///
174 /// # Parameters
175 ///
176 /// - `tokens`: Initial number of tokens in the bucket
177 ///
178 /// # Returns
179 ///
180 /// A new `RateLimitBucket` instance
181 fn new(tokens: usize) -> Self {
182 Self {
183 tokens: AtomicUsize::new(tokens),
184 last_update: RwLock::new(Instant::now()),
185 request_count: AtomicUsize::new(0),
186 }
187 }
188
189 /// Attempts to consume tokens from the bucket.
190 ///
191 /// This method refills tokens based on time elapsed since the last update,
192 /// then attempts to consume the requested number of tokens.
193 ///
194 /// # Parameters
195 ///
196 /// - `tokens`: Number of tokens to consume
197 /// - `config`: Rate limit configuration for token refill
198 ///
199 /// # Returns
200 ///
201 /// `true` if tokens were successfully consumed, `false` otherwise
202 async fn try_consume(&self, tokens: usize, config: &DMSCRateLimitConfig) -> bool {
203 let now = Instant::now();
204 let mut last_update = self.last_update.write().await;
205
206 // Refill tokens based on time elapsed
207 let elapsed = now.duration_since(*last_update).as_secs_f64();
208 let tokens_to_add = (elapsed * config.requests_per_second as f64) as usize;
209
210 if tokens_to_add > 0 {
211 let current_tokens = self.tokens.load(Ordering::Relaxed);
212 let new_tokens = std::cmp::min(current_tokens + tokens_to_add, config.burst_size as usize);
213 self.tokens.store(new_tokens, Ordering::Relaxed);
214 *last_update = now;
215 }
216
217 // Try to consume tokens
218 let current_tokens = self.tokens.load(Ordering::Relaxed);
219 if current_tokens >= tokens {
220 self.tokens.fetch_sub(tokens, Ordering::Relaxed);
221 self.request_count.fetch_add(1, Ordering::Relaxed);
222 true
223 } else {
224 false
225 }
226 }
227
228 /// Gets the current statistics for this bucket.
229 ///
230 /// # Returns
231 ///
232 /// A `DMSCRateLimitStats` struct containing current tokens and total requests
233 fn get_stats(&self) -> DMSCRateLimitStats {
234 DMSCRateLimitStats {
235 current_tokens: self.tokens.load(Ordering::Relaxed),
236 total_requests: self.request_count.load(Ordering::Relaxed),
237 }
238 }
239}
240
241/// Statistics for rate limiting monitoring.
242///
243/// This struct contains metrics about a rate limiter bucket, including the current
244/// number of available tokens and the total number of requests processed.
245#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass(get_all, set_all))]
246#[derive(Debug, Clone)]
247pub struct DMSCRateLimitStats {
248 /// Current number of available tokens in the bucket
249 pub current_tokens: usize,
250
251 /// Total number of requests processed by the bucket
252 pub total_requests: usize,
253}
254
255#[cfg(feature = "pyo3")]
256#[pyo3::prelude::pymethods]
257impl DMSCRateLimitStats {
258 #[new]
259 fn py_new(current_tokens: usize, total_requests: usize) -> Self {
260 Self {
261 current_tokens,
262 total_requests,
263 }
264 }
265
266 fn get_current_tokens(&self) -> usize {
267 self.current_tokens
268 }
269
270 fn get_total_requests(&self) -> usize {
271 self.total_requests
272 }
273}
274
275/// Token bucket based rate limiter implementation.
276///
277/// This struct implements the token bucket algorithm for rate limiting, allowing
278/// for both steady-state rate limiting and temporary bursts of requests.
279#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
280pub struct DMSCRateLimiter {
281 /// Configuration for rate limiting behavior
282 config: DMSCRateLimitConfig,
283
284 /// Map of key to token bucket instances
285 buckets: RwLock<HashMap<String, Arc<RateLimitBucket>>>,
286}
287
288impl DMSCRateLimiter {
289 /// Creates a new rate limiter with the specified configuration.
290 ///
291 /// # Parameters
292 ///
293 /// - `config`: The configuration for rate limiting behavior
294 ///
295 /// # Returns
296 ///
297 /// A new `DMSCRateLimiter` instance
298 pub fn new(config: DMSCRateLimitConfig) -> Self {
299 Self {
300 config,
301 buckets: RwLock::new(HashMap::new()),
302 }
303 }
304
305 /// Checks if a gateway request should be allowed based on rate limiting.
306 ///
307 /// This method uses the client IP address as the key for rate limiting.
308 ///
309 /// # Parameters
310 ///
311 /// - `request`: The gateway request to check
312 ///
313 /// # Returns
314 ///
315 /// `true` if the request should be allowed, `false` otherwise
316 pub async fn check_request(&self, request: &crate::gateway::DMSCGatewayRequest) -> bool {
317 // Use client IP as the key for rate limiting
318 let key = request.remote_addr.clone();
319 self.check_rate_limit(&key, 1)
320 }
321
322 /// Checks if a request with a custom key should be allowed based on rate limiting.
323 ///
324 /// This method attempts to consume tokens from the bucket associated with the given key.
325 /// If no bucket exists for the key, a new one is created.
326 ///
327 /// # Parameters
328 ///
329 /// - `key`: The key to use for rate limiting (e.g., client IP, API key)
330 /// - `tokens`: Number of tokens to consume for this request
331 ///
332 /// # Returns
333 ///
334 /// `true` if the request should be allowed, `false` otherwise
335 pub fn check_rate_limit(&self, key: &str, tokens: usize) -> bool {
336 futures::executor::block_on(async {
337 let buckets = self.buckets.read().await;
338
339 if let Some(bucket) = buckets.get(key) {
340 bucket.try_consume(tokens, &self.config).await
341 } else {
342 drop(buckets);
343 let mut buckets = self.buckets.write().await;
344
345 if let Some(bucket) = buckets.get(key) {
346 bucket.try_consume(tokens, &self.config).await
347 } else {
348 let bucket = Arc::new(RateLimitBucket::new(self.config.burst_size as usize));
349 let result = bucket.try_consume(tokens, &self.config).await;
350 buckets.insert(key.to_string(), bucket);
351 result
352 }
353 }
354 })
355 }
356
357 /// Gets rate limit statistics for a specific key.
358 ///
359 /// # Parameters
360 ///
361 /// - `key`: The key to get statistics for
362 ///
363 /// # Returns
364 ///
365 /// An `Option<DMSCRateLimitStats>` with the statistics, or `None` if no bucket exists for the key
366 pub fn get_stats(&self, key: &str) -> Option<DMSCRateLimitStats> {
367 futures::executor::block_on(async {
368 let buckets = self.buckets.read().await;
369 buckets.get(key).map(|bucket| bucket.get_stats())
370 })
371 }
372
373 /// Gets the remaining tokens for a specific key.
374 pub fn get_remaining(&self, key: &str) -> Option<f64> {
375 futures::executor::block_on(async {
376 let buckets = self.buckets.read().await;
377 buckets.get(key).map(|bucket| {
378 let stats = bucket.get_stats();
379 stats.current_tokens as f64
380 })
381 })
382 }
383
384 /// Gets rate limit statistics for all keys.
385 ///
386 /// # Returns
387 ///
388 /// A `HashMap<String, DMSCRateLimitStats>` with statistics for all keys
389 pub fn get_all_stats(&self) -> HashMap<String, DMSCRateLimitStats> {
390 futures::executor::block_on(async {
391 let buckets = self.buckets.read().await;
392 let mut stats = HashMap::new();
393
394 for (key, bucket) in buckets.iter() {
395 stats.insert(key.clone(), bucket.get_stats());
396 }
397
398 stats
399 })
400 }
401
402 /// Resets the rate limit bucket for a specific key.
403 ///
404 /// This method removes the bucket for the given key, effectively resetting the rate limit.
405 ///
406 /// # Parameters
407 ///
408 /// - `key`: The key to reset the bucket for
409 pub fn reset_bucket(&self, key: &str) {
410 futures::executor::block_on(async {
411 let mut buckets = self.buckets.write().await;
412 buckets.remove(key);
413 })
414 }
415
416 /// Clears all rate limit buckets.
417 ///
418 /// This method removes all buckets, effectively resetting rate limits for all keys.
419 pub fn clear_all_buckets(&self) {
420 futures::executor::block_on(async {
421 let mut buckets = self.buckets.write().await;
422 buckets.clear();
423 })
424 }
425
426 /// Gets the current rate limit configuration.
427 ///
428 /// # Returns
429 ///
430 /// A reference to the current `DMSCRateLimitConfig`
431 pub fn get_config(&self) -> DMSCRateLimitConfig {
432 self.config.clone()
433 }
434
435 /// Updates the rate limit configuration.
436 ///
437 /// This method updates the configuration and resets all buckets with the new settings.
438 ///
439 /// # Parameters
440 ///
441 /// - `config`: The new rate limit configuration
442 pub async fn update_config(&mut self, config: DMSCRateLimitConfig) {
443 self.config = config;
444
445 let mut buckets = self.buckets.write().await;
446 buckets.clear();
447 }
448
449 pub async fn check_multi(&self, keys: &[String], tokens: usize) -> Vec<bool> {
450 let mut results = Vec::with_capacity(keys.len());
451 for key in keys {
452 results.push(self.check_rate_limit(key, tokens));
453 }
454 results
455 }
456
457 pub async fn get_keys(&self) -> Vec<String> {
458 let buckets = self.buckets.read().await;
459 buckets.keys().cloned().collect()
460 }
461
462 pub fn bucket_count(&self) -> usize {
463 futures::executor::block_on(async {
464 let buckets = self.buckets.read().await;
465 buckets.len()
466 })
467 }
468}
469
470/// Sliding window rate limiter for fine-grained control.
471///
472/// This struct implements a sliding window rate limiter, which provides more precise
473/// rate limiting by tracking all requests within a sliding time window.
474#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
475pub struct DMSCSlidingWindowRateLimiter {
476 /// Maximum number of requests allowed within the window
477 max_requests: u32,
478 /// Duration of the sliding window
479 window_duration: Duration,
480 /// Vector of request timestamps within the window
481 requests: RwLock<Vec<Instant>>,
482}
483
484impl DMSCSlidingWindowRateLimiter {
485 /// Creates a new sliding window rate limiter.
486 ///
487 /// # Parameters
488 ///
489 /// - `max_requests`: Maximum number of requests allowed within the window
490 /// - `window_seconds`: Duration of the window in seconds
491 ///
492 /// # Returns
493 ///
494 /// A new `DMSCSlidingWindowRateLimiter` instance
495 pub fn new(max_requests: u32, window_seconds: u64) -> Self {
496 Self {
497 max_requests,
498 window_duration: Duration::from_secs(window_seconds),
499 requests: RwLock::new(Vec::new()),
500 }
501 }
502
503 /// Checks if a request should be allowed based on the sliding window.
504 ///
505 /// This method removes old requests outside the window, then checks if the number
506 /// of remaining requests is below the maximum allowed.
507 ///
508 /// # Returns
509 ///
510 /// `true` if the request should be allowed, `false` otherwise
511 pub fn allow_request(&self) -> bool {
512 futures::executor::block_on(async {
513 let mut requests = self.requests.write().await;
514 let now = Instant::now();
515
516 requests.retain(|×tamp| now.duration_since(timestamp) < self.window_duration);
517
518 if requests.len() < self.max_requests as usize {
519 requests.push(now);
520 true
521 } else {
522 false
523 }
524 })
525 }
526
527 /// Gets the current number of requests within the sliding window.
528 ///
529 /// This method removes old requests outside the window, then returns the count
530 /// of remaining requests.
531 ///
532 /// # Returns
533 ///
534 /// The number of requests within the current window
535 pub fn get_current_count(&self) -> usize {
536 futures::executor::block_on(async {
537 let mut requests = self.requests.write().await;
538 let now = Instant::now();
539
540 requests.retain(|×tamp| now.duration_since(timestamp) < self.window_duration);
541
542 requests.len()
543 })
544 }
545
546 /// Resets the sliding window by clearing all request timestamps.
547 pub fn reset(&self) {
548 futures::executor::block_on(async {
549 let mut requests = self.requests.write().await;
550 requests.clear();
551 })
552 }
553
554 pub fn get_max_requests(&self) -> u32 {
555 self.max_requests
556 }
557
558 pub fn get_window_seconds(&self) -> u64 {
559 self.window_duration.as_secs()
560 }
561}