1use async_trait::async_trait;
19use sqlx::mysql::{MySqlPool, MySqlRow};
20use sqlx::{Row, Column};
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24use crate::core::{DMSCResult, DMSCError};
25use crate::database::{
26 DMSCDatabase, DMSCDatabaseConfig, DatabaseType,
27 DMSCDBResult, DMSCDBRow
28};
29
30#[derive(Clone)]
31#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
32#[allow(dead_code)]
33pub struct MySQLDatabase {
34 pool: MySqlPool,
35 config: DMSCDatabaseConfig,
36}
37
38impl MySQLDatabase {
39 pub async fn new(database_url: &str, config: DMSCDatabaseConfig) -> Result<Self, DMSCError> {
40 let pool = MySqlPool::connect(database_url)
41 .await
42 .map_err(|e| DMSCError::Other(format!("Failed to connect to MySQL: {}", e)))?;
43
44 Ok(Self { pool, config })
45 }
46
47 fn row_to_dmsc_row(row: &MySqlRow) -> DMSCDBRow {
48 let columns: Vec<String> = (0..row.len())
49 .map(|i| row.column(i).name().to_string())
50 .collect();
51
52 let values: Vec<Option<serde_json::Value>> = (0..row.len())
53 .map(|idx| Self::value_to_json(row, idx))
54 .collect();
55
56 DMSCDBRow { columns, values }
57 }
58
59 fn value_to_json(row: &MySqlRow, idx: usize) -> Option<serde_json::Value> {
60 match row.try_get::<i64, _>(idx) {
61 Ok(v) => Some(serde_json::json!(v)),
62 Err(_) => {
63 match row.try_get::<f64, _>(idx) {
64 Ok(v) => Some(serde_json::json!(v)),
65 Err(_) => {
66 match row.try_get::<String, _>(idx) {
67 Ok(v) => Some(serde_json::json!(v)),
68 Err(_) => {
69 match row.try_get::<bool, _>(idx) {
70 Ok(v) => Some(serde_json::json!(v)),
71 Err(_) => {
72 match row.try_get::<Vec<u8>, _>(idx) {
73 Ok(v) => Some(serde_json::json!(v)),
74 Err(_) => None,
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82 }
83 }
84 }
85}
86
87#[async_trait]
88impl DMSCDatabase for MySQLDatabase {
89 fn database_type(&self) -> DatabaseType {
90 DatabaseType::MySQL
91 }
92
93 async fn execute(&self, sql: &str) -> DMSCResult<u64> {
94 let result = sqlx::query::<sqlx::MySql>(sql)
95 .execute(&self.pool)
96 .await
97 .map_err(|e| DMSCError::Other(format!("MySQL execute error: {}", e)))?;
98 Ok(result.rows_affected())
99 }
100
101 async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
102 let rows = sqlx::query::<sqlx::MySql>(sql)
103 .fetch_all(&self.pool)
104 .await
105 .map_err(|e| DMSCError::Other(format!("MySQL query error: {}", e)))?;
106
107 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
108 .map(|row| Self::row_to_dmsc_row(row))
109 .collect();
110
111 Ok(DMSCDBResult::with_rows(dmsc_rows))
112 }
113
114 async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
115 let row = sqlx::query::<sqlx::MySql>(sql)
116 .fetch_optional(&self.pool)
117 .await
118 .map_err(|e| DMSCError::Other(format!("MySQL query_one error: {}", e)))?;
119
120 Ok(row.map(|r| Self::row_to_dmsc_row(&r)))
121 }
122
123 async fn ping(&self) -> DMSCResult<bool> {
124 sqlx::query::<sqlx::MySql>("SELECT 1")
125 .fetch_one(&self.pool)
126 .await
127 .map(|_| true)
128 .map_err(|e| DMSCError::Other(format!("MySQL ping error: {}", e)))
129 }
130
131 fn is_connected(&self) -> bool {
132 !self.pool.is_closed()
133 }
134
135 async fn close(&self) -> DMSCResult<()> {
136 self.pool.close().await;
137 Ok(())
138 }
139
140 async fn batch_execute(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<u64>> {
141 let mut results = Vec::with_capacity(params.len());
142 for param_set in params {
143 let result = self.execute_with_params(sql, param_set).await?;
144 results.push(result);
145 }
146 Ok(results)
147 }
148
149 async fn batch_query(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<DMSCDBResult>> {
150 let mut results = Vec::with_capacity(params.len());
151 for param_set in params {
152 let result = self.query_with_params(sql, param_set).await?;
153 results.push(result);
154 }
155 Ok(results)
156 }
157
158 async fn execute_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<u64> {
159 let mut query = sqlx::query::<sqlx::MySql>(sql);
160
161 for param in params {
162 let param_str = param.to_string();
163 query = query.bind(param_str);
164 }
165
166 let result = query
167 .execute(&self.pool)
168 .await
169 .map_err(|e| DMSCError::Other(format!("MySQL execute_with_params error: {}", e)))?;
170 Ok(result.rows_affected())
171 }
172
173 async fn query_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<DMSCDBResult> {
174 let mut query = sqlx::query::<sqlx::MySql>(sql);
175
176 for param in params {
177 let param_str = param.to_string();
178 query = query.bind(param_str);
179 }
180
181 let rows = query
182 .fetch_all(&self.pool)
183 .await
184 .map_err(|e| DMSCError::Other(format!("MySQL query_with_params error: {}", e)))?;
185
186 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
187 .map(|row| Self::row_to_dmsc_row(row))
188 .collect();
189
190 Ok(DMSCDBResult::with_rows(dmsc_rows))
191 }
192
193 async fn transaction(&self) -> DMSCResult<Box<dyn crate::database::DMSCDatabaseTransaction>> {
194 let tx = self.pool.begin().await
195 .map_err(|e| DMSCError::Other(format!("MySQL transaction begin error: {}", e)))?;
196
197 Ok(Box::new(MySQLTransaction::new(tx)))
198 }
199}
200
201struct MySQLTransaction {
202 tx: Arc<Mutex<Option<sqlx::Transaction<'static, sqlx::MySql>>>>,
203}
204
205impl MySQLTransaction {
206 pub fn new(tx: sqlx::Transaction<'static, sqlx::MySql>) -> Self {
207 Self {
208 tx: Arc::new(Mutex::new(Some(tx))),
209 }
210 }
211}
212
213#[async_trait::async_trait]
214impl crate::database::DMSCDatabaseTransaction for MySQLTransaction {
215 async fn execute(&self, sql: &str) -> DMSCResult<u64> {
216 let mut guard = self.tx.lock().await;
217 let tx = guard.as_mut()
218 .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
219
220 let result = sqlx::query::<sqlx::MySql>(sql)
221 .execute(&mut **tx)
222 .await
223 .map_err(|e| DMSCError::Other(format!("MySQL transaction execute error: {}", e)))?;
224 Ok(result.rows_affected())
225 }
226
227 async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
228 let mut guard = self.tx.lock().await;
229 let tx = guard.as_mut()
230 .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
231
232 let rows = sqlx::query::<sqlx::MySql>(sql)
233 .fetch_all(&mut **tx)
234 .await
235 .map_err(|e| DMSCError::Other(format!("MySQL transaction query error: {}", e)))?;
236
237 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
238 .map(|row| MySQLDatabase::row_to_dmsc_row(row))
239 .collect();
240
241 Ok(DMSCDBResult::with_rows(dmsc_rows))
242 }
243
244 async fn commit(&self) -> DMSCResult<()> {
245 let mut guard = self.tx.lock().await;
246 let tx = guard.take()
247 .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
248
249 tx.commit().await
250 .map_err(|e| DMSCError::Other(format!("MySQL transaction commit error: {}", e)))
251 }
252
253 async fn rollback(&self) -> DMSCResult<()> {
254 let mut guard = self.tx.lock().await;
255 let tx = guard.take()
256 .ok_or_else(|| DMSCError::Other("MySQL transaction already closed".to_string()))?;
257
258 tx.rollback().await
259 .map_err(|e| DMSCError::Other(format!("MySQL transaction rollback error: {}", e)))
260 }
261
262 async fn close(&self) -> DMSCResult<()> {
263 self.rollback().await
264 }
265}
266
267#[cfg(feature = "pyo3")]
268#[pyo3::prelude::pymethods]
269impl MySQLDatabase {
270 #[staticmethod]
271 pub fn from_connection_string(conn_string: &str, max_connections: u32) -> Result<Self, pyo3::PyErr> {
272 let rt = tokio::runtime::Runtime::new()
273 .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
274 format!("Failed to create Tokio runtime: {}", e),
275 ))?;
276
277 rt.block_on(async {
278 let pool = MySqlPool::connect(conn_string)
279 .await
280 .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
281 format!("Failed to connect to MySQL: {}", e),
282 ))?;
283
284 let db_config = DMSCDatabaseConfig::mysql()
285 .host("localhost")
286 .port(3306)
287 .database("mysql")
288 .max_connections(max_connections)
289 .min_idle_connections(1)
290 .build();
291
292 Ok(Self { pool, config: db_config })
293 })
294 }
295
296 pub fn execute_sync(&self, sql: &str) -> Result<u64, DMSCError> {
297 let rt = tokio::runtime::Runtime::new()
298 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
299 rt.block_on(async {
300 self.execute(sql).await
301 })
302 }
303
304 pub fn query_sync(&self, sql: &str) -> Result<DMSCDBResult, DMSCError> {
305 let rt = tokio::runtime::Runtime::new()
306 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
307 rt.block_on(async {
308 self.query(sql).await
309 })
310 }
311
312 pub fn ping_sync(&self) -> Result<bool, DMSCError> {
313 let rt = tokio::runtime::Runtime::new()
314 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
315 rt.block_on(async {
316 self.ping().await
317 })
318 }
319}