sqlmodel_postgres/
config.rs1use std::collections::HashMap;
7use std::time::Duration;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum SslMode {
12 #[default]
14 Disable,
15 Prefer,
17 Require,
19 VerifyCa,
21 VerifyFull,
23}
24
25impl SslMode {
26 pub const fn should_try_ssl(self) -> bool {
28 !matches!(self, SslMode::Disable)
29 }
30
31 pub const fn is_required(self) -> bool {
33 matches!(
34 self,
35 SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull
36 )
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct PgConfig {
43 pub host: String,
45 pub port: u16,
47 pub user: String,
49 pub password: Option<String>,
51 pub database: String,
53 pub application_name: Option<String>,
55 pub connect_timeout: Duration,
57 pub ssl_mode: SslMode,
59 pub options: HashMap<String, String>,
61}
62
63impl Default for PgConfig {
64 fn default() -> Self {
65 Self {
66 host: "localhost".to_string(),
67 port: 5432,
68 user: String::new(),
69 password: None,
70 database: String::new(),
71 application_name: None,
72 connect_timeout: Duration::from_secs(30),
73 ssl_mode: SslMode::default(),
74 options: HashMap::new(),
75 }
76 }
77}
78
79impl PgConfig {
80 pub fn new(
82 host: impl Into<String>,
83 user: impl Into<String>,
84 database: impl Into<String>,
85 ) -> Self {
86 Self {
87 host: host.into(),
88 user: user.into(),
89 database: database.into(),
90 ..Default::default()
91 }
92 }
93
94 pub fn port(mut self, port: u16) -> Self {
96 self.port = port;
97 self
98 }
99
100 pub fn password(mut self, password: impl Into<String>) -> Self {
102 self.password = Some(password.into());
103 self
104 }
105
106 pub fn application_name(mut self, name: impl Into<String>) -> Self {
108 self.application_name = Some(name.into());
109 self
110 }
111
112 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
114 self.connect_timeout = timeout;
115 self
116 }
117
118 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
120 self.ssl_mode = mode;
121 self
122 }
123
124 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
126 self.options.insert(key.into(), value.into());
127 self
128 }
129
130 pub fn startup_params(&self) -> Vec<(String, String)> {
132 let mut params = vec![
133 ("user".to_string(), self.user.clone()),
134 ("database".to_string(), self.database.clone()),
135 ("client_encoding".to_string(), "UTF8".to_string()),
136 ];
137
138 if let Some(app_name) = &self.application_name {
139 params.push(("application_name".to_string(), app_name.clone()));
140 }
141
142 for (k, v) in &self.options {
143 params.push((k.clone(), v.clone()));
144 }
145
146 params
147 }
148
149 pub fn socket_addr(&self) -> String {
151 format!("{}:{}", self.host, self.port)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_config_builder() {
161 let config = PgConfig::new("localhost", "postgres", "testdb")
162 .port(5433)
163 .password("secret")
164 .application_name("myapp")
165 .connect_timeout(Duration::from_secs(10))
166 .ssl_mode(SslMode::Prefer)
167 .option("timezone", "UTC");
168
169 assert_eq!(config.host, "localhost");
170 assert_eq!(config.port, 5433);
171 assert_eq!(config.user, "postgres");
172 assert_eq!(config.database, "testdb");
173 assert_eq!(config.password, Some("secret".to_string()));
174 assert_eq!(config.application_name, Some("myapp".to_string()));
175 assert_eq!(config.connect_timeout, Duration::from_secs(10));
176 assert_eq!(config.ssl_mode, SslMode::Prefer);
177 assert_eq!(config.options.get("timezone"), Some(&"UTC".to_string()));
178 }
179
180 #[test]
181 fn test_startup_params() {
182 let config = PgConfig::new("localhost", "postgres", "testdb")
183 .application_name("myapp")
184 .option("timezone", "UTC");
185
186 let params = config.startup_params();
187
188 assert!(params.iter().any(|(k, v)| k == "user" && v == "postgres"));
189 assert!(params.iter().any(|(k, v)| k == "database" && v == "testdb"));
190 assert!(
191 params
192 .iter()
193 .any(|(k, v)| k == "client_encoding" && v == "UTF8")
194 );
195 assert!(
196 params
197 .iter()
198 .any(|(k, v)| k == "application_name" && v == "myapp")
199 );
200 assert!(params.iter().any(|(k, v)| k == "timezone" && v == "UTC"));
201 }
202
203 #[test]
204 fn test_socket_addr() {
205 let config = PgConfig::new("db.example.com", "user", "db").port(5433);
206 assert_eq!(config.socket_addr(), "db.example.com:5433");
207 }
208
209 #[test]
210 fn test_ssl_mode_properties() {
211 assert!(!SslMode::Disable.should_try_ssl());
212 assert!(!SslMode::Disable.is_required());
213
214 assert!(SslMode::Prefer.should_try_ssl());
215 assert!(!SslMode::Prefer.is_required());
216
217 assert!(SslMode::Require.should_try_ssl());
218 assert!(SslMode::Require.is_required());
219
220 assert!(SslMode::VerifyCa.should_try_ssl());
221 assert!(SslMode::VerifyCa.is_required());
222
223 assert!(SslMode::VerifyFull.should_try_ssl());
224 assert!(SslMode::VerifyFull.is_required());
225 }
226}