dmsc/database/
postgres.rs

1//! Copyright © 2025-2026 Wenze Wei. All Rights Reserved.
2//!
3//! This file is part of DMSC.
4//! The DMSC project belongs to the Dunimd Team.
5//!
6//! Licensed under the Apache License, Version 2.0 (the "License");
7//! You may not use this file except in compliance with the License.
8//! You may obtain a copy of the License at
9//!
10//!     http://www.apache.org/licenses/LICENSE-2.0
11//!
12//! Unless required by applicable law or agreed to in writing, software
13//! distributed under the License is distributed on an "AS IS" BASIS,
14//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15//! See the License for the specific language governing permissions and
16//! limitations under the License.
17
18use async_trait::async_trait;
19use sqlx::postgres::{PgPool, PgRow};
20use sqlx::{Row, Column};
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24use crate::core::{DMSCResult, DMSCError};
25use crate::database::{
26    DMSCDatabase, DMSCDatabaseConfig, DatabaseType,
27    DMSCDBResult, DMSCDBRow
28};
29
30#[derive(Clone)]
31#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
32#[allow(dead_code)]
33pub struct PostgresDatabase {
34    pool: PgPool,
35    config: DMSCDatabaseConfig,
36}
37
38impl PostgresDatabase {
39    pub async fn new(database_url: &str, config: DMSCDatabaseConfig) -> Result<Self, DMSCError> {
40        let pool = PgPool::connect(database_url)
41            .await
42            .map_err(|e| DMSCError::Other(format!("Failed to connect to PostgreSQL: {}", e)))?;
43        
44        Ok(Self { pool, config })
45    }
46
47    fn row_to_dmsc_row(row: &PgRow) -> DMSCDBRow {
48        let columns: Vec<String> = (0..row.len())
49            .map(|i| row.column(i).name().to_string())
50            .collect();
51
52        let values: Vec<Option<serde_json::Value>> = (0..row.len())
53            .map(|idx| Self::value_to_json(row, idx))
54            .collect();
55
56        DMSCDBRow { columns, values }
57    }
58
59    fn value_to_json(row: &PgRow, idx: usize) -> Option<serde_json::Value> {
60        match row.try_get::<i64, _>(idx) {
61            Ok(v) => Some(serde_json::json!(v)),
62            Err(_) => {
63                match row.try_get::<f64, _>(idx) {
64                    Ok(v) => Some(serde_json::json!(v)),
65                    Err(_) => {
66                        match row.try_get::<String, _>(idx) {
67                            Ok(v) => Some(serde_json::json!(v)),
68                            Err(_) => {
69                                match row.try_get::<bool, _>(idx) {
70                                    Ok(v) => Some(serde_json::json!(v)),
71                                    Err(_) => {
72                                        match row.try_get::<Vec<u8>, _>(idx) {
73                                            Ok(v) => Some(serde_json::json!(v)),
74                                            Err(_) => None,
75                                        }
76                                    }
77                                }
78                            }
79                        }
80                    }
81                }
82            }
83        }
84    }
85}
86
87#[async_trait]
88impl DMSCDatabase for PostgresDatabase {
89    fn database_type(&self) -> DatabaseType {
90        DatabaseType::Postgres
91    }
92
93    async fn execute(&self, sql: &str) -> DMSCResult<u64> {
94        let result = sqlx::query::<sqlx::Postgres>(sql)
95            .execute(&self.pool)
96            .await
97            .map_err(|e| DMSCError::Other(format!("PostgreSQL execute error: {}", e)))?;
98        Ok(result.rows_affected())
99    }
100
101    async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
102        let rows = sqlx::query::<sqlx::Postgres>(sql)
103            .fetch_all(&self.pool)
104            .await
105            .map_err(|e| DMSCError::Other(format!("PostgreSQL query error: {}", e)))?;
106
107        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
108            .map(|row| Self::row_to_dmsc_row(row))
109            .collect();
110
111        Ok(DMSCDBResult::with_rows(dmsc_rows))
112    }
113
114    async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
115        let row = sqlx::query::<sqlx::Postgres>(sql)
116            .fetch_optional(&self.pool)
117            .await
118            .map_err(|e| DMSCError::Other(format!("PostgreSQL query_one error: {}", e)))?;
119
120        Ok(row.map(|r| Self::row_to_dmsc_row(&r)))
121    }
122
123    async fn ping(&self) -> DMSCResult<bool> {
124        sqlx::query::<sqlx::Postgres>("SELECT 1")
125            .fetch_one(&self.pool)
126            .await
127            .map(|_| true)
128            .map_err(|e| DMSCError::Other(format!("PostgreSQL ping error: {}", e)))
129    }
130
131    fn is_connected(&self) -> bool {
132        !self.pool.is_closed()
133    }
134
135    async fn close(&self) -> DMSCResult<()> {
136        self.pool.close().await;
137        Ok(())
138    }
139
140    async fn batch_execute(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<u64>> {
141        let mut results = Vec::with_capacity(params.len());
142        for param_set in params {
143            let result = self.execute_with_params(sql, param_set).await?;
144            results.push(result);
145        }
146        Ok(results)
147    }
148
149    async fn batch_query(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<DMSCDBResult>> {
150        let mut results = Vec::with_capacity(params.len());
151        for param_set in params {
152            let result = self.query_with_params(sql, param_set).await?;
153            results.push(result);
154        }
155        Ok(results)
156    }
157
158    async fn execute_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<u64> {
159        let mut query = sqlx::query::<sqlx::Postgres>(sql);
160        
161        for param in params {
162            query = query.bind(param.clone());
163        }
164        
165        let result = query
166            .execute(&self.pool)
167            .await
168            .map_err(|e| DMSCError::Other(format!("PostgreSQL execute_with_params error: {}", e)))?;
169        Ok(result.rows_affected())
170    }
171
172    async fn query_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<DMSCDBResult> {
173        let mut query = sqlx::query::<sqlx::Postgres>(sql);
174        
175        for param in params {
176            query = query.bind(param.clone());
177        }
178        
179        let rows = query
180            .fetch_all(&self.pool)
181            .await
182            .map_err(|e| DMSCError::Other(format!("PostgreSQL query_with_params error: {}", e)))?;
183
184        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
185            .map(|row| Self::row_to_dmsc_row(row))
186            .collect();
187
188        Ok(DMSCDBResult::with_rows(dmsc_rows))
189    }
190
191    async fn transaction(&self) -> DMSCResult<Box<dyn crate::database::DMSCDatabaseTransaction>> {
192        let tx = self.pool.begin().await
193            .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction begin error: {}", e)))?;
194
195        Ok(Box::new(PostgresTransaction::new(tx)))
196    }
197}
198
199struct PostgresTransaction {
200    tx: Arc<Mutex<Option<sqlx::Transaction<'static, sqlx::Postgres>>>>,
201}
202
203impl PostgresTransaction {
204    pub fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
205        Self {
206            tx: Arc::new(Mutex::new(Some(tx))),
207        }
208    }
209}
210
211#[async_trait::async_trait]
212impl crate::database::DMSCDatabaseTransaction for PostgresTransaction {
213    async fn execute(&self, sql: &str) -> DMSCResult<u64> {
214        let mut guard = self.tx.lock().await;
215        let tx = guard.as_mut()
216            .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
217        
218        let result = sqlx::query::<sqlx::Postgres>(sql)
219            .execute(&mut **tx)
220            .await
221            .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction execute error: {}", e)))?;
222        Ok(result.rows_affected())
223    }
224
225    async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
226        let mut guard = self.tx.lock().await;
227        let tx = guard.as_mut()
228            .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
229        
230        let rows = sqlx::query::<sqlx::Postgres>(sql)
231            .fetch_all(&mut **tx)
232            .await
233            .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction query error: {}", e)))?;
234
235        let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
236            .map(|row| PostgresDatabase::row_to_dmsc_row(row))
237            .collect();
238
239        Ok(DMSCDBResult::with_rows(dmsc_rows))
240    }
241
242    async fn commit(&self) -> DMSCResult<()> {
243        let mut guard = self.tx.lock().await;
244        let tx = guard.take()
245            .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
246        
247        tx.commit().await
248            .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction commit error: {}", e)))
249    }
250
251    async fn rollback(&self) -> DMSCResult<()> {
252        let mut guard = self.tx.lock().await;
253        let tx = guard.take()
254            .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
255        
256        tx.rollback().await
257            .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction rollback error: {}", e)))
258    }
259
260    async fn close(&self) -> DMSCResult<()> {
261        self.rollback().await
262    }
263}
264
265#[cfg(feature = "pyo3")]
266#[pyo3::prelude::pymethods]
267impl PostgresDatabase {
268    #[staticmethod]
269    pub fn from_connection_string(conn_string: &str, max_connections: u32) -> Result<Self, pyo3::PyErr> {
270        let rt = tokio::runtime::Runtime::new()
271            .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
272                format!("Failed to create Tokio runtime: {}", e),
273            ))?;
274        
275        rt.block_on(async {
276            let pool = PgPool::connect(conn_string)
277                .await
278                .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
279                    format!("Failed to connect to PostgreSQL: {}", e),
280                ))?;
281
282            let db_config = DMSCDatabaseConfig::postgres()
283                .host("localhost")
284                .port(5432)
285                .database("postgres")
286                .max_connections(max_connections)
287                .build();
288
289            Ok(Self { pool, config: db_config })
290        })
291    }
292
293    pub fn execute_sync(&self, sql: &str) -> Result<u64, DMSCError> {
294        let rt = tokio::runtime::Runtime::new()
295            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
296        rt.block_on(async {
297            self.execute(sql).await
298        })
299    }
300
301    pub fn query_sync(&self, sql: &str) -> Result<DMSCDBResult, DMSCError> {
302        let rt = tokio::runtime::Runtime::new()
303            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
304        rt.block_on(async {
305            self.query(sql).await
306        })
307    }
308
309    pub fn ping_sync(&self) -> Result<bool, DMSCError> {
310        let rt = tokio::runtime::Runtime::new()
311            .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
312        rt.block_on(async {
313            self.ping().await
314        })
315    }
316}