Skip to main content

sqlmodel_postgres/protocol/
reader.rs

1//! PostgreSQL message decoder.
2//!
3//! This module handles decoding backend messages from the wire protocol format.
4
5#![allow(clippy::cast_possible_truncation)]
6
7use super::messages::{
8    BackendMessage, ErrorFields, FieldDescription, TransactionStatus, auth_type, backend_type,
9};
10use std::error::Error as StdError;
11use std::fmt;
12
13/// Errors that can occur while decoding PostgreSQL protocol messages.
14#[derive(Debug)]
15pub enum ProtocolError {
16    /// Not enough bytes to parse a full message.
17    Incomplete,
18    /// Invalid length prefix encountered.
19    InvalidLength { length: i32 },
20    /// Message exceeds configured maximum size.
21    MessageTooLarge { length: usize, max: usize },
22    /// Unknown message type byte.
23    UnknownMessageType(u8),
24    /// UTF-8 decoding error while parsing strings.
25    Utf8(std::string::FromUtf8Error),
26    /// Unexpected end of buffer while parsing a field.
27    UnexpectedEof,
28    /// Invalid field encoding or value.
29    InvalidField(&'static str),
30}
31
32impl fmt::Display for ProtocolError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            ProtocolError::Incomplete => write!(f, "incomplete message"),
36            ProtocolError::InvalidLength { length } => {
37                write!(f, "invalid message length: {}", length)
38            }
39            ProtocolError::MessageTooLarge { length, max } => {
40                write!(f, "message too large: {} > {}", length, max)
41            }
42            ProtocolError::UnknownMessageType(ty) => {
43                write!(f, "unknown message type: 0x{:02x}", ty)
44            }
45            ProtocolError::Utf8(err) => write!(f, "utf-8 error: {}", err),
46            ProtocolError::UnexpectedEof => write!(f, "unexpected end of buffer"),
47            ProtocolError::InvalidField(msg) => write!(f, "invalid field: {}", msg),
48        }
49    }
50}
51
52impl StdError for ProtocolError {}
53
54impl From<std::string::FromUtf8Error> for ProtocolError {
55    fn from(err: std::string::FromUtf8Error) -> Self {
56        ProtocolError::Utf8(err)
57    }
58}
59
60/// Incremental reader for PostgreSQL backend messages.
61#[derive(Debug, Clone)]
62pub struct MessageReader {
63    buf: Vec<u8>,
64    max_message_size: usize,
65}
66
67impl Default for MessageReader {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl MessageReader {
74    /// Create a new reader with a default max message size.
75    pub fn new() -> Self {
76        Self::with_max_size(8 * 1024 * 1024)
77    }
78
79    /// Create a new reader with a custom max message size.
80    pub fn with_max_size(max_message_size: usize) -> Self {
81        Self {
82            buf: Vec::new(),
83            max_message_size,
84        }
85    }
86
87    /// Number of bytes currently buffered.
88    pub fn buffered_len(&self) -> usize {
89        self.buf.len()
90    }
91
92    /// Feed bytes into the reader and return any complete messages.
93    pub fn feed(&mut self, data: &[u8]) -> Result<Vec<BackendMessage>, ProtocolError> {
94        self.buf.extend_from_slice(data);
95
96        let mut messages = Vec::new();
97        while let Some(msg) = self.next_message()? {
98            messages.push(msg);
99        }
100        Ok(messages)
101    }
102
103    /// Attempt to parse the next message from the internal buffer.
104    pub fn next_message(&mut self) -> Result<Option<BackendMessage>, ProtocolError> {
105        if self.buf.len() < 5 {
106            return Ok(None);
107        }
108
109        let length = i32::from_be_bytes([self.buf[1], self.buf[2], self.buf[3], self.buf[4]]);
110        if length < 4 {
111            return Err(ProtocolError::InvalidLength { length });
112        }
113
114        let total_len = length as usize + 1;
115        if total_len > self.max_message_size {
116            return Err(ProtocolError::MessageTooLarge {
117                length: total_len,
118                max: self.max_message_size,
119            });
120        }
121
122        if self.buf.len() < total_len {
123            return Ok(None);
124        }
125
126        let frame = self.buf[..total_len].to_vec();
127        self.buf.drain(..total_len);
128        Ok(Some(Self::parse_message(&frame)?))
129    }
130
131    /// Parse a single full message frame (type + length + payload).
132    pub fn parse_message(frame: &[u8]) -> Result<BackendMessage, ProtocolError> {
133        if frame.len() < 5 {
134            return Err(ProtocolError::Incomplete);
135        }
136
137        let ty = frame[0];
138        let length = i32::from_be_bytes([frame[1], frame[2], frame[3], frame[4]]);
139        if length < 4 {
140            return Err(ProtocolError::InvalidLength { length });
141        }
142
143        let total_len = length as usize + 1;
144        if frame.len() < total_len {
145            return Err(ProtocolError::Incomplete);
146        }
147
148        let payload = &frame[5..total_len];
149        let mut cur = Cursor::new(payload);
150
151        match ty {
152            backend_type::AUTHENTICATION => parse_authentication(&mut cur),
153            backend_type::BACKEND_KEY_DATA => parse_backend_key_data(&mut cur),
154            backend_type::PARAMETER_STATUS => parse_parameter_status(&mut cur),
155            backend_type::READY_FOR_QUERY => parse_ready_for_query(&mut cur),
156            backend_type::ROW_DESCRIPTION => parse_row_description(&mut cur),
157            backend_type::DATA_ROW => parse_data_row(&mut cur),
158            backend_type::COMMAND_COMPLETE => parse_command_complete(&mut cur),
159            backend_type::EMPTY_QUERY => Ok(BackendMessage::EmptyQueryResponse),
160            backend_type::PARSE_COMPLETE => Ok(BackendMessage::ParseComplete),
161            backend_type::BIND_COMPLETE => Ok(BackendMessage::BindComplete),
162            backend_type::CLOSE_COMPLETE => Ok(BackendMessage::CloseComplete),
163            backend_type::PARAMETER_DESCRIPTION => parse_parameter_description(&mut cur),
164            backend_type::NO_DATA => Ok(BackendMessage::NoData),
165            backend_type::PORTAL_SUSPENDED => Ok(BackendMessage::PortalSuspended),
166            backend_type::ERROR_RESPONSE => parse_error_response(&mut cur, true),
167            backend_type::NOTICE_RESPONSE => parse_error_response(&mut cur, false),
168            backend_type::COPY_IN_RESPONSE => parse_copy_in_response(&mut cur),
169            backend_type::COPY_OUT_RESPONSE => parse_copy_out_response(&mut cur),
170            backend_type::COPY_BOTH_RESPONSE => parse_copy_both_response(&mut cur),
171            backend_type::COPY_DATA => Ok(BackendMessage::CopyData(cur.take_remaining())),
172            backend_type::COPY_DONE => Ok(BackendMessage::CopyDone),
173            backend_type::NOTIFICATION_RESPONSE => parse_notification_response(&mut cur),
174            backend_type::FUNCTION_CALL_RESPONSE => parse_function_call_response(&mut cur),
175            backend_type::NEGOTIATE_PROTOCOL_VERSION => parse_negotiate_protocol_version(&mut cur),
176            _ => Err(ProtocolError::UnknownMessageType(ty)),
177        }
178    }
179}
180
181fn parse_authentication(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
182    let auth_type = cur.read_i32()?;
183    match auth_type {
184        auth_type::OK => Ok(BackendMessage::AuthenticationOk),
185        auth_type::CLEARTEXT_PASSWORD => Ok(BackendMessage::AuthenticationCleartextPassword),
186        auth_type::MD5_PASSWORD => {
187            let salt = cur.read_bytes(4)?;
188            let mut buf = [0_u8; 4];
189            buf.copy_from_slice(salt);
190            Ok(BackendMessage::AuthenticationMD5Password(buf))
191        }
192        auth_type::SASL => {
193            let mut mechanisms = Vec::new();
194            loop {
195                let mech = cur.read_cstring()?;
196                if mech.is_empty() {
197                    break;
198                }
199                mechanisms.push(mech);
200            }
201            Ok(BackendMessage::AuthenticationSASL(mechanisms))
202        }
203        auth_type::SASL_CONTINUE => Ok(BackendMessage::AuthenticationSASLContinue(
204            cur.take_remaining(),
205        )),
206        auth_type::SASL_FINAL => Ok(BackendMessage::AuthenticationSASLFinal(
207            cur.take_remaining(),
208        )),
209        _ => Err(ProtocolError::InvalidField("unknown auth type")),
210    }
211}
212
213fn parse_backend_key_data(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
214    let process_id = cur.read_i32()?;
215    let secret_key = cur.read_i32()?;
216    Ok(BackendMessage::BackendKeyData {
217        process_id,
218        secret_key,
219    })
220}
221
222fn parse_parameter_status(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
223    let name = cur.read_cstring()?;
224    let value = cur.read_cstring()?;
225    Ok(BackendMessage::ParameterStatus { name, value })
226}
227
228fn parse_ready_for_query(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
229    let status = cur.read_u8()?;
230    let status = TransactionStatus::from_byte(status)
231        .ok_or(ProtocolError::InvalidField("invalid transaction status"))?;
232    Ok(BackendMessage::ReadyForQuery(status))
233}
234
235fn parse_row_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
236    let count = cur.read_i16()?;
237    if count < 0 {
238        return Err(ProtocolError::InvalidField("negative field count"));
239    }
240    let mut fields = Vec::with_capacity(count as usize);
241    for _ in 0..count {
242        let name = cur.read_cstring()?;
243        let table_oid = cur.read_u32()?;
244        let column_id = cur.read_i16()?;
245        let type_oid = cur.read_u32()?;
246        let type_size = cur.read_i16()?;
247        let type_modifier = cur.read_i32()?;
248        let format = cur.read_i16()?;
249        fields.push(FieldDescription {
250            name,
251            table_oid,
252            column_id,
253            type_oid,
254            type_size,
255            type_modifier,
256            format,
257        });
258    }
259    Ok(BackendMessage::RowDescription(fields))
260}
261
262fn parse_data_row(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
263    let count = cur.read_i16()?;
264    if count < 0 {
265        return Err(ProtocolError::InvalidField("negative column count"));
266    }
267    let mut values = Vec::with_capacity(count as usize);
268    for _ in 0..count {
269        let len = cur.read_i32()?;
270        if len == -1 {
271            values.push(None);
272            continue;
273        }
274        if len < 0 {
275            return Err(ProtocolError::InvalidField("negative data length"));
276        }
277        let bytes = cur.read_bytes(len as usize)?.to_vec();
278        values.push(Some(bytes));
279    }
280    Ok(BackendMessage::DataRow(values))
281}
282
283fn parse_command_complete(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
284    let tag = cur.read_cstring()?;
285    Ok(BackendMessage::CommandComplete(tag))
286}
287
288fn parse_parameter_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
289    let count = cur.read_i16()?;
290    if count < 0 {
291        return Err(ProtocolError::InvalidField("negative parameter count"));
292    }
293    let mut oids = Vec::with_capacity(count as usize);
294    for _ in 0..count {
295        oids.push(cur.read_u32()?);
296    }
297    Ok(BackendMessage::ParameterDescription(oids))
298}
299
300fn parse_copy_in_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
301    let format = cur.read_i8()?;
302    let column_formats = read_column_formats(cur)?;
303    Ok(BackendMessage::CopyInResponse {
304        format,
305        column_formats,
306    })
307}
308
309fn parse_copy_out_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
310    let format = cur.read_i8()?;
311    let column_formats = read_column_formats(cur)?;
312    Ok(BackendMessage::CopyOutResponse {
313        format,
314        column_formats,
315    })
316}
317
318fn parse_copy_both_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
319    let format = cur.read_i8()?;
320    let column_formats = read_column_formats(cur)?;
321    Ok(BackendMessage::CopyBothResponse {
322        format,
323        column_formats,
324    })
325}
326
327fn read_column_formats(cur: &mut Cursor<'_>) -> Result<Vec<i16>, ProtocolError> {
328    let count = cur.read_i16()?;
329    if count < 0 {
330        return Err(ProtocolError::InvalidField("negative format count"));
331    }
332    let mut formats = Vec::with_capacity(count as usize);
333    for _ in 0..count {
334        formats.push(cur.read_i16()?);
335    }
336    Ok(formats)
337}
338
339fn parse_notification_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
340    let process_id = cur.read_i32()?;
341    let channel = cur.read_cstring()?;
342    let payload = cur.read_cstring()?;
343    Ok(BackendMessage::NotificationResponse {
344        process_id,
345        channel,
346        payload,
347    })
348}
349
350fn parse_function_call_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
351    let len = cur.read_i32()?;
352    if len == -1 {
353        return Ok(BackendMessage::FunctionCallResponse(None));
354    }
355    if len < 0 {
356        return Err(ProtocolError::InvalidField("negative function length"));
357    }
358    let bytes = cur.read_bytes(len as usize)?.to_vec();
359    Ok(BackendMessage::FunctionCallResponse(Some(bytes)))
360}
361
362fn parse_negotiate_protocol_version(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
363    let newest_minor = cur.read_i32()?;
364    let count = cur.read_i32()?;
365    if count < 0 {
366        return Err(ProtocolError::InvalidField(
367            "negative protocol option count",
368        ));
369    }
370    let mut unrecognized = Vec::with_capacity(count as usize);
371    for _ in 0..count {
372        unrecognized.push(cur.read_cstring()?);
373    }
374    Ok(BackendMessage::NegotiateProtocolVersion {
375        newest_minor,
376        unrecognized,
377    })
378}
379
380fn parse_error_response(
381    cur: &mut Cursor<'_>,
382    is_error: bool,
383) -> Result<BackendMessage, ProtocolError> {
384    let mut fields = ErrorFields::default();
385    loop {
386        let code = cur.read_u8()?;
387        if code == 0 {
388            break;
389        }
390        let value = cur.read_cstring()?;
391        match code {
392            b'S' => fields.severity = value,
393            b'V' => fields.severity_localized = Some(value),
394            b'C' => fields.code = value,
395            b'M' => fields.message = value,
396            b'D' => fields.detail = Some(value),
397            b'H' => fields.hint = Some(value),
398            b'P' => fields.position = value.parse().ok(),
399            b'p' => fields.internal_position = value.parse().ok(),
400            b'q' => fields.internal_query = Some(value),
401            b'W' => fields.where_ = Some(value),
402            b's' => fields.schema = Some(value),
403            b't' => fields.table = Some(value),
404            b'c' => fields.column = Some(value),
405            b'd' => fields.data_type = Some(value),
406            b'n' => fields.constraint = Some(value),
407            b'F' => fields.file = Some(value),
408            b'L' => fields.line = value.parse().ok(),
409            b'R' => fields.routine = Some(value),
410            _ => {
411                // Ignore unknown fields.
412            }
413        }
414    }
415
416    if is_error {
417        Ok(BackendMessage::ErrorResponse(fields))
418    } else {
419        Ok(BackendMessage::NoticeResponse(fields))
420    }
421}
422
423#[derive(Debug)]
424struct Cursor<'a> {
425    buf: &'a [u8],
426    pos: usize,
427}
428
429impl<'a> Cursor<'a> {
430    fn new(buf: &'a [u8]) -> Self {
431        Self { buf, pos: 0 }
432    }
433
434    fn remaining(&self) -> usize {
435        self.buf.len().saturating_sub(self.pos)
436    }
437
438    fn read_u8(&mut self) -> Result<u8, ProtocolError> {
439        if self.remaining() < 1 {
440            return Err(ProtocolError::UnexpectedEof);
441        }
442        let b = self.buf[self.pos];
443        self.pos += 1;
444        Ok(b)
445    }
446
447    fn read_i8(&mut self) -> Result<i8, ProtocolError> {
448        let b = self.read_u8()?;
449        Ok(b as i8)
450    }
451
452    fn read_i16(&mut self) -> Result<i16, ProtocolError> {
453        let bytes = self.read_bytes(2)?;
454        Ok(i16::from_be_bytes([bytes[0], bytes[1]]))
455    }
456
457    fn read_u32(&mut self) -> Result<u32, ProtocolError> {
458        let bytes = self.read_bytes(4)?;
459        Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
460    }
461
462    fn read_i32(&mut self) -> Result<i32, ProtocolError> {
463        let bytes = self.read_bytes(4)?;
464        Ok(i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
465    }
466
467    fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], ProtocolError> {
468        if self.remaining() < n {
469            return Err(ProtocolError::UnexpectedEof);
470        }
471        let start = self.pos;
472        let end = self.pos + n;
473        self.pos = end;
474        Ok(&self.buf[start..end])
475    }
476
477    fn read_cstring(&mut self) -> Result<String, ProtocolError> {
478        let start = self.pos;
479        while self.pos < self.buf.len() && self.buf[self.pos] != 0 {
480            self.pos += 1;
481        }
482        if self.pos >= self.buf.len() {
483            return Err(ProtocolError::UnexpectedEof);
484        }
485        let bytes = self.buf[start..self.pos].to_vec();
486        self.pos += 1; // consume null terminator
487        Ok(String::from_utf8(bytes)?)
488    }
489
490    fn take_remaining(&mut self) -> Vec<u8> {
491        let remaining = self.buf[self.pos..].to_vec();
492        self.pos = self.buf.len();
493        remaining
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[allow(clippy::cast_possible_truncation)]
502    fn build_message(ty: u8, payload: &[u8]) -> Vec<u8> {
503        let mut buf = Vec::new();
504        buf.push(ty);
505        let len = (payload.len() + 4) as i32;
506        buf.extend_from_slice(&len.to_be_bytes());
507        buf.extend_from_slice(payload);
508        buf
509    }
510
511    #[test]
512    fn parse_auth_ok() {
513        let mut payload = Vec::new();
514        payload.extend_from_slice(&auth_type::OK.to_be_bytes());
515        let msg = build_message(backend_type::AUTHENTICATION, &payload);
516        let decoded = MessageReader::parse_message(&msg).unwrap();
517        assert!(matches!(decoded, BackendMessage::AuthenticationOk));
518    }
519
520    #[test]
521    fn parse_ready_for_query() {
522        let payload = [TransactionStatus::Idle.as_byte()];
523        let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
524        let decoded = MessageReader::parse_message(&msg).unwrap();
525        assert!(matches!(
526            decoded,
527            BackendMessage::ReadyForQuery(TransactionStatus::Idle)
528        ));
529    }
530
531    #[test]
532    fn parse_error_response() {
533        let mut payload = Vec::new();
534        payload.push(b'S');
535        payload.extend_from_slice(b"ERROR\0");
536        payload.push(b'C');
537        payload.extend_from_slice(b"12345\0");
538        payload.push(b'M');
539        payload.extend_from_slice(b"bad\0");
540        payload.push(0);
541
542        let msg = build_message(backend_type::ERROR_RESPONSE, &payload);
543        let decoded = MessageReader::parse_message(&msg).unwrap();
544        match decoded {
545            BackendMessage::ErrorResponse(fields) => {
546                assert_eq!(fields.severity, "ERROR");
547                assert_eq!(fields.code, "12345");
548                assert_eq!(fields.message, "bad");
549            }
550            _ => panic!("unexpected message"),
551        }
552    }
553
554    #[test]
555    fn parse_data_row() {
556        let mut payload = Vec::new();
557        payload.extend_from_slice(&(2_i16).to_be_bytes());
558        payload.extend_from_slice(&(3_i32).to_be_bytes());
559        payload.extend_from_slice(b"foo");
560        payload.extend_from_slice(&(-1_i32).to_be_bytes());
561
562        let msg = build_message(backend_type::DATA_ROW, &payload);
563        let decoded = MessageReader::parse_message(&msg).unwrap();
564        match decoded {
565            BackendMessage::DataRow(values) => {
566                assert_eq!(values.len(), 2);
567                assert_eq!(values[0].as_deref(), Some(b"foo".as_slice()));
568                assert!(values[1].is_none());
569            }
570            _ => panic!("unexpected message"),
571        }
572    }
573
574    #[test]
575    fn reader_buffers_partial_frames() {
576        let payload = [TransactionStatus::Idle.as_byte()];
577        let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
578        let (left, right) = msg.split_at(3);
579
580        let mut reader = MessageReader::new();
581        let first = reader.feed(left).unwrap();
582        assert!(first.is_empty());
583
584        let second = reader.feed(right).unwrap();
585        assert_eq!(second.len(), 1);
586    }
587
588    #[test]
589    fn parse_row_description_negative_count_rejected() {
590        // ROW_DESCRIPTION with negative field count (-1)
591        let payload = (-1_i16).to_be_bytes();
592        let msg = build_message(backend_type::ROW_DESCRIPTION, &payload);
593        let result = MessageReader::parse_message(&msg);
594        assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
595    }
596
597    #[test]
598    fn parse_data_row_negative_count_rejected() {
599        // DATA_ROW with negative column count (-1)
600        let payload = (-1_i16).to_be_bytes();
601        let msg = build_message(backend_type::DATA_ROW, &payload);
602        let result = MessageReader::parse_message(&msg);
603        assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
604    }
605
606    #[test]
607    fn parse_parameter_description_negative_count_rejected() {
608        // PARAMETER_DESCRIPTION with negative parameter count (-1)
609        let payload = (-1_i16).to_be_bytes();
610        let msg = build_message(backend_type::PARAMETER_DESCRIPTION, &payload);
611        let result = MessageReader::parse_message(&msg);
612        assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
613    }
614}