dmsc/database/
mysql.rs

1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18use async_trait::async_trait;
19use sqlx::mysql::{MySqlPool, MySqlRow};
20use sqlx::{Row, Column};
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24use crate::core::{DMSCResult, DMSCError};
25use crate::database::{
26    DMSCDatabase, DMSCDatabaseConfig, DatabaseType,
27    DMSCDBResult, DMSCDBRow
28};
29
30#[derive(Clone)]
31#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
32#[allow(dead_code)]
33pub struct MySQLDatabase {
34    pool: MySqlPool,
35    config: DMSCDatabaseConfig,
36}
37
38impl MySQLDatabase {
39    pub async fn new(database_url: &str, config: DMSCDatabaseConfig) -> Result<Self, DMSCError> {
40        let pool = MySqlPool::connect(database_url)
41            .await
42            .map_err(|e| DMSCError::Other(format!("Failed to connect to MySQL: {}", e)))?;
43        
44        Ok(Self { pool, config })
45    }
46
47    fn row_to_dmsc_row(row: &MySqlRow) -> DMSCDBRow {
48        let columns: Vec<String> = (0..row.len())
49            .map(|i| row.column(i).name().to_string())
50            .collect();
51
52        let values: Vec<Option<serde_json::Value>> = (0..row.len())
53            .map(|idx| Self::value_to_json(row, idx))
54            .collect();
55
56        DMSCDBRow { columns, values }
57    }
58
59    fn value_to_json(row: &MySqlRow, idx: usize) -> Option<serde_json::Value> {
60        match row.try_get::<i64, _>(idx) {
61            Ok(v) => Some(serde_json::json!(v)),
62            Err(_) => {
63                match row.try_get::<f64, _>(idx) {
64                    Ok(v) => Some(serde_json::json!(v)),
65                    Err(_) => {
66                        match row.try_get::<String, _>(idx) {
67                            Ok(v) => Some(serde_json::json!(v)),
68                            Err(_) => {
69                                match row.try_get::<bool, _>(idx) {
70                                    Ok(v) => Some(serde_json::json!(v)),
71                                    Err(_) => {
72                                        match row.try_get::<Vec<u8>, _>(idx) {
73                                            Ok(v) => Some(serde_json::json!(v)),
74                                            Err(_) => None,
75                                        }
76                                    }
77                                }
78                            }
79                        }
80                    }
81                }
82            }
83        }
84    }
85}
86
87#[async_trait]
88impl DMSCDatabase for MySQLDatabase {
89    fn database_type(&self) -> DatabaseType {
90        DatabaseType::MySQL
91    }
92
93    async fn execute(&self, sql: &str) -> DMSCResult<u64> {
94        let result = sqlx::query::<sqlx::MySql>(sql)
95            .execute(&self.pool)
96            .await
97            .map_err(|e| DMSCError::Other(format!("MySQL execute error: {}", e)))?;
98        Ok(result.rows_affected())
99    }
100
101    async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
102        let rows = sqlx::query::<sqlx::MySql>(sql)
103            .fetch_all(&self.pool)
104            .await
105            .map_err(|e| DMSCError::Other(format!("MySQL query error: {}", e)))?;
106
107        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
108            .map(|row| Self::row_to_dmsc_row(row))
109            .collect();
110
111        Ok(DMSCDBResult::with_rows(dmsc_rows))
112    }
113
114    async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
115        let row = sqlx::query::<sqlx::MySql>(sql)
116            .fetch_optional(&self.pool)
117            .await
118            .map_err(|e| DMSCError::Other(format!("MySQL query_one error: {}", e)))?;
119
120        Ok(row.map(|r| Self::row_to_dmsc_row(&r)))
121    }
122
123    async fn ping(&self) -> DMSCResult<bool> {
124        sqlx::query::<sqlx::MySql>("SELECT 1")
125            .fetch_one(&self.pool)
126            .await
127            .map(|_| true)
128            .map_err(|e| DMSCError::Other(format!("MySQL ping error: {}", e)))
129    }
130
131    fn is_connected(&self) -> bool {
132        !self.pool.is_closed()
133    }
134
135    async fn close(&self) -> DMSCResult<()> {
136        self.pool.close().await;
137        Ok(())
138    }
139
140    async fn batch_execute(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<u64>> {
141        let mut results = Vec::with_capacity(params.len());
142        for param_set in params {
143            let result = self.execute_with_params(sql, param_set).await?;
144            results.push(result);
145        }
146        Ok(results)
147    }
148
149    async fn batch_query(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<DMSCDBResult>> {
150        let mut results = Vec::with_capacity(params.len());
151        for param_set in params {
152            let result = self.query_with_params(sql, param_set).await?;
153            results.push(result);
154        }
155        Ok(results)
156    }
157
158    async fn execute_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<u64> {
159        let mut query = sqlx::query::<sqlx::MySql>(sql);
160        
161        for param in params {
162            let param_str = param.to_string();
163            query = query.bind(param_str);
164        }
165        
166        let result = query
167            .execute(&self.pool)
168            .await
169            .map_err(|e| DMSCError::Other(format!("MySQL execute_with_params error: {}", e)))?;
170        Ok(result.rows_affected())
171    }
172
173    async fn query_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<DMSCDBResult> {
174        let mut query = sqlx::query::<sqlx::MySql>(sql);
175        
176        for param in params {
177            let param_str = param.to_string();
178            query = query.bind(param_str);
179        }
180        
181        let rows = query
182            .fetch_all(&self.pool)
183            .await
184            .map_err(|e| DMSCError::Other(format!("MySQL query_with_params error: {}", e)))?;
185
186        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
187            .map(|row| Self::row_to_dmsc_row(row))
188            .collect();
189
190        Ok(DMSCDBResult::with_rows(dmsc_rows))
191    }
192
193    async fn transaction(&self) -> DMSCResult<Box<dyn crate::database::DMSCDatabaseTransaction>> {
194        let tx = self.pool.begin().await
195            .map_err(|e| DMSCError::Other(format!("MySQL transaction begin error: {}", e)))?;
196
197        Ok(Box::new(MySQLTransaction::new(tx)))
198    }
199}
200
201struct MySQLTransaction {
202    tx: Arc<Mutex<Option<sqlx::Transaction<'static, sqlx::MySql>>>>,
203}
204
205impl MySQLTransaction {
206    pub fn new(tx: sqlx::Transaction<'static, sqlx::MySql>) -> Self {
207        Self {
208            tx: Arc::new(Mutex::new(Some(tx))),
209        }
210    }
211}
212
213#[async_trait::async_trait]
214impl crate::database::DMSCDatabaseTransaction for MySQLTransaction {
215    async fn execute(&self, sql: &str) -> DMSCResult<u64> {
216        let mut guard = self.tx.lock().await;
217        let tx = guard.as_mut()
218            .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
219        
220        let result = sqlx::query::<sqlx::MySql>(sql)
221            .execute(&mut **tx)
222            .await
223            .map_err(|e| DMSCError::Other(format!("MySQL transaction execute error: {}", e)))?;
224        Ok(result.rows_affected())
225    }
226
227    async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
228        let mut guard = self.tx.lock().await;
229        let tx = guard.as_mut()
230            .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
231        
232        let rows = sqlx::query::<sqlx::MySql>(sql)
233            .fetch_all(&mut **tx)
234            .await
235            .map_err(|e| DMSCError::Other(format!("MySQL transaction query error: {}", e)))?;
236
237        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
238            .map(|row| MySQLDatabase::row_to_dmsc_row(row))
239            .collect();
240
241        Ok(DMSCDBResult::with_rows(dmsc_rows))
242    }
243
244    async fn commit(&self) -> DMSCResult<()> {
245        let mut guard = self.tx.lock().await;
246        let tx = guard.take()
247            .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
248        
249        tx.commit().await
250            .map_err(|e| DMSCError::Other(format!("MySQL transaction commit error: {}", e)))
251    }
252
253    async fn rollback(&self) -> DMSCResult<()> {
254        let mut guard = self.tx.lock().await;
255        let tx = guard.take()
256            .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
257        
258        tx.rollback().await
259            .map_err(|e| DMSCError::Other(format!("MySQL transaction rollback error: {}", e)))
260    }
261
262    async fn close(&self) -> DMSCResult<()> {
263        self.rollback().await
264    }
265}
266
267#[cfg(feature = "pyo3")]
268#[pyo3::prelude::pymethods]
269impl MySQLDatabase {
270    #[staticmethod]
271    pub fn from_connection_string(conn_string: &str, max_connections: u32) -> Result<Self, pyo3::PyErr> {
272        let rt = tokio::runtime::Runtime::new()
273            .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
274                format!("Failed to create Tokio runtime: {}", e),
275            ))?;
276        
277        rt.block_on(async {
278            let pool = MySqlPool::connect(conn_string)
279                .await
280                .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
281                    format!("Failed to connect to MySQL: {}", e),
282                ))?;
283
284            let db_config = DMSCDatabaseConfig::mysql()
285                .host("localhost")
286                .port(3306)
287                .database("mysql")
288                .max_connections(max_connections)
289                .min_idle_connections(1)
290                .build();
291
292            Ok(Self { pool, config: db_config })
293        })
294    }
295
296    pub fn execute_sync(&self, sql: &str) -> Result<u64, DMSCError> {
297        let rt = tokio::runtime::Runtime::new()
298            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
299        rt.block_on(async {
300            self.execute(sql).await
301        })
302    }
303
304    pub fn query_sync(&self, sql: &str) -> Result<DMSCDBResult, DMSCError> {
305        let rt = tokio::runtime::Runtime::new()
306            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
307        rt.block_on(async {
308            self.query(sql).await
309        })
310    }
311
312    pub fn ping_sync(&self) -> Result<bool, DMSCError> {
313        let rt = tokio::runtime::Runtime::new()
314            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
315        rt.block_on(async {
316            self.ping().await
317        })
318    }
319}