1use std::borrow::Borrow;
58use std::collections::HashMap;
59use std::hash::{Hash, Hasher};
60use std::collections::hash_map::DefaultHasher;
61use std::sync::Arc;
62use tokio::sync::RwLock;
63
64#[cfg(feature = "pyo3")]
65use pyo3::pyclass;
66
67const DEFAULT_SHARD_COUNT: usize = 16;
68
69fn calculate_hash<K: Hash + ?Sized>(key: &K) -> u64 {
70 let mut hasher = DefaultHasher::new();
71 key.hash(&mut hasher);
72 hasher.finish()
73}
74
75#[inline]
76fn get_shard_index<K: Hash + ?Sized>(key: &K, shard_count: usize) -> usize {
77 let hash = calculate_hash(key);
78 (hash as usize) % shard_count
79}
80
81struct Shard<K, V> {
82 data: RwLock<HashMap<K, V>>,
83}
84
85impl<K, V> Shard<K, V> {
86 fn new() -> Self {
87 Self {
88 data: RwLock::new(HashMap::new()),
89 }
90 }
91}
92
93pub struct RiShardedLock<K, V> {
94 shards: Vec<Arc<Shard<K, V>>>,
95 shard_count: usize,
96}
97
98impl<K, V> RiShardedLock<K, V>
99where
100 K: Hash + Eq + Clone + Send + Sync + 'static,
101 V: Clone + Send + Sync + 'static,
102{
103 pub fn new(shard_count: usize) -> Self {
104 let actual_shard_count = if shard_count == 0 { DEFAULT_SHARD_COUNT } else { shard_count };
105 let shards = (0..actual_shard_count)
106 .map(|_| Arc::new(Shard::new()))
107 .collect();
108
109 Self {
110 shards,
111 shard_count: actual_shard_count,
112 }
113 }
114
115 pub fn with_default_shards() -> Self {
116 Self::new(DEFAULT_SHARD_COUNT)
117 }
118
119 #[inline]
120 fn get_shard(&self, key: &K) -> &Arc<Shard<K, V>> {
121 let index = get_shard_index(key, self.shard_count);
122 &self.shards[index]
123 }
124
125 pub async fn insert(&self, key: K, value: V) -> Option<V> {
126 let shard = self.get_shard(&key);
127 let mut data = shard.data.write().await;
128 data.insert(key, value)
129 }
130
131 pub async fn get<Q>(&self, key: &Q) -> Option<V>
132 where
133 K: Borrow<Q>,
134 Q: Hash + Eq + ?Sized,
135 {
136 let shard_index = get_shard_index(key, self.shard_count);
137 let shard = &self.shards[shard_index];
138 let data = shard.data.read().await;
139 data.get(key).cloned()
140 }
141
142 pub async fn get_mut<F, R, Q>(&self, key: &Q, f: F) -> Option<R>
143 where
144 F: FnOnce(&mut V) -> R,
145 K: Borrow<Q>,
146 Q: Hash + Eq + ?Sized,
147 {
148 let shard_index = get_shard_index(key, self.shard_count);
149 let shard = &self.shards[shard_index];
150 let mut data = shard.data.write().await;
151 data.get_mut(key).map(f)
152 }
153
154 pub async fn remove<Q>(&self, key: &Q) -> Option<V>
155 where
156 K: Borrow<Q>,
157 Q: Hash + Eq + ?Sized,
158 {
159 let shard_index = get_shard_index(key, self.shard_count);
160 let shard = &self.shards[shard_index];
161 let mut data = shard.data.write().await;
162 data.remove(key)
163 }
164
165 pub async fn contains_key<Q>(&self, key: &Q) -> bool
166 where
167 K: Borrow<Q>,
168 Q: Hash + Eq + ?Sized,
169 {
170 let shard_index = get_shard_index(key, self.shard_count);
171 let shard = &self.shards[shard_index];
172 let data = shard.data.read().await;
173 data.contains_key(key)
174 }
175
176 pub async fn len(&self) -> usize {
177 let mut total = 0;
178 for shard in &self.shards {
179 let data = shard.data.read().await;
180 total += data.len();
181 }
182 total
183 }
184
185 pub async fn is_empty(&self) -> bool {
186 self.len().await == 0
187 }
188
189 pub async fn clear(&self) {
190 for shard in &self.shards {
191 let mut data = shard.data.write().await;
192 data.clear();
193 }
194 }
195
196 pub async fn retain<F>(&self, mut f: F)
197 where
198 F: FnMut(&K, &mut V) -> bool,
199 {
200 for shard in &self.shards {
201 let mut data = shard.data.write().await;
202 data.retain(|k, v| f(k, v));
203 }
204 }
205
206 pub async fn for_each<F>(&self, mut f: F)
207 where
208 F: FnMut(&K, &V),
209 {
210 for shard in &self.shards {
211 let data = shard.data.read().await;
212 for (k, v) in data.iter() {
213 f(k, v);
214 }
215 }
216 }
217
218 pub async fn for_each_mut<F>(&self, mut f: F)
219 where
220 F: FnMut(&K, &mut V),
221 {
222 for shard in &self.shards {
223 let mut data = shard.data.write().await;
224 for (k, v) in data.iter_mut() {
225 f(k, v);
226 }
227 }
228 }
229
230 pub async fn collect_all(&self) -> HashMap<K, V> {
231 let mut result = HashMap::new();
232 for shard in &self.shards {
233 let data = shard.data.read().await;
234 for (k, v) in data.iter() {
235 result.insert(k.clone(), v.clone());
236 }
237 }
238 result
239 }
240
241 pub async fn collect_where<F>(&self, mut predicate: F) -> Vec<V>
242 where
243 F: FnMut(&K, &V) -> bool,
244 {
245 let mut result = Vec::new();
246 for shard in &self.shards {
247 let data = shard.data.read().await;
248 for (k, v) in data.iter() {
249 if predicate(k, v) {
250 result.push(v.clone());
251 }
252 }
253 }
254 result
255 }
256
257 pub async fn count_where<F>(&self, mut predicate: F) -> usize
258 where
259 F: FnMut(&K, &V) -> bool,
260 {
261 let mut count = 0;
262 for shard in &self.shards {
263 let data = shard.data.read().await;
264 for (k, v) in data.iter() {
265 if predicate(k, v) {
266 count += 1;
267 }
268 }
269 }
270 count
271 }
272
273 pub async fn remove_where<F>(&self, mut predicate: F) -> usize
274 where
275 F: FnMut(&K, &V) -> bool,
276 {
277 let mut removed_count = 0;
278 for shard in &self.shards {
279 let mut data = shard.data.write().await;
280 let before_len = data.len();
281 data.retain(|k, v| !predicate(k, v));
282 removed_count += before_len - data.len();
283 }
284 removed_count
285 }
286
287 pub async fn update<F, R>(&self, key: &K, f: F) -> Option<R>
288 where
289 F: FnOnce(Option<&mut V>) -> R,
290 {
291 let shard = self.get_shard(key);
292 let mut data = shard.data.write().await;
293 Some(f(data.get_mut(key)))
294 }
295
296 pub async fn get_or_insert<F>(&self, key: K, default: F) -> V
297 where
298 F: FnOnce() -> V,
299 {
300 let shard = self.get_shard(&key);
301 let mut data = shard.data.write().await;
302 data.entry(key).or_insert_with(default).clone()
303 }
304
305 pub async fn get_or_insert_with_key<F>(&self, key: K, default: F) -> V
306 where
307 F: FnOnce(&K) -> V,
308 {
309 let shard = self.get_shard(&key);
310 let mut data = shard.data.write().await;
311 data.entry(key.clone()).or_insert_with(|| default(&key)).clone()
312 }
313
314 pub fn shard_count(&self) -> usize {
315 self.shard_count
316 }
317}
318
319impl<K, V> Default for RiShardedLock<K, V>
320where
321 K: Hash + Eq + Clone + Send + Sync + 'static,
322 V: Clone + Send + Sync + 'static,
323{
324 fn default() -> Self {
325 Self::with_default_shards()
326 }
327}
328
329#[allow(dead_code)]
330pub struct RiShardedLockReadGuard<'a, K, V> {
331 shard_index: usize,
332 guard: tokio::sync::RwLockReadGuard<'a, HashMap<K, V>>,
333}
334
335#[allow(dead_code)]
336pub struct RiShardedLockWriteGuard<'a, K, V> {
337 shard_index: usize,
338 guard: tokio::sync::RwLockWriteGuard<'a, HashMap<K, V>>,
339}
340
341#[cfg_attr(feature = "pyo3", pyclass)]
342pub struct RiShardedLockStats {
343 pub shard_count: usize,
344 pub total_entries: usize,
345 pub shard_distribution: Vec<usize>,
346}
347
348impl RiShardedLockStats {
349 pub fn new(shard_count: usize, total_entries: usize, shard_distribution: Vec<usize>) -> Self {
350 Self {
351 shard_count,
352 total_entries,
353 shard_distribution,
354 }
355 }
356
357 pub fn calc_load_factor(&self) -> f64 {
358 if self.shard_count == 0 {
359 return 0.0;
360 }
361 self.total_entries as f64 / self.shard_count as f64
362 }
363
364 pub fn calc_distribution_variance(&self) -> f64 {
365 if self.shard_count == 0 || self.total_entries == 0 {
366 return 0.0;
367 }
368 let mean = self.total_entries as f64 / self.shard_count as f64;
369 let variance: f64 = self.shard_distribution.iter()
370 .map(|&count| {
371 let diff = count as f64 - mean;
372 diff * diff
373 })
374 .sum::<f64>() / self.shard_count as f64;
375 variance
376 }
377}
378
379#[cfg(feature = "pyo3")]
380#[pyo3::prelude::pymethods]
381impl RiShardedLockStats {
382 #[getter]
383 fn shard_count(&self) -> usize {
384 self.shard_count
385 }
386
387 #[getter]
388 fn total_entries(&self) -> usize {
389 self.total_entries
390 }
391
392 #[getter]
393 fn shard_distribution(&self) -> Vec<usize> {
394 self.shard_distribution.clone()
395 }
396
397 fn load_factor(&self) -> f64 {
398 self.calc_load_factor()
399 }
400
401 fn distribution_variance(&self) -> f64 {
402 self.calc_distribution_variance()
403 }
404}