dmsc/database/
pool.rs

1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18use crate::core::DMSCResult;
19use crate::database::{DMSCDatabase, DMSCDatabaseConfig, DMSCDBResult, DMSCDBRow};
20use dashmap::DashMap;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::sync::Arc;
23use tokio::sync::Semaphore;
24use tokio::time::{Duration, Instant};
25
26#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
27#[derive(Debug, Clone, Default)]
28pub struct DatabaseMetrics {
29    pub active_connections: u64,
30    pub idle_connections: u64,
31    pub total_connections: u64,
32    pub queries_executed: u64,
33    pub query_duration_ms: f64,
34    pub errors: u64,
35}
36
37#[derive(Clone)]
38pub struct PooledDatabase {
39    id: u32,
40    inner: Arc<dyn DMSCDatabase>,
41    pool: Arc<DMSCDatabasePool>,
42}
43
44impl PooledDatabase {
45    pub fn new(id: u32, inner: Arc<dyn DMSCDatabase>, pool: Arc<DMSCDatabasePool>) -> Self {
46        Self { id, inner, pool }
47    }
48
49    pub fn id(&self) -> u32 {
50        self.id
51    }
52
53    pub async fn execute(&self, sql: &str) -> DMSCResult<u64> {
54        self.inner.execute(sql).await
55    }
56
57    pub async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
58        self.inner.query(sql).await
59    }
60
61    pub async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
62        self.inner.query_one(sql).await
63    }
64
65    pub async fn ping(&self) -> DMSCResult<bool> {
66        self.inner.ping().await
67    }
68
69    pub fn is_connected(&self) -> bool {
70        self.inner.is_connected()
71    }
72
73    pub fn pool_metrics(&self) -> DatabaseMetrics {
74        self.pool.metrics()
75    }
76}
77
78#[async_trait::async_trait]
79impl DMSCDatabase for PooledDatabase {
80    fn database_type(&self) -> crate::database::DatabaseType {
81        self.inner.database_type()
82    }
83
84    async fn execute(&self, sql: &str) -> DMSCResult<u64> {
85        self.inner.execute(sql).await
86    }
87
88    async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
89        self.inner.query(sql).await
90    }
91
92    async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
93        self.inner.query_one(sql).await
94    }
95
96    async fn ping(&self) -> DMSCResult<bool> {
97        self.inner.ping().await
98    }
99
100    fn is_connected(&self) -> bool {
101        self.inner.is_connected()
102    }
103
104    async fn close(&self) -> DMSCResult<()> {
105        self.pool.close().await
106    }
107
108    async fn batch_execute(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<u64>> {
109        self.inner.batch_execute(sql, params).await
110    }
111
112    async fn batch_query(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<DMSCDBResult>> {
113        self.inner.batch_query(sql, params).await
114    }
115
116    async fn execute_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<u64> {
117        self.inner.execute_with_params(sql, params).await
118    }
119
120    async fn query_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<DMSCDBResult> {
121        self.inner.query_with_params(sql, params).await
122    }
123
124    async fn transaction(&self) -> DMSCResult<Box<dyn crate::database::DMSCDatabaseTransaction>> {
125        self.inner.transaction().await
126    }
127}
128
129struct PoolConnection {
130    db: Arc<dyn DMSCDatabase>,
131    acquired_at: Instant,
132}
133
134#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
135pub struct DMSCDatabasePool {
136    config: DMSCDatabaseConfig,
137    connections: Arc<DashMap<u32, PoolConnection>>,
138    available: Arc<DashMap<u32, PoolConnection>>,
139    connection_ids: Arc<AtomicU64>,
140    semaphore: Arc<Semaphore>,
141    max_idle_time: Duration,
142    max_lifetime: Duration,
143    idle_connections: Arc<AtomicU64>,
144    active_connections: Arc<AtomicU64>,
145    total_connections: Arc<AtomicU64>,
146    queries_executed: Arc<AtomicU64>,
147    errors: Arc<AtomicU64>,
148}
149
150impl DMSCDatabasePool {
151    pub async fn new(config: DMSCDatabaseConfig) -> DMSCResult<Self> {
152        let pool = Self {
153            config: config.clone(),
154            connections: Arc::new(DashMap::new()),
155            available: Arc::new(DashMap::new()),
156            connection_ids: Arc::new(AtomicU64::new(0)),
157            semaphore: Arc::new(Semaphore::new(config.max_connections as usize)),
158            max_idle_time: Duration::from_secs(config.idle_timeout_secs),
159            max_lifetime: Duration::from_secs(config.max_lifetime_secs),
160            idle_connections: Arc::new(AtomicU64::new(0)),
161            active_connections: Arc::new(AtomicU64::new(0)),
162            total_connections: Arc::new(AtomicU64::new(0)),
163            queries_executed: Arc::new(AtomicU64::new(0)),
164            errors: Arc::new(AtomicU64::new(0)),
165        };
166
167        for _ in 0..config.min_idle_connections {
168            if let Ok(conn) = pool.create_connection().await {
169                let id = pool.connection_ids.fetch_add(1, Ordering::SeqCst) as u32;
170                pool.available.insert(id, PoolConnection { db: conn, acquired_at: Instant::now() });
171                pool.idle_connections.fetch_add(1, Ordering::SeqCst);
172                pool.total_connections.fetch_add(1, Ordering::SeqCst);
173            }
174        }
175
176        Ok(pool)
177    }
178
179    async fn create_connection(&self) -> DMSCResult<Arc<dyn DMSCDatabase>> {
180        match self.config.database_type {
181            #[cfg(feature = "postgres")]
182            crate::database::DatabaseType::Postgres => {
183                let connection_string = self.config.connection_string();
184                let db = crate::database::postgres::PostgresDatabase::new(&connection_string, self.config.clone()).await
185                    .map_err(|e| crate::core::DMSCError::Config(e.to_string()))?;
186                Ok(Arc::new(db) as Arc<dyn DMSCDatabase>)
187            }
188            #[cfg(feature = "mysql")]
189            crate::database::DatabaseType::MySQL => {
190                let connection_string = self.config.connection_string();
191                let db = crate::database::mysql::MySQLDatabase::new(&connection_string, self.config.clone()).await
192                    .map_err(|e| crate::core::DMSCError::Config(e.to_string()))?;
193                Ok(Arc::new(db) as Arc<dyn DMSCDatabase>)
194            }
195            #[cfg(feature = "sqlite")]
196            crate::database::DatabaseType::SQLite => {
197                let url = format!("sqlite:{}", self.config.database);
198                let db = tokio::runtime::Handle::current().block_on(
199                    crate::database::sqlite::SQLiteDatabase::new(&url, self.config.clone())
200                );
201                match db {
202                    Ok(db) => Ok(Arc::new(db) as Arc<dyn DMSCDatabase>),
203                    Err(e) => Err(crate::core::DMSCError::Config(e.to_string())),
204                }
205            }
206            _ => Err(crate::core::DMSCError::Config("Unsupported database type".to_string())),
207        }
208    }
209
210    pub async fn get(&self) -> DMSCResult<PooledDatabase> {
211        let _permit = self.semaphore.acquire().await.map_err(|e| crate::core::DMSCError::Config(e.to_string()))?;
212
213        let mut reused_db = None;
214        let mut reused_id = None;
215
216        let now = Instant::now();
217        
218        for entry in self.available.iter() {
219            let id = *entry.key();
220            let conn = entry.value();
221            if now.duration_since(conn.acquired_at) > self.max_idle_time || now.duration_since(conn.acquired_at) > self.max_lifetime {
222                self.available.remove(&id);
223                let _ = conn.db.close().await;
224            } else {
225                reused_db = Some(conn.db.clone());
226                reused_id = Some(id);
227                self.available.remove(&id);
228                self.idle_connections.fetch_sub(1, Ordering::SeqCst);
229                self.active_connections.fetch_add(1, Ordering::SeqCst);
230                break;
231            }
232        }
233
234        let (db, id) = if let Some((existing_db, existing_id)) = reused_db.zip(reused_id) {
235            (existing_db, existing_id)
236        } else {
237            match self.create_connection().await {
238                Ok(new_conn) => {
239                    let id = self.connection_ids.fetch_add(1, Ordering::SeqCst) as u32;
240                    self.total_connections.fetch_add(1, Ordering::SeqCst);
241                    self.active_connections.fetch_add(1, Ordering::SeqCst);
242                    (new_conn, id)
243                }
244                Err(e) => {
245                    self.errors.fetch_add(1, Ordering::SeqCst);
246                    return Err(e);
247                }
248            }
249        };
250
251        Ok(PooledDatabase::new(id, db, Arc::new(self.clone())))
252    }
253
254    pub async fn release(&self, db: PooledDatabase) {
255        self.active_connections.fetch_sub(1, Ordering::SeqCst);
256        self.idle_connections.fetch_add(1, Ordering::SeqCst);
257        
258        self.available.insert(db.id(), PoolConnection { 
259            db: db.inner.clone(),
260            acquired_at: Instant::now(),
261        });
262    }
263
264    pub async fn close(&self) -> DMSCResult<()> {
265        self.semaphore.close();
266        for entry in self.connections.iter() {
267            let _ = entry.value().db.close().await;
268        }
269        self.connections.clear();
270        self.available.clear();
271        Ok(())
272    }
273
274    pub fn metrics(&self) -> DatabaseMetrics {
275        DatabaseMetrics {
276            active_connections: self.active_connections.load(Ordering::SeqCst),
277            idle_connections: self.idle_connections.load(Ordering::SeqCst),
278            total_connections: self.total_connections.load(Ordering::SeqCst),
279            queries_executed: self.queries_executed.load(Ordering::SeqCst),
280            query_duration_ms: 0.0,
281            errors: self.errors.load(Ordering::SeqCst),
282        }
283    }
284}
285
286#[cfg(feature = "pyo3")]
287#[pyo3::prelude::pymethods]
288impl DMSCDatabasePool {
289    #[new]
290    fn py_new(config: DMSCDatabaseConfig) -> Self {
291        let pool = Self {
292            config: config.clone(),
293            connections: Arc::new(DashMap::new()),
294            available: Arc::new(DashMap::new()),
295            connection_ids: Arc::new(AtomicU64::new(0)),
296            semaphore: Arc::new(Semaphore::new(config.max_connections as usize)),
297            max_idle_time: Duration::from_secs(config.idle_timeout_secs),
298            max_lifetime: Duration::from_secs(config.max_lifetime_secs),
299            idle_connections: Arc::new(AtomicU64::new(0)),
300            active_connections: Arc::new(AtomicU64::new(0)),
301            total_connections: Arc::new(AtomicU64::new(0)),
302            queries_executed: Arc::new(AtomicU64::new(0)),
303            errors: Arc::new(AtomicU64::new(0)),
304        };
305        pool
306    }
307
308    fn status(&self) -> String {
309        format!(
310            "Pool status - Active: {}, Idle: {}, Total: {}, Queries: {}, Errors: {}",
311            self.active_connections.load(Ordering::SeqCst),
312            self.idle_connections.load(Ordering::SeqCst),
313            self.total_connections.load(Ordering::SeqCst),
314            self.queries_executed.load(Ordering::SeqCst),
315            self.errors.load(Ordering::SeqCst)
316        )
317    }
318
319    fn get_config(&self) -> DMSCDatabaseConfig {
320        self.config.clone()
321    }
322}
323
324impl Clone for DMSCDatabasePool {
325    fn clone(&self) -> Self {
326        Self {
327            config: self.config.clone(),
328            connections: self.connections.clone(),
329            available: self.available.clone(),
330            connection_ids: self.connection_ids.clone(),
331            semaphore: self.semaphore.clone(),
332            max_idle_time: self.max_idle_time,
333            max_lifetime: self.max_lifetime,
334            idle_connections: self.idle_connections.clone(),
335            active_connections: self.active_connections.clone(),
336            total_connections: self.total_connections.clone(),
337            queries_executed: self.queries_executed.clone(),
338            errors: self.errors.clone(),
339        }
340    }
341}