Skip to main content

sqlmodel_postgres/
config.rs

1//! PostgreSQL connection configuration.
2//!
3//! Provides connection parameters for establishing PostgreSQL connections
4//! including authentication, SSL, and connection options.
5
6use std::collections::HashMap;
7use std::time::Duration;
8
9/// SSL mode for PostgreSQL connections.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum SslMode {
12    /// Do not use SSL
13    #[default]
14    Disable,
15    /// Try SSL, fall back to non-SSL if unavailable
16    Prefer,
17    /// Require SSL connection
18    Require,
19    /// Require SSL and verify server certificate
20    VerifyCa,
21    /// Require SSL and verify server certificate matches hostname
22    VerifyFull,
23}
24
25impl SslMode {
26    /// Check if SSL should be attempted.
27    pub const fn should_try_ssl(self) -> bool {
28        !matches!(self, SslMode::Disable)
29    }
30
31    /// Check if SSL is required.
32    pub const fn is_required(self) -> bool {
33        matches!(
34            self,
35            SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull
36        )
37    }
38}
39
40/// PostgreSQL connection configuration.
41#[derive(Debug, Clone)]
42pub struct PgConfig {
43    /// Hostname or IP address
44    pub host: String,
45    /// Port number (default: 5432)
46    pub port: u16,
47    /// Username for authentication
48    pub user: String,
49    /// Password for authentication (optional for trust auth)
50    pub password: Option<String>,
51    /// Database name to connect to
52    pub database: String,
53    /// Application name (visible in pg_stat_activity)
54    pub application_name: Option<String>,
55    /// Connection timeout
56    pub connect_timeout: Duration,
57    /// SSL mode
58    pub ssl_mode: SslMode,
59    /// Additional connection parameters
60    pub options: HashMap<String, String>,
61}
62
63impl Default for PgConfig {
64    fn default() -> Self {
65        Self {
66            host: "localhost".to_string(),
67            port: 5432,
68            user: String::new(),
69            password: None,
70            database: String::new(),
71            application_name: None,
72            connect_timeout: Duration::from_secs(30),
73            ssl_mode: SslMode::default(),
74            options: HashMap::new(),
75        }
76    }
77}
78
79impl PgConfig {
80    /// Create a new configuration with the given connection string components.
81    pub fn new(
82        host: impl Into<String>,
83        user: impl Into<String>,
84        database: impl Into<String>,
85    ) -> Self {
86        Self {
87            host: host.into(),
88            user: user.into(),
89            database: database.into(),
90            ..Default::default()
91        }
92    }
93
94    /// Set the port.
95    pub fn port(mut self, port: u16) -> Self {
96        self.port = port;
97        self
98    }
99
100    /// Set the password.
101    pub fn password(mut self, password: impl Into<String>) -> Self {
102        self.password = Some(password.into());
103        self
104    }
105
106    /// Set the application name.
107    pub fn application_name(mut self, name: impl Into<String>) -> Self {
108        self.application_name = Some(name.into());
109        self
110    }
111
112    /// Set the connection timeout.
113    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
114        self.connect_timeout = timeout;
115        self
116    }
117
118    /// Set the SSL mode.
119    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
120        self.ssl_mode = mode;
121        self
122    }
123
124    /// Set an additional connection option.
125    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
126        self.options.insert(key.into(), value.into());
127        self
128    }
129
130    /// Build the startup parameters to send to the server.
131    pub fn startup_params(&self) -> Vec<(String, String)> {
132        let mut params = vec![
133            ("user".to_string(), self.user.clone()),
134            ("database".to_string(), self.database.clone()),
135            ("client_encoding".to_string(), "UTF8".to_string()),
136        ];
137
138        if let Some(app_name) = &self.application_name {
139            params.push(("application_name".to_string(), app_name.clone()));
140        }
141
142        for (k, v) in &self.options {
143            params.push((k.clone(), v.clone()));
144        }
145
146        params
147    }
148
149    /// Get the socket address string for connection.
150    pub fn socket_addr(&self) -> String {
151        format!("{}:{}", self.host, self.port)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_config_builder() {
161        let config = PgConfig::new("localhost", "postgres", "testdb")
162            .port(5433)
163            .password("secret")
164            .application_name("myapp")
165            .connect_timeout(Duration::from_secs(10))
166            .ssl_mode(SslMode::Prefer)
167            .option("timezone", "UTC");
168
169        assert_eq!(config.host, "localhost");
170        assert_eq!(config.port, 5433);
171        assert_eq!(config.user, "postgres");
172        assert_eq!(config.database, "testdb");
173        assert_eq!(config.password, Some("secret".to_string()));
174        assert_eq!(config.application_name, Some("myapp".to_string()));
175        assert_eq!(config.connect_timeout, Duration::from_secs(10));
176        assert_eq!(config.ssl_mode, SslMode::Prefer);
177        assert_eq!(config.options.get("timezone"), Some(&"UTC".to_string()));
178    }
179
180    #[test]
181    fn test_startup_params() {
182        let config = PgConfig::new("localhost", "postgres", "testdb")
183            .application_name("myapp")
184            .option("timezone", "UTC");
185
186        let params = config.startup_params();
187
188        assert!(params.iter().any(|(k, v)| k == "user" && v == "postgres"));
189        assert!(params.iter().any(|(k, v)| k == "database" && v == "testdb"));
190        assert!(
191            params
192                .iter()
193                .any(|(k, v)| k == "client_encoding" && v == "UTF8")
194        );
195        assert!(
196            params
197                .iter()
198                .any(|(k, v)| k == "application_name" && v == "myapp")
199        );
200        assert!(params.iter().any(|(k, v)| k == "timezone" && v == "UTC"));
201    }
202
203    #[test]
204    fn test_socket_addr() {
205        let config = PgConfig::new("db.example.com", "user", "db").port(5433);
206        assert_eq!(config.socket_addr(), "db.example.com:5433");
207    }
208
209    #[test]
210    fn test_ssl_mode_properties() {
211        assert!(!SslMode::Disable.should_try_ssl());
212        assert!(!SslMode::Disable.is_required());
213
214        assert!(SslMode::Prefer.should_try_ssl());
215        assert!(!SslMode::Prefer.is_required());
216
217        assert!(SslMode::Require.should_try_ssl());
218        assert!(SslMode::Require.is_required());
219
220        assert!(SslMode::VerifyCa.should_try_ssl());
221        assert!(SslMode::VerifyCa.is_required());
222
223        assert!(SslMode::VerifyFull.should_try_ssl());
224        assert!(SslMode::VerifyFull.is_required());
225    }
226}