dmsc/database/
sqlite.rs

1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18use async_trait::async_trait;
19use sqlx::sqlite::{SqlitePool, SqliteRow};
20use sqlx::{Transaction, Row, Column};
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24use crate::core::{DMSCResult, DMSCError};
25use crate::database::{
26    DMSCDatabase, DMSCDatabaseConfig, DatabaseType,
27    DMSCDBResult, DMSCDBRow
28};
29
30#[derive(Clone)]
31#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
32#[allow(dead_code)]
33pub struct SQLiteDatabase {
34    pool: SqlitePool,
35    config: DMSCDatabaseConfig,
36}
37
38impl SQLiteDatabase {
39    pub async fn new(database_url: &str, config: DMSCDatabaseConfig) -> Result<Self, DMSCError> {
40        let pool = SqlitePool::connect(database_url)
41            .await
42            .map_err(|e| DMSCError::Other(format!("Failed to connect to SQLite database: {}", e)))?;
43        
44        Ok(Self { pool, config })
45    }
46
47    fn row_to_dmsc_row(row: &SqliteRow) -> DMSCDBRow {
48        let columns: Vec<String> = (0..row.len())
49            .map(|i| row.column(i).name().to_string())
50            .collect();
51
52        let values: Vec<Option<serde_json::Value>> = (0..row.len())
53            .map(|idx| Self::value_to_json(row, idx))
54            .collect();
55
56        DMSCDBRow { columns, values }
57    }
58
59    fn value_to_json(row: &SqliteRow, idx: usize) -> Option<serde_json::Value> {
60        match row.try_get::<i64, _>(idx) {
61            Ok(v) => Some(serde_json::json!(v)),
62            Err(_) => {
63                match row.try_get::<f64, _>(idx) {
64                    Ok(v) => Some(serde_json::json!(v)),
65                    Err(_) => {
66                        match row.try_get::<String, _>(idx) {
67                            Ok(v) => Some(serde_json::json!(v)),
68                            Err(_) => {
69                                match row.try_get::<bool, _>(idx) {
70                                    Ok(v) => Some(serde_json::json!(v)),
71                                    Err(_) => {
72                                        match row.try_get::<Vec<u8>, _>(idx) {
73                                            Ok(v) => Some(serde_json::json!(v)),
74                                            Err(_) => None,
75                                        }
76                                    }
77                                }
78                            }
79                        }
80                    }
81                }
82            }
83        }
84    }
85}
86
87#[async_trait]
88impl DMSCDatabase for SQLiteDatabase {
89    fn database_type(&self) -> DatabaseType {
90        DatabaseType::SQLite
91    }
92
93    async fn execute(&self, sql: &str) -> DMSCResult<u64> {
94        let result = sqlx::query::<sqlx::Sqlite>(sql)
95            .execute(&self.pool)
96            .await
97            .map_err(|e| DMSCError::Other(format!("SQLite execute error: {}", e)))?;
98        Ok(result.rows_affected())
99    }
100
101    async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
102        let rows = sqlx::query::<sqlx::Sqlite>(sql)
103            .fetch_all(&self.pool)
104            .await
105            .map_err(|e| DMSCError::Other(format!("SQLite query error: {}", e)))?;
106
107        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
108            .map(|row| Self::row_to_dmsc_row(row))
109            .collect();
110
111        Ok(DMSCDBResult::with_rows(dmsc_rows))
112    }
113
114    async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
115        let row = sqlx::query::<sqlx::Sqlite>(sql)
116            .fetch_optional(&self.pool)
117            .await
118            .map_err(|e| DMSCError::Other(format!("SQLite query_one error: {}", e)))?;
119
120        Ok(row.map(|r| Self::row_to_dmsc_row(&r)))
121    }
122
123    async fn ping(&self) -> DMSCResult<bool> {
124        sqlx::query::<sqlx::Sqlite>("SELECT 1")
125            .fetch_one(&self.pool)
126            .await
127            .map(|_| true)
128            .map_err(|e| DMSCError::Other(format!("SQLite ping error: {}", e)))
129    }
130
131    fn is_connected(&self) -> bool {
132        !self.pool.is_closed()
133    }
134
135    async fn close(&self) -> DMSCResult<()> {
136        self.pool.close().await;
137        Ok(())
138    }
139
140    async fn batch_execute(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<u64>> {
141        let mut results = Vec::with_capacity(params.len());
142        for param_set in params {
143            let result = self.execute_with_params(sql, param_set).await?;
144            results.push(result);
145        }
146        Ok(results)
147    }
148
149    async fn batch_query(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<DMSCDBResult>> {
150        let mut results = Vec::with_capacity(params.len());
151        for param_set in params {
152            let result = self.query_with_params(sql, param_set).await?;
153            results.push(result);
154        }
155        Ok(results)
156    }
157
158    async fn execute_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<u64> {
159        let mut query = sqlx::query::<sqlx::Sqlite>(sql);
160        
161        for param in params {
162            let param_str = param.to_string();
163            query = query.bind(param_str);
164        }
165        
166        let result = query
167            .execute(&self.pool)
168            .await
169            .map_err(|e| DMSCError::Other(format!("SQLite execute_with_params error: {}", e)))?;
170        Ok(result.rows_affected())
171    }
172
173    async fn query_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<DMSCDBResult> {
174        let mut query = sqlx::query::<sqlx::Sqlite>(sql);
175        
176        for param in params {
177            let param_str = param.to_string();
178            query = query.bind(param_str);
179        }
180        
181        let rows = query
182            .fetch_all(&self.pool)
183            .await
184            .map_err(|e| DMSCError::Other(format!("SQLite query_with_params error: {}", e)))?;
185
186        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
187            .map(|row| Self::row_to_dmsc_row(row))
188            .collect();
189
190        Ok(DMSCDBResult::with_rows(dmsc_rows))
191    }
192
193    async fn transaction(&self) -> DMSCResult<Box<dyn crate::database::DMSCDatabaseTransaction>> {
194        let tx = self.pool.begin().await
195            .map_err(|e| DMSCError::Other(format!("SQLite transaction begin error: {}", e)))?;
196
197        Ok(Box::new(SQLiteTransaction::new(tx)))
198    }
199}
200
201struct SQLiteTransaction {
202    tx: Arc<Mutex<Option<Transaction<'static, sqlx::Sqlite>>>>,
203}
204
205impl SQLiteTransaction {
206    pub fn new(tx: Transaction<'static, sqlx::Sqlite>) -> Self {
207        Self {
208            tx: Arc::new(Mutex::new(Some(tx))),
209        }
210    }
211}
212
213#[async_trait::async_trait]
214impl crate::database::DMSCDatabaseTransaction for SQLiteTransaction {
215    async fn execute(&self, sql: &str) -> DMSCResult<u64> {
216        let mut guard = self.tx.lock().await;
217        let tx = guard.as_mut()
218            .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
219        
220        let result = sqlx::query::<sqlx::Sqlite>(sql)
221            .execute(&mut **tx)
222            .await
223            .map_err(|e| DMSCError::Other(format!("SQLite transaction execute error: {}", e)))?;
224        Ok(result.rows_affected())
225    }
226
227    async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
228        let mut guard = self.tx.lock().await;
229        let tx = guard.as_mut()
230            .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
231        
232        let rows = sqlx::query::<sqlx::Sqlite>(sql)
233            .fetch_all(&mut **tx)
234            .await
235            .map_err(|e| DMSCError::Other(format!("SQLite transaction query error: {}", e)))?;
236
237        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
238            .map(|row| SQLiteDatabase::row_to_dmsc_row(row))
239            .collect();
240
241        Ok(DMSCDBResult::with_rows(dmsc_rows))
242    }
243
244    async fn commit(&self) -> DMSCResult<()> {
245        let mut guard = self.tx.lock().await;
246        let tx = guard.take()
247            .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
248        
249        tx.commit().await
250            .map_err(|e| DMSCError::Other(format!("SQLite transaction commit error: {}", e)))
251    }
252
253    async fn rollback(&self) -> DMSCResult<()> {
254        let mut guard = self.tx.lock().await;
255        let tx = guard.take()
256            .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
257        
258        tx.rollback().await
259            .map_err(|e| DMSCError::Other(format!("SQLite transaction rollback error: {}", e)))
260    }
261
262    async fn close(&self) -> DMSCResult<()> {
263        self.rollback().await
264    }
265}
266
267#[cfg(feature = "pyo3")]
268#[pyo3::prelude::pymethods]
269impl SQLiteDatabase {
270    #[staticmethod]
271    pub fn from_path(path: &str, max_connections: u32) -> Result<Self, pyo3::PyErr> {
272        let rt = tokio::runtime::Runtime::new()
273            .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
274                format!("Failed to create Tokio runtime: {}", e),
275            ))?;
276        
277        let url = format!("sqlite:{}", path);
278        
279        rt.block_on(async {
280            let pool = SqlitePool::connect(&url)
281                .await
282                .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
283                    format!("Failed to connect to SQLite database: {}", e),
284                ))?;
285
286            let db_config = DMSCDatabaseConfig::sqlite(path)
287                .max_connections(max_connections)
288                .build();
289
290            Ok(Self { pool, config: db_config })
291        })
292    }
293
294    pub fn execute_sync(&self, sql: &str) -> Result<u64, DMSCError> {
295        let rt = tokio::runtime::Runtime::new()
296            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
297        rt.block_on(async {
298            self.execute(sql).await
299        })
300    }
301
302    pub fn query_sync(&self, sql: &str) -> Result<DMSCDBResult, DMSCError> {
303        let rt = tokio::runtime::Runtime::new()
304            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
305        rt.block_on(async {
306            self.query(sql).await
307        })
308    }
309
310    pub fn ping_sync(&self) -> Result<bool, DMSCError> {
311        let rt = tokio::runtime::Runtime::new()
312            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
313        rt.block_on(async {
314            self.ping().await
315        })
316    }
317}