Skip to main content

sqlmodel_postgres/protocol/
messages.rs

1//! Message definitions for PostgreSQL protocol.
2
3use std::fmt;
4
5/// Protocol version 3.0.
6pub const PROTOCOL_VERSION: i32 = 196_608; // 3 << 16
7
8/// Cancel request code.
9pub const CANCEL_REQUEST_CODE: i32 = 80_877_102; // 1234 << 16 | 5678
10
11/// SSL request code.
12pub const SSL_REQUEST_CODE: i32 = 80_877_103; // 1234 << 16 | 5679
13// ==================== Frontend Messages (Client -> Server) ====================
14
15/// Messages sent from the client to the PostgreSQL server.
16#[derive(Debug, Clone, PartialEq)]
17pub enum FrontendMessage {
18    /// Startup message (no type byte) - first message sent after connecting
19    Startup {
20        /// Protocol version (196608 for 3.0)
21        version: i32,
22        /// Connection parameters (user, database, etc.)
23        params: Vec<(String, String)>,
24    },
25
26    /// Password response for authentication
27    PasswordMessage(String),
28
29    /// SASL initial response (mechanism selection and initial data)
30    SASLInitialResponse {
31        /// SASL mechanism name (e.g., "SCRAM-SHA-256")
32        mechanism: String,
33        /// Initial response data
34        data: Vec<u8>,
35    },
36
37    /// SASL response (continuation data)
38    SASLResponse(Vec<u8>),
39
40    /// Simple query (single SQL string, returns text format)
41    Query(String),
42
43    /// Parse a prepared statement (extended query protocol)
44    Parse {
45        /// Statement name ("" for unnamed)
46        name: String,
47        /// SQL query with $1, $2, etc. placeholders
48        query: String,
49        /// Parameter type OIDs (0 for server to infer)
50        param_types: Vec<u32>,
51    },
52
53    /// Bind parameters to a prepared statement
54    Bind {
55        /// Portal name ("" for unnamed)
56        portal: String,
57        /// Statement name to bind to
58        statement: String,
59        /// Parameter format codes (0=text, 1=binary)
60        param_formats: Vec<i16>,
61        /// Parameter values (None for NULL)
62        params: Vec<Option<Vec<u8>>>,
63        /// Result format codes (0=text, 1=binary)
64        result_formats: Vec<i16>,
65    },
66
67    /// Describe a prepared statement or portal
68    Describe {
69        /// 'S' for statement, 'P' for portal
70        kind: DescribeKind,
71        /// Name of statement/portal
72        name: String,
73    },
74
75    /// Execute a bound portal
76    Execute {
77        /// Portal name
78        portal: String,
79        /// Maximum rows to return (0 for all)
80        max_rows: i32,
81    },
82
83    /// Close a prepared statement or portal
84    Close {
85        /// 'S' for statement, 'P' for portal
86        kind: DescribeKind,
87        /// Name of statement/portal
88        name: String,
89    },
90
91    /// Sync - marks end of extended query, requests ReadyForQuery
92    Sync,
93
94    /// Flush - request server to send all pending output
95    Flush,
96
97    /// COPY data chunk
98    CopyData(Vec<u8>),
99
100    /// COPY operation complete
101    CopyDone,
102
103    /// COPY operation failed
104    CopyFail(String),
105
106    /// Terminate the connection
107    Terminate,
108
109    /// Cancel a running query (sent on a separate connection)
110    CancelRequest {
111        /// Backend process ID
112        process_id: i32,
113        /// Secret key from BackendKeyData
114        secret_key: i32,
115    },
116
117    /// SSL negotiation request
118    SSLRequest,
119}
120
121/// Kind for Describe/Close messages
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum DescribeKind {
124    /// Statement ('S')
125    Statement,
126    /// Portal ('P')
127    Portal,
128}
129
130impl DescribeKind {
131    /// Get the wire protocol byte for this kind.
132    pub const fn as_byte(self) -> u8 {
133        match self {
134            DescribeKind::Statement => b'S',
135            DescribeKind::Portal => b'P',
136        }
137    }
138
139    /// Parse from wire protocol byte.
140    pub fn from_byte(b: u8) -> Option<Self> {
141        match b {
142            b'S' => Some(DescribeKind::Statement),
143            b'P' => Some(DescribeKind::Portal),
144            _ => None,
145        }
146    }
147}
148
149// ==================== Backend Messages (Server -> Client) ====================
150
151/// Messages sent from the PostgreSQL server to the client.
152#[derive(Debug, Clone, PartialEq)]
153pub enum BackendMessage {
154    // Authentication messages
155    /// Authentication successful
156    AuthenticationOk,
157    /// Server requests cleartext password
158    AuthenticationCleartextPassword,
159    /// Server requests MD5-hashed password with salt
160    AuthenticationMD5Password([u8; 4]),
161    /// Server requests SASL authentication (lists mechanisms)
162    AuthenticationSASL(Vec<String>),
163    /// SASL authentication continuation data
164    AuthenticationSASLContinue(Vec<u8>),
165    /// SASL authentication final data
166    AuthenticationSASLFinal(Vec<u8>),
167
168    // Connection info
169    /// Backend process ID and secret key for cancellation
170    BackendKeyData {
171        /// Process ID
172        process_id: i32,
173        /// Secret key
174        secret_key: i32,
175    },
176    /// Server parameter status (e.g., server_encoding, TimeZone)
177    ParameterStatus {
178        /// Parameter name
179        name: String,
180        /// Parameter value
181        value: String,
182    },
183    /// Server is ready for a new query
184    ReadyForQuery(TransactionStatus),
185
186    // Query results
187    /// Describes the columns of a result set
188    RowDescription(Vec<FieldDescription>),
189    /// A single data row
190    DataRow(Vec<Option<Vec<u8>>>),
191    /// Query completed successfully
192    CommandComplete(String),
193    /// Empty query response
194    EmptyQueryResponse,
195
196    // Extended query protocol responses
197    /// Parse completed successfully
198    ParseComplete,
199    /// Bind completed successfully
200    BindComplete,
201    /// Close completed successfully
202    CloseComplete,
203    /// Describes parameter types for a prepared statement
204    ParameterDescription(Vec<u32>),
205    /// No data will be returned
206    NoData,
207    /// Portal execution suspended (reached max_rows)
208    PortalSuspended,
209
210    // Errors and notices
211    /// Error response with details
212    ErrorResponse(ErrorFields),
213    /// Notice (warning) with details
214    NoticeResponse(ErrorFields),
215
216    // COPY protocol
217    /// Server is ready to receive COPY data
218    CopyInResponse {
219        /// Overall COPY format (0=text, 1=binary)
220        format: i8,
221        /// Per-column format codes
222        column_formats: Vec<i16>,
223    },
224    /// Server is sending COPY data
225    CopyOutResponse {
226        /// Overall COPY format (0=text, 1=binary)
227        format: i8,
228        /// Per-column format codes
229        column_formats: Vec<i16>,
230    },
231    /// COPY data chunk
232    CopyData(Vec<u8>),
233    /// COPY operation complete
234    CopyDone,
235    /// COPY data format information for both directions
236    CopyBothResponse {
237        /// Overall COPY format (0=text, 1=binary)
238        format: i8,
239        /// Per-column format codes
240        column_formats: Vec<i16>,
241    },
242
243    // Notifications
244    /// Asynchronous notification (from LISTEN/NOTIFY)
245    NotificationResponse {
246        /// Backend process ID that sent the notification
247        process_id: i32,
248        /// Channel name
249        channel: String,
250        /// Payload string
251        payload: String,
252    },
253
254    // Function call (legacy, rarely used)
255    /// Function call result
256    FunctionCallResponse(Option<Vec<u8>>),
257
258    // Negotiate protocol version
259    /// Server doesn't support requested protocol features
260    NegotiateProtocolVersion {
261        /// Server's newest supported minor version
262        newest_minor: i32,
263        /// Unrecognized options
264        unrecognized: Vec<String>,
265    },
266}
267
268/// Transaction status indicator from ReadyForQuery
269#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
270pub enum TransactionStatus {
271    /// Idle - not in a transaction block
272    #[default]
273    Idle,
274    /// In a transaction block
275    Transaction,
276    /// In a failed transaction block
277    Error,
278}
279
280impl TransactionStatus {
281    /// Get the wire protocol byte for this status.
282    pub const fn as_byte(self) -> u8 {
283        match self {
284            TransactionStatus::Idle => b'I',
285            TransactionStatus::Transaction => b'T',
286            TransactionStatus::Error => b'E',
287        }
288    }
289
290    /// Parse from wire protocol byte.
291    pub fn from_byte(b: u8) -> Option<Self> {
292        match b {
293            b'I' => Some(TransactionStatus::Idle),
294            b'T' => Some(TransactionStatus::Transaction),
295            b'E' => Some(TransactionStatus::Error),
296            _ => None,
297        }
298    }
299}
300
301impl fmt::Display for TransactionStatus {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        match self {
304            TransactionStatus::Idle => write!(f, "idle"),
305            TransactionStatus::Transaction => write!(f, "in transaction"),
306            TransactionStatus::Error => write!(f, "in failed transaction"),
307        }
308    }
309}
310
311/// Describes a single field (column) in a row description.
312#[derive(Debug, Clone, PartialEq, Eq)]
313pub struct FieldDescription {
314    /// Column name
315    pub name: String,
316    /// OID of the table (0 if not from a table)
317    pub table_oid: u32,
318    /// Attribute number in the table (0 if not from a table)
319    pub column_id: i16,
320    /// OID of the column's data type
321    pub type_oid: u32,
322    /// Data type size (-1 for variable-length types)
323    pub type_size: i16,
324    /// Type modifier (e.g., precision for NUMERIC)
325    pub type_modifier: i32,
326    /// Format code (0=text, 1=binary)
327    pub format: i16,
328}
329
330impl FieldDescription {
331    /// Check if this field uses binary format.
332    pub const fn is_binary(&self) -> bool {
333        self.format == 1
334    }
335
336    /// Check if this field uses text format.
337    pub const fn is_text(&self) -> bool {
338        self.format == 0
339    }
340}
341
342/// Error and notice response fields.
343///
344/// PostgreSQL error responses contain multiple fields identified by single-byte codes.
345/// All fields are optional except severity, code, and message.
346#[derive(Debug, Clone, PartialEq, Eq, Default)]
347pub struct ErrorFields {
348    /// Severity (ERROR, FATAL, PANIC, WARNING, NOTICE, DEBUG, INFO, LOG)
349    pub severity: String,
350    /// Localized severity (for display)
351    pub severity_localized: Option<String>,
352    /// SQLSTATE code (e.g., "23505" for unique_violation)
353    pub code: String,
354    /// Primary error message
355    pub message: String,
356    /// Optional secondary message with more detail
357    pub detail: Option<String>,
358    /// Optional suggestion for fixing the problem
359    pub hint: Option<String>,
360    /// Position in query string (1-based)
361    pub position: Option<i32>,
362    /// Position in internal query
363    pub internal_position: Option<i32>,
364    /// Internal query that generated the error
365    pub internal_query: Option<String>,
366    /// Call stack context
367    pub where_: Option<String>,
368    /// Schema name
369    pub schema: Option<String>,
370    /// Table name
371    pub table: Option<String>,
372    /// Column name
373    pub column: Option<String>,
374    /// Data type name
375    pub data_type: Option<String>,
376    /// Constraint name
377    pub constraint: Option<String>,
378    /// Source file name
379    pub file: Option<String>,
380    /// Source line number
381    pub line: Option<i32>,
382    /// Source routine name
383    pub routine: Option<String>,
384}
385
386impl ErrorFields {
387    /// Check if this is a fatal error.
388    pub fn is_fatal(&self) -> bool {
389        self.severity == "FATAL" || self.severity == "PANIC"
390    }
391
392    /// Check if this is a regular error.
393    pub fn is_error(&self) -> bool {
394        self.severity == "ERROR"
395    }
396
397    /// Check if this is a warning or notice.
398    pub fn is_warning(&self) -> bool {
399        matches!(
400            self.severity.as_str(),
401            "WARNING" | "NOTICE" | "DEBUG" | "INFO" | "LOG"
402        )
403    }
404
405    /// Get the SQLSTATE error class (first two characters).
406    pub fn error_class(&self) -> &str {
407        if self.code.len() >= 2 {
408            &self.code[..2]
409        } else {
410            &self.code
411        }
412    }
413}
414
415impl fmt::Display for ErrorFields {
416    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417        write!(f, "{}: {} ({})", self.severity, self.message, self.code)?;
418        if let Some(detail) = &self.detail {
419            write!(f, "\nDETAIL: {detail}")?;
420        }
421        if let Some(hint) = &self.hint {
422            write!(f, "\nHINT: {hint}")?;
423        }
424        if let Some(pos) = self.position {
425            write!(f, "\nPOSITION: {pos}")?;
426        }
427        if let Some(where_) = &self.where_ {
428            write!(f, "\nCONTEXT: {where_}")?;
429        }
430        Ok(())
431    }
432}
433
434// ==================== Message Type Bytes ====================
435
436/// Message type bytes for frontend messages.
437pub mod frontend_type {
438    pub const PASSWORD: u8 = b'p';
439    pub const QUERY: u8 = b'Q';
440    pub const PARSE: u8 = b'P';
441    pub const BIND: u8 = b'B';
442    pub const DESCRIBE: u8 = b'D';
443    pub const EXECUTE: u8 = b'E';
444    pub const CLOSE: u8 = b'C';
445    pub const SYNC: u8 = b'S';
446    pub const FLUSH: u8 = b'H';
447    pub const COPY_DATA: u8 = b'd';
448    pub const COPY_DONE: u8 = b'c';
449    pub const COPY_FAIL: u8 = b'f';
450    pub const TERMINATE: u8 = b'X';
451}
452
453/// Message type bytes for backend messages.
454pub mod backend_type {
455    pub const AUTHENTICATION: u8 = b'R';
456    pub const BACKEND_KEY_DATA: u8 = b'K';
457    pub const PARAMETER_STATUS: u8 = b'S';
458    pub const READY_FOR_QUERY: u8 = b'Z';
459    pub const ROW_DESCRIPTION: u8 = b'T';
460    pub const DATA_ROW: u8 = b'D';
461    pub const COMMAND_COMPLETE: u8 = b'C';
462    pub const EMPTY_QUERY: u8 = b'I';
463    pub const PARSE_COMPLETE: u8 = b'1';
464    pub const BIND_COMPLETE: u8 = b'2';
465    pub const CLOSE_COMPLETE: u8 = b'3';
466    pub const PARAMETER_DESCRIPTION: u8 = b't';
467    pub const NO_DATA: u8 = b'n';
468    pub const PORTAL_SUSPENDED: u8 = b's';
469    pub const ERROR_RESPONSE: u8 = b'E';
470    pub const NOTICE_RESPONSE: u8 = b'N';
471    pub const COPY_IN_RESPONSE: u8 = b'G';
472    pub const COPY_OUT_RESPONSE: u8 = b'H';
473    pub const COPY_DATA: u8 = b'd';
474    pub const COPY_DONE: u8 = b'c';
475    pub const COPY_BOTH_RESPONSE: u8 = b'W';
476    pub const NOTIFICATION_RESPONSE: u8 = b'A';
477    pub const FUNCTION_CALL_RESPONSE: u8 = b'V';
478    pub const NEGOTIATE_PROTOCOL_VERSION: u8 = b'v';
479}
480
481// ==================== Authentication Type Codes ====================
482
483/// Authentication method codes from the server.
484pub mod auth_type {
485    pub const OK: i32 = 0;
486    pub const KERBEROS_V5: i32 = 2;
487    pub const CLEARTEXT_PASSWORD: i32 = 3;
488    pub const MD5_PASSWORD: i32 = 5;
489    pub const SCM_CREDENTIAL: i32 = 6;
490    pub const GSS: i32 = 7;
491    pub const GSS_CONTINUE: i32 = 8;
492    pub const SSPI: i32 = 9;
493    pub const SASL: i32 = 10;
494    pub const SASL_CONTINUE: i32 = 11;
495    pub const SASL_FINAL: i32 = 12;
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_transaction_status_roundtrip() {
504        for status in [
505            TransactionStatus::Idle,
506            TransactionStatus::Transaction,
507            TransactionStatus::Error,
508        ] {
509            let byte = status.as_byte();
510            let parsed = TransactionStatus::from_byte(byte).unwrap();
511            assert_eq!(status, parsed);
512        }
513    }
514
515    #[test]
516    fn test_describe_kind_roundtrip() {
517        for kind in [DescribeKind::Statement, DescribeKind::Portal] {
518            let byte = kind.as_byte();
519            let parsed = DescribeKind::from_byte(byte).unwrap();
520            assert_eq!(kind, parsed);
521        }
522    }
523
524    #[test]
525    fn test_error_fields_display() {
526        let err = ErrorFields {
527            severity: "ERROR".to_string(),
528            code: "23505".to_string(),
529            message: "duplicate key value violates unique constraint".to_string(),
530            detail: Some("Key (id)=(1) already exists.".to_string()),
531            hint: None,
532            ..Default::default()
533        };
534
535        let display = format!("{err}");
536        assert!(display.contains("ERROR"));
537        assert!(display.contains("23505"));
538        assert!(display.contains("duplicate key"));
539        assert!(display.contains("Key (id)=(1)"));
540    }
541
542    #[test]
543    fn test_error_fields_classification() {
544        let fatal = ErrorFields {
545            severity: "FATAL".to_string(),
546            code: "XX000".to_string(),
547            message: "internal error".to_string(),
548            ..Default::default()
549        };
550        assert!(fatal.is_fatal());
551        assert!(!fatal.is_error());
552        assert!(!fatal.is_warning());
553
554        let error = ErrorFields {
555            severity: "ERROR".to_string(),
556            code: "23505".to_string(),
557            message: "constraint violation".to_string(),
558            ..Default::default()
559        };
560        assert!(!error.is_fatal());
561        assert!(error.is_error());
562        assert!(!error.is_warning());
563
564        let warning = ErrorFields {
565            severity: "WARNING".to_string(),
566            code: "01000".to_string(),
567            message: "deprecated feature".to_string(),
568            ..Default::default()
569        };
570        assert!(!warning.is_fatal());
571        assert!(!warning.is_error());
572        assert!(warning.is_warning());
573    }
574
575    #[test]
576    fn test_error_class() {
577        let err = ErrorFields {
578            code: "23505".to_string(),
579            ..Default::default()
580        };
581        assert_eq!(err.error_class(), "23");
582    }
583
584    #[test]
585    fn test_field_description_format() {
586        let text_field = FieldDescription {
587            name: "id".to_string(),
588            table_oid: 0,
589            column_id: 0,
590            type_oid: 23,
591            type_size: 4,
592            type_modifier: -1,
593            format: 0,
594        };
595        assert!(text_field.is_text());
596        assert!(!text_field.is_binary());
597
598        let binary_field = FieldDescription {
599            format: 1,
600            ..text_field
601        };
602        assert!(!binary_field.is_text());
603        assert!(binary_field.is_binary());
604    }
605}