1use async_trait::async_trait;
19use sqlx::sqlite::{SqlitePool, SqliteRow};
20use sqlx::{Transaction, Row, Column};
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24use crate::core::{DMSCResult, DMSCError};
25use crate::database::{
26 DMSCDatabase, DMSCDatabaseConfig, DatabaseType,
27 DMSCDBResult, DMSCDBRow
28};
29
30#[derive(Clone)]
31#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
32#[allow(dead_code)]
33pub struct SQLiteDatabase {
34 pool: SqlitePool,
35 config: DMSCDatabaseConfig,
36}
37
38impl SQLiteDatabase {
39 pub async fn new(database_url: &str, config: DMSCDatabaseConfig) -> Result<Self, DMSCError> {
40 let pool = SqlitePool::connect(database_url)
41 .await
42 .map_err(|e| DMSCError::Other(format!("Failed to connect to SQLite database: {}", e)))?;
43
44 Ok(Self { pool, config })
45 }
46
47 fn row_to_dmsc_row(row: &SqliteRow) -> DMSCDBRow {
48 let columns: Vec<String> = (0..row.len())
49 .map(|i| row.column(i).name().to_string())
50 .collect();
51
52 let values: Vec<Option<serde_json::Value>> = (0..row.len())
53 .map(|idx| Self::value_to_json(row, idx))
54 .collect();
55
56 DMSCDBRow { columns, values }
57 }
58
59 fn value_to_json(row: &SqliteRow, idx: usize) -> Option<serde_json::Value> {
60 match row.try_get::<i64, _>(idx) {
61 Ok(v) => Some(serde_json::json!(v)),
62 Err(_) => {
63 match row.try_get::<f64, _>(idx) {
64 Ok(v) => Some(serde_json::json!(v)),
65 Err(_) => {
66 match row.try_get::<String, _>(idx) {
67 Ok(v) => Some(serde_json::json!(v)),
68 Err(_) => {
69 match row.try_get::<bool, _>(idx) {
70 Ok(v) => Some(serde_json::json!(v)),
71 Err(_) => {
72 match row.try_get::<Vec<u8>, _>(idx) {
73 Ok(v) => Some(serde_json::json!(v)),
74 Err(_) => None,
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82 }
83 }
84 }
85}
86
87#[async_trait]
88impl DMSCDatabase for SQLiteDatabase {
89 fn database_type(&self) -> DatabaseType {
90 DatabaseType::SQLite
91 }
92
93 async fn execute(&self, sql: &str) -> DMSCResult<u64> {
94 let result = sqlx::query::<sqlx::Sqlite>(sql)
95 .execute(&self.pool)
96 .await
97 .map_err(|e| DMSCError::Other(format!("SQLite execute error: {}", e)))?;
98 Ok(result.rows_affected())
99 }
100
101 async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
102 let rows = sqlx::query::<sqlx::Sqlite>(sql)
103 .fetch_all(&self.pool)
104 .await
105 .map_err(|e| DMSCError::Other(format!("SQLite query error: {}", e)))?;
106
107 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
108 .map(|row| Self::row_to_dmsc_row(row))
109 .collect();
110
111 Ok(DMSCDBResult::with_rows(dmsc_rows))
112 }
113
114 async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
115 let row = sqlx::query::<sqlx::Sqlite>(sql)
116 .fetch_optional(&self.pool)
117 .await
118 .map_err(|e| DMSCError::Other(format!("SQLite query_one error: {}", e)))?;
119
120 Ok(row.map(|r| Self::row_to_dmsc_row(&r)))
121 }
122
123 async fn ping(&self) -> DMSCResult<bool> {
124 sqlx::query::<sqlx::Sqlite>("SELECT 1")
125 .fetch_one(&self.pool)
126 .await
127 .map(|_| true)
128 .map_err(|e| DMSCError::Other(format!("SQLite ping error: {}", e)))
129 }
130
131 fn is_connected(&self) -> bool {
132 !self.pool.is_closed()
133 }
134
135 async fn close(&self) -> DMSCResult<()> {
136 self.pool.close().await;
137 Ok(())
138 }
139
140 async fn batch_execute(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<u64>> {
141 let mut results = Vec::with_capacity(params.len());
142 for param_set in params {
143 let result = self.execute_with_params(sql, param_set).await?;
144 results.push(result);
145 }
146 Ok(results)
147 }
148
149 async fn batch_query(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<DMSCDBResult>> {
150 let mut results = Vec::with_capacity(params.len());
151 for param_set in params {
152 let result = self.query_with_params(sql, param_set).await?;
153 results.push(result);
154 }
155 Ok(results)
156 }
157
158 async fn execute_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<u64> {
159 let mut query = sqlx::query::<sqlx::Sqlite>(sql);
160
161 for param in params {
162 let param_str = param.to_string();
163 query = query.bind(param_str);
164 }
165
166 let result = query
167 .execute(&self.pool)
168 .await
169 .map_err(|e| DMSCError::Other(format!("SQLite execute_with_params error: {}", e)))?;
170 Ok(result.rows_affected())
171 }
172
173 async fn query_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<DMSCDBResult> {
174 let mut query = sqlx::query::<sqlx::Sqlite>(sql);
175
176 for param in params {
177 let param_str = param.to_string();
178 query = query.bind(param_str);
179 }
180
181 let rows = query
182 .fetch_all(&self.pool)
183 .await
184 .map_err(|e| DMSCError::Other(format!("SQLite query_with_params error: {}", e)))?;
185
186 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
187 .map(|row| Self::row_to_dmsc_row(row))
188 .collect();
189
190 Ok(DMSCDBResult::with_rows(dmsc_rows))
191 }
192
193 async fn transaction(&self) -> DMSCResult<Box<dyn crate::database::DMSCDatabaseTransaction>> {
194 let tx = self.pool.begin().await
195 .map_err(|e| DMSCError::Other(format!("SQLite transaction begin error: {}", e)))?;
196
197 Ok(Box::new(SQLiteTransaction::new(tx)))
198 }
199}
200
201struct SQLiteTransaction {
202 tx: Arc<Mutex<Option<Transaction<'static, sqlx::Sqlite>>>>,
203}
204
205impl SQLiteTransaction {
206 pub fn new(tx: Transaction<'static, sqlx::Sqlite>) -> Self {
207 Self {
208 tx: Arc::new(Mutex::new(Some(tx))),
209 }
210 }
211}
212
213#[async_trait::async_trait]
214impl crate::database::DMSCDatabaseTransaction for SQLiteTransaction {
215 async fn execute(&self, sql: &str) -> DMSCResult<u64> {
216 let mut guard = self.tx.lock().await;
217 let tx = guard.as_mut()
218 .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
219
220 let result = sqlx::query::<sqlx::Sqlite>(sql)
221 .execute(&mut **tx)
222 .await
223 .map_err(|e| DMSCError::Other(format!("SQLite transaction execute error: {}", e)))?;
224 Ok(result.rows_affected())
225 }
226
227 async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
228 let mut guard = self.tx.lock().await;
229 let tx = guard.as_mut()
230 .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
231
232 let rows = sqlx::query::<sqlx::Sqlite>(sql)
233 .fetch_all(&mut **tx)
234 .await
235 .map_err(|e| DMSCError::Other(format!("SQLite transaction query error: {}", e)))?;
236
237 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
238 .map(|row| SQLiteDatabase::row_to_dmsc_row(row))
239 .collect();
240
241 Ok(DMSCDBResult::with_rows(dmsc_rows))
242 }
243
244 async fn commit(&self) -> DMSCResult<()> {
245 let mut guard = self.tx.lock().await;
246 let tx = guard.take()
247 .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
248
249 tx.commit().await
250 .map_err(|e| DMSCError::Other(format!("SQLite transaction commit error: {}", e)))
251 }
252
253 async fn rollback(&self) -> DMSCResult<()> {
254 let mut guard = self.tx.lock().await;
255 let tx = guard.take()
256 .ok_or_else(|| DMSCError::Other("SQLite transaction already closed".to_string()))?;
257
258 tx.rollback().await
259 .map_err(|e| DMSCError::Other(format!("SQLite transaction rollback error: {}", e)))
260 }
261
262 async fn close(&self) -> DMSCResult<()> {
263 self.rollback().await
264 }
265}
266
267#[cfg(feature = "pyo3")]
268#[pyo3::prelude::pymethods]
269impl SQLiteDatabase {
270 #[staticmethod]
271 pub fn from_path(path: &str, max_connections: u32) -> Result<Self, pyo3::PyErr> {
272 let rt = tokio::runtime::Runtime::new()
273 .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
274 format!("Failed to create Tokio runtime: {}", e),
275 ))?;
276
277 let url = format!("sqlite:{}", path);
278
279 rt.block_on(async {
280 let pool = SqlitePool::connect(&url)
281 .await
282 .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
283 format!("Failed to connect to SQLite database: {}", e),
284 ))?;
285
286 let db_config = DMSCDatabaseConfig::sqlite(path)
287 .max_connections(max_connections)
288 .build();
289
290 Ok(Self { pool, config: db_config })
291 })
292 }
293
294 pub fn execute_sync(&self, sql: &str) -> Result<u64, DMSCError> {
295 let rt = tokio::runtime::Runtime::new()
296 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
297 rt.block_on(async {
298 self.execute(sql).await
299 })
300 }
301
302 pub fn query_sync(&self, sql: &str) -> Result<DMSCDBResult, DMSCError> {
303 let rt = tokio::runtime::Runtime::new()
304 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
305 rt.block_on(async {
306 self.query(sql).await
307 })
308 }
309
310 pub fn ping_sync(&self) -> Result<bool, DMSCError> {
311 let rt = tokio::runtime::Runtime::new()
312 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
313 rt.block_on(async {
314 self.ping().await
315 })
316 }
317}