Skip to main content

ri/core/
concurrent.rs

1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of Ri.
4//! The Ri project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18//! # Sharded Lock Implementation
19//!
20//! This module provides a sharded lock data structure (`RiShardedLock`) that
21//! improves concurrent performance by reducing lock contention. Instead of using
22//! a single global lock, the data is partitioned into multiple shards, each with
23//! its own lock.
24//!
25//! ## Key Benefits
26//!
27//! - **Reduced Lock Contention**: Multiple threads can access different shards simultaneously
28//! - **Better Scalability**: Performance improves with more shards for high-concurrency scenarios
29//! - **Uniform Distribution**: Uses hash-based sharding for even key distribution
30//! - **Thread Safety**: All operations are thread-safe using async RwLock
31//!
32//! ## Design Principles
33//!
34//! 1. **Sharding Strategy**: Keys are distributed using `hash(key) % shard_count`
35//! 2. **Lock Granularity**: Each shard has its own RwLock for fine-grained locking
36//! 3. **Default Shard Count**: 16 shards by default, configurable based on workload
37//! 4. **Zero-Cost Abstraction**: Sharding adds minimal overhead to operations
38//!
39//! ## Usage Example
40//!
41//! ```rust,ignore
42//! use ri::core::concurrent::RiShardedLock;
43//! use std::collections::HashMap;
44//!
45//! let sharded_map = RiShardedLock::<String, String>::new(16);
46//!
47//! // Insert a value
48//! sharded_map.insert("key1".to_string(), "value1".to_string()).await;
49//!
50//! // Get a value
51//! let value = sharded_map.get("key1").await;
52//!
53//! // Remove a value
54//! sharded_map.remove("key1").await;
55//! ```
56
57use std::borrow::Borrow;
58use std::collections::HashMap;
59use std::hash::{Hash, Hasher};
60use std::collections::hash_map::DefaultHasher;
61use std::sync::Arc;
62use tokio::sync::RwLock;
63
64#[cfg(feature = "pyo3")]
65use pyo3::pyclass;
66
67const DEFAULT_SHARD_COUNT: usize = 16;
68
69fn calculate_hash<K: Hash + ?Sized>(key: &K) -> u64 {
70    let mut hasher = DefaultHasher::new();
71    key.hash(&mut hasher);
72    hasher.finish()
73}
74
75#[inline]
76fn get_shard_index<K: Hash + ?Sized>(key: &K, shard_count: usize) -> usize {
77    let hash = calculate_hash(key);
78    (hash as usize) % shard_count
79}
80
81struct Shard<K, V> {
82    data: RwLock<HashMap<K, V>>,
83}
84
85impl<K, V> Shard<K, V> {
86    fn new() -> Self {
87        Self {
88            data: RwLock::new(HashMap::new()),
89        }
90    }
91}
92
93pub struct RiShardedLock<K, V> {
94    shards: Vec<Arc<Shard<K, V>>>,
95    shard_count: usize,
96}
97
98impl<K, V> RiShardedLock<K, V>
99where
100    K: Hash + Eq + Clone + Send + Sync + 'static,
101    V: Clone + Send + Sync + 'static,
102{
103    pub fn new(shard_count: usize) -> Self {
104        let actual_shard_count = if shard_count == 0 { DEFAULT_SHARD_COUNT } else { shard_count };
105        let shards = (0..actual_shard_count)
106            .map(|_| Arc::new(Shard::new()))
107            .collect();
108
109        Self {
110            shards,
111            shard_count: actual_shard_count,
112        }
113    }
114
115    pub fn with_default_shards() -> Self {
116        Self::new(DEFAULT_SHARD_COUNT)
117    }
118
119    #[inline]
120    fn get_shard(&self, key: &K) -> &Arc<Shard<K, V>> {
121        let index = get_shard_index(key, self.shard_count);
122        &self.shards[index]
123    }
124
125    pub async fn insert(&self, key: K, value: V) -> Option<V> {
126        let shard = self.get_shard(&key);
127        let mut data = shard.data.write().await;
128        data.insert(key, value)
129    }
130
131    pub async fn get<Q>(&self, key: &Q) -> Option<V>
132    where
133        K: Borrow<Q>,
134        Q: Hash + Eq + ?Sized,
135    {
136        let shard_index = get_shard_index(key, self.shard_count);
137        let shard = &self.shards[shard_index];
138        let data = shard.data.read().await;
139        data.get(key).cloned()
140    }
141
142    pub async fn get_mut<F, R, Q>(&self, key: &Q, f: F) -> Option<R>
143    where
144        F: FnOnce(&mut V) -> R,
145        K: Borrow<Q>,
146        Q: Hash + Eq + ?Sized,
147    {
148        let shard_index = get_shard_index(key, self.shard_count);
149        let shard = &self.shards[shard_index];
150        let mut data = shard.data.write().await;
151        data.get_mut(key).map(f)
152    }
153
154    pub async fn remove<Q>(&self, key: &Q) -> Option<V>
155    where
156        K: Borrow<Q>,
157        Q: Hash + Eq + ?Sized,
158    {
159        let shard_index = get_shard_index(key, self.shard_count);
160        let shard = &self.shards[shard_index];
161        let mut data = shard.data.write().await;
162        data.remove(key)
163    }
164
165    pub async fn contains_key<Q>(&self, key: &Q) -> bool
166    where
167        K: Borrow<Q>,
168        Q: Hash + Eq + ?Sized,
169    {
170        let shard_index = get_shard_index(key, self.shard_count);
171        let shard = &self.shards[shard_index];
172        let data = shard.data.read().await;
173        data.contains_key(key)
174    }
175
176    pub async fn len(&self) -> usize {
177        let mut total = 0;
178        for shard in &self.shards {
179            let data = shard.data.read().await;
180            total += data.len();
181        }
182        total
183    }
184
185    pub async fn is_empty(&self) -> bool {
186        self.len().await == 0
187    }
188
189    pub async fn clear(&self) {
190        for shard in &self.shards {
191            let mut data = shard.data.write().await;
192            data.clear();
193        }
194    }
195
196    pub async fn retain<F>(&self, mut f: F)
197    where
198        F: FnMut(&K, &mut V) -> bool,
199    {
200        for shard in &self.shards {
201            let mut data = shard.data.write().await;
202            data.retain(|k, v| f(k, v));
203        }
204    }
205
206    pub async fn for_each<F>(&self, mut f: F)
207    where
208        F: FnMut(&K, &V),
209    {
210        for shard in &self.shards {
211            let data = shard.data.read().await;
212            for (k, v) in data.iter() {
213                f(k, v);
214            }
215        }
216    }
217
218    pub async fn for_each_mut<F>(&self, mut f: F)
219    where
220        F: FnMut(&K, &mut V),
221    {
222        for shard in &self.shards {
223            let mut data = shard.data.write().await;
224            for (k, v) in data.iter_mut() {
225                f(k, v);
226            }
227        }
228    }
229
230    pub async fn collect_all(&self) -> HashMap<K, V> {
231        let mut result = HashMap::new();
232        for shard in &self.shards {
233            let data = shard.data.read().await;
234            for (k, v) in data.iter() {
235                result.insert(k.clone(), v.clone());
236            }
237        }
238        result
239    }
240
241    pub async fn collect_where<F>(&self, mut predicate: F) -> Vec<V>
242    where
243        F: FnMut(&K, &V) -> bool,
244    {
245        let mut result = Vec::new();
246        for shard in &self.shards {
247            let data = shard.data.read().await;
248            for (k, v) in data.iter() {
249                if predicate(k, v) {
250                    result.push(v.clone());
251                }
252            }
253        }
254        result
255    }
256
257    pub async fn count_where<F>(&self, mut predicate: F) -> usize
258    where
259        F: FnMut(&K, &V) -> bool,
260    {
261        let mut count = 0;
262        for shard in &self.shards {
263            let data = shard.data.read().await;
264            for (k, v) in data.iter() {
265                if predicate(k, v) {
266                    count += 1;
267                }
268            }
269        }
270        count
271    }
272
273    pub async fn remove_where<F>(&self, mut predicate: F) -> usize
274    where
275        F: FnMut(&K, &V) -> bool,
276    {
277        let mut removed_count = 0;
278        for shard in &self.shards {
279            let mut data = shard.data.write().await;
280            let before_len = data.len();
281            data.retain(|k, v| !predicate(k, v));
282            removed_count += before_len - data.len();
283        }
284        removed_count
285    }
286
287    pub async fn update<F, R>(&self, key: &K, f: F) -> Option<R>
288    where
289        F: FnOnce(Option<&mut V>) -> R,
290    {
291        let shard = self.get_shard(key);
292        let mut data = shard.data.write().await;
293        Some(f(data.get_mut(key)))
294    }
295
296    pub async fn get_or_insert<F>(&self, key: K, default: F) -> V
297    where
298        F: FnOnce() -> V,
299    {
300        let shard = self.get_shard(&key);
301        let mut data = shard.data.write().await;
302        data.entry(key).or_insert_with(default).clone()
303    }
304
305    pub async fn get_or_insert_with_key<F>(&self, key: K, default: F) -> V
306    where
307        F: FnOnce(&K) -> V,
308    {
309        let shard = self.get_shard(&key);
310        let mut data = shard.data.write().await;
311        data.entry(key.clone()).or_insert_with(|| default(&key)).clone()
312    }
313
314    pub fn shard_count(&self) -> usize {
315        self.shard_count
316    }
317}
318
319impl<K, V> Default for RiShardedLock<K, V>
320where
321    K: Hash + Eq + Clone + Send + Sync + 'static,
322    V: Clone + Send + Sync + 'static,
323{
324    fn default() -> Self {
325        Self::with_default_shards()
326    }
327}
328
329#[allow(dead_code)]
330pub struct RiShardedLockReadGuard<'a, K, V> {
331    shard_index: usize,
332    guard: tokio::sync::RwLockReadGuard<'a, HashMap<K, V>>,
333}
334
335#[allow(dead_code)]
336pub struct RiShardedLockWriteGuard<'a, K, V> {
337    shard_index: usize,
338    guard: tokio::sync::RwLockWriteGuard<'a, HashMap<K, V>>,
339}
340
341#[cfg_attr(feature = "pyo3", pyclass)]
342pub struct RiShardedLockStats {
343    pub shard_count: usize,
344    pub total_entries: usize,
345    pub shard_distribution: Vec<usize>,
346}
347
348impl RiShardedLockStats {
349    pub fn new(shard_count: usize, total_entries: usize, shard_distribution: Vec<usize>) -> Self {
350        Self {
351            shard_count,
352            total_entries,
353            shard_distribution,
354        }
355    }
356
357    pub fn calc_load_factor(&self) -> f64 {
358        if self.shard_count == 0 {
359            return 0.0;
360        }
361        self.total_entries as f64 / self.shard_count as f64
362    }
363
364    pub fn calc_distribution_variance(&self) -> f64 {
365        if self.shard_count == 0 || self.total_entries == 0 {
366            return 0.0;
367        }
368        let mean = self.total_entries as f64 / self.shard_count as f64;
369        let variance: f64 = self.shard_distribution.iter()
370            .map(|&count| {
371                let diff = count as f64 - mean;
372                diff * diff
373            })
374            .sum::<f64>() / self.shard_count as f64;
375        variance
376    }
377}
378
379#[cfg(feature = "pyo3")]
380#[pyo3::prelude::pymethods]
381impl RiShardedLockStats {
382    #[getter]
383    fn shard_count(&self) -> usize {
384        self.shard_count
385    }
386
387    #[getter]
388    fn total_entries(&self) -> usize {
389        self.total_entries
390    }
391
392    #[getter]
393    fn shard_distribution(&self) -> Vec<usize> {
394        self.shard_distribution.clone()
395    }
396
397    fn load_factor(&self) -> f64 {
398        self.calc_load_factor()
399    }
400
401    fn distribution_variance(&self) -> f64 {
402        self.calc_distribution_variance()
403    }
404}