dmsc/database/
postgres.rs1use async_trait::async_trait;
19use sqlx::postgres::{PgPool, PgRow};
20use sqlx::{Row, Column};
21use std::sync::Arc;
22use tokio::sync::Mutex;
23
24use crate::core::{DMSCResult, DMSCError};
25use crate::database::{
26 DMSCDatabase, DMSCDatabaseConfig, DatabaseType,
27 DMSCDBResult, DMSCDBRow
28};
29
30#[derive(Clone)]
31#[cfg_attr(feature = "pyo3", pyo3::prelude::pyclass)]
32#[allow(dead_code)]
33pub struct PostgresDatabase {
34 pool: PgPool,
35 config: DMSCDatabaseConfig,
36}
37
38impl PostgresDatabase {
39 pub async fn new(database_url: &str, config: DMSCDatabaseConfig) -> Result<Self, DMSCError> {
40 let pool = PgPool::connect(database_url)
41 .await
42 .map_err(|e| DMSCError::Other(format!("Failed to connect to PostgreSQL: {}", e)))?;
43
44 Ok(Self { pool, config })
45 }
46
47 fn row_to_dmsc_row(row: &PgRow) -> DMSCDBRow {
48 let columns: Vec<String> = (0..row.len())
49 .map(|i| row.column(i).name().to_string())
50 .collect();
51
52 let values: Vec<Option<serde_json::Value>> = (0..row.len())
53 .map(|idx| Self::value_to_json(row, idx))
54 .collect();
55
56 DMSCDBRow { columns, values }
57 }
58
59 fn value_to_json(row: &PgRow, idx: usize) -> Option<serde_json::Value> {
60 match row.try_get::<i64, _>(idx) {
61 Ok(v) => Some(serde_json::json!(v)),
62 Err(_) => {
63 match row.try_get::<f64, _>(idx) {
64 Ok(v) => Some(serde_json::json!(v)),
65 Err(_) => {
66 match row.try_get::<String, _>(idx) {
67 Ok(v) => Some(serde_json::json!(v)),
68 Err(_) => {
69 match row.try_get::<bool, _>(idx) {
70 Ok(v) => Some(serde_json::json!(v)),
71 Err(_) => {
72 match row.try_get::<Vec<u8>, _>(idx) {
73 Ok(v) => Some(serde_json::json!(v)),
74 Err(_) => None,
75 }
76 }
77 }
78 }
79 }
80 }
81 }
82 }
83 }
84 }
85}
86
87#[async_trait]
88impl DMSCDatabase for PostgresDatabase {
89 fn database_type(&self) -> DatabaseType {
90 DatabaseType::Postgres
91 }
92
93 async fn execute(&self, sql: &str) -> DMSCResult<u64> {
94 let result = sqlx::query::<sqlx::Postgres>(sql)
95 .execute(&self.pool)
96 .await
97 .map_err(|e| DMSCError::Other(format!("PostgreSQL execute error: {}", e)))?;
98 Ok(result.rows_affected())
99 }
100
101 async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
102 let rows = sqlx::query::<sqlx::Postgres>(sql)
103 .fetch_all(&self.pool)
104 .await
105 .map_err(|e| DMSCError::Other(format!("PostgreSQL query error: {}", e)))?;
106
107 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
108 .map(|row| Self::row_to_dmsc_row(row))
109 .collect();
110
111 Ok(DMSCDBResult::with_rows(dmsc_rows))
112 }
113
114 async fn query_one(&self, sql: &str) -> DMSCResult<Option<DMSCDBRow>> {
115 let row = sqlx::query::<sqlx::Postgres>(sql)
116 .fetch_optional(&self.pool)
117 .await
118 .map_err(|e| DMSCError::Other(format!("PostgreSQL query_one error: {}", e)))?;
119
120 Ok(row.map(|r| Self::row_to_dmsc_row(&r)))
121 }
122
123 async fn ping(&self) -> DMSCResult<bool> {
124 sqlx::query::<sqlx::Postgres>("SELECT 1")
125 .fetch_one(&self.pool)
126 .await
127 .map(|_| true)
128 .map_err(|e| DMSCError::Other(format!("PostgreSQL ping error: {}", e)))
129 }
130
131 fn is_connected(&self) -> bool {
132 !self.pool.is_closed()
133 }
134
135 async fn close(&self) -> DMSCResult<()> {
136 self.pool.close().await;
137 Ok(())
138 }
139
140 async fn batch_execute(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<u64>> {
141 let mut results = Vec::with_capacity(params.len());
142 for param_set in params {
143 let result = self.execute_with_params(sql, param_set).await?;
144 results.push(result);
145 }
146 Ok(results)
147 }
148
149 async fn batch_query(&self, sql: &str, params: &[Vec<serde_json::Value>]) -> DMSCResult<Vec<DMSCDBResult>> {
150 let mut results = Vec::with_capacity(params.len());
151 for param_set in params {
152 let result = self.query_with_params(sql, param_set).await?;
153 results.push(result);
154 }
155 Ok(results)
156 }
157
158 async fn execute_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<u64> {
159 let mut query = sqlx::query::<sqlx::Postgres>(sql);
160
161 for param in params {
162 query = query.bind(param.clone());
163 }
164
165 let result = query
166 .execute(&self.pool)
167 .await
168 .map_err(|e| DMSCError::Other(format!("PostgreSQL execute_with_params error: {}", e)))?;
169 Ok(result.rows_affected())
170 }
171
172 async fn query_with_params(&self, sql: &str, params: &[serde_json::Value]) -> DMSCResult<DMSCDBResult> {
173 let mut query = sqlx::query::<sqlx::Postgres>(sql);
174
175 for param in params {
176 query = query.bind(param.clone());
177 }
178
179 let rows = query
180 .fetch_all(&self.pool)
181 .await
182 .map_err(|e| DMSCError::Other(format!("PostgreSQL query_with_params error: {}", e)))?;
183
184 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
185 .map(|row| Self::row_to_dmsc_row(row))
186 .collect();
187
188 Ok(DMSCDBResult::with_rows(dmsc_rows))
189 }
190
191 async fn transaction(&self) -> DMSCResult<Box<dyn crate::database::DMSCDatabaseTransaction>> {
192 let tx = self.pool.begin().await
193 .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction begin error: {}", e)))?;
194
195 Ok(Box::new(PostgresTransaction::new(tx)))
196 }
197}
198
199struct PostgresTransaction {
200 tx: Arc<Mutex<Option<sqlx::Transaction<'static, sqlx::Postgres>>>>,
201}
202
203impl PostgresTransaction {
204 pub fn new(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
205 Self {
206 tx: Arc::new(Mutex::new(Some(tx))),
207 }
208 }
209}
210
211#[async_trait::async_trait]
212impl crate::database::DMSCDatabaseTransaction for PostgresTransaction {
213 async fn execute(&self, sql: &str) -> DMSCResult<u64> {
214 let mut guard = self.tx.lock().await;
215 let tx = guard.as_mut()
216 .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
217
218 let result = sqlx::query::<sqlx::Postgres>(sql)
219 .execute(&mut **tx)
220 .await
221 .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction execute error: {}", e)))?;
222 Ok(result.rows_affected())
223 }
224
225 async fn query(&self, sql: &str) -> DMSCResult<DMSCDBResult> {
226 let mut guard = self.tx.lock().await;
227 let tx = guard.as_mut()
228 .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
229
230 let rows = sqlx::query::<sqlx::Postgres>(sql)
231 .fetch_all(&mut **tx)
232 .await
233 .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction query error: {}", e)))?;
234
235 let dmsc_rows: Vec<DMSCDBRow> = rows.iter()
236 .map(|row| PostgresDatabase::row_to_dmsc_row(row))
237 .collect();
238
239 Ok(DMSCDBResult::with_rows(dmsc_rows))
240 }
241
242 async fn commit(&self) -> DMSCResult<()> {
243 let mut guard = self.tx.lock().await;
244 let tx = guard.take()
245 .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
246
247 tx.commit().await
248 .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction commit error: {}", e)))
249 }
250
251 async fn rollback(&self) -> DMSCResult<()> {
252 let mut guard = self.tx.lock().await;
253 let tx = guard.take()
254 .ok_or_else(|| DMSCError::Other("PostgreSQL transaction already closed".to_string()))?;
255
256 tx.rollback().await
257 .map_err(|e| DMSCError::Other(format!("PostgreSQL transaction rollback error: {}", e)))
258 }
259
260 async fn close(&self) -> DMSCResult<()> {
261 self.rollback().await
262 }
263}
264
265#[cfg(feature = "pyo3")]
266#[pyo3::prelude::pymethods]
267impl PostgresDatabase {
268 #[staticmethod]
269 pub fn from_connection_string(conn_string: &str, max_connections: u32) -> Result<Self, pyo3::PyErr> {
270 let rt = tokio::runtime::Runtime::new()
271 .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
272 format!("Failed to create Tokio runtime: {}", e),
273 ))?;
274
275 rt.block_on(async {
276 let pool = PgPool::connect(conn_string)
277 .await
278 .map_err(|e| pyo3::PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
279 format!("Failed to connect to PostgreSQL: {}", e),
280 ))?;
281
282 let db_config = DMSCDatabaseConfig::postgres()
283 .host("localhost")
284 .port(5432)
285 .database("postgres")
286 .max_connections(max_connections)
287 .build();
288
289 Ok(Self { pool, config: db_config })
290 })
291 }
292
293 pub fn execute_sync(&self, sql: &str) -> Result<u64, DMSCError> {
294 let rt = tokio::runtime::Runtime::new()
295 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
296 rt.block_on(async {
297 self.execute(sql).await
298 })
299 }
300
301 pub fn query_sync(&self, sql: &str) -> Result<DMSCDBResult, DMSCError> {
302 let rt = tokio::runtime::Runtime::new()
303 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
304 rt.block_on(async {
305 self.query(sql).await
306 })
307 }
308
309 pub fn ping_sync(&self) -> Result<bool, DMSCError> {
310 let rt = tokio::runtime::Runtime::new()
311 .map_err(|e| DMSCError::Other(format!("Failed to create Tokio runtime: {}", e)))?;
312 rt.block_on(async {
313 self.ping().await
314 })
315 }
316}