1#![allow(clippy::cast_possible_truncation)]
6
7use super::messages::{
8 BackendMessage, ErrorFields, FieldDescription, TransactionStatus, auth_type, backend_type,
9};
10use std::error::Error as StdError;
11use std::fmt;
12
13#[derive(Debug)]
15pub enum ProtocolError {
16 Incomplete,
18 InvalidLength { length: i32 },
20 MessageTooLarge { length: usize, max: usize },
22 UnknownMessageType(u8),
24 Utf8(std::string::FromUtf8Error),
26 UnexpectedEof,
28 InvalidField(&'static str),
30}
31
32impl fmt::Display for ProtocolError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 ProtocolError::Incomplete => write!(f, "incomplete message"),
36 ProtocolError::InvalidLength { length } => {
37 write!(f, "invalid message length: {}", length)
38 }
39 ProtocolError::MessageTooLarge { length, max } => {
40 write!(f, "message too large: {} > {}", length, max)
41 }
42 ProtocolError::UnknownMessageType(ty) => {
43 write!(f, "unknown message type: 0x{:02x}", ty)
44 }
45 ProtocolError::Utf8(err) => write!(f, "utf-8 error: {}", err),
46 ProtocolError::UnexpectedEof => write!(f, "unexpected end of buffer"),
47 ProtocolError::InvalidField(msg) => write!(f, "invalid field: {}", msg),
48 }
49 }
50}
51
52impl StdError for ProtocolError {}
53
54impl From<std::string::FromUtf8Error> for ProtocolError {
55 fn from(err: std::string::FromUtf8Error) -> Self {
56 ProtocolError::Utf8(err)
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct MessageReader {
63 buf: Vec<u8>,
64 max_message_size: usize,
65}
66
67impl Default for MessageReader {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl MessageReader {
74 pub fn new() -> Self {
76 Self::with_max_size(8 * 1024 * 1024)
77 }
78
79 pub fn with_max_size(max_message_size: usize) -> Self {
81 Self {
82 buf: Vec::new(),
83 max_message_size,
84 }
85 }
86
87 pub fn buffered_len(&self) -> usize {
89 self.buf.len()
90 }
91
92 pub fn feed(&mut self, data: &[u8]) -> Result<Vec<BackendMessage>, ProtocolError> {
94 self.buf.extend_from_slice(data);
95
96 let mut messages = Vec::new();
97 while let Some(msg) = self.next_message()? {
98 messages.push(msg);
99 }
100 Ok(messages)
101 }
102
103 pub fn next_message(&mut self) -> Result<Option<BackendMessage>, ProtocolError> {
105 if self.buf.len() < 5 {
106 return Ok(None);
107 }
108
109 let length = i32::from_be_bytes([self.buf[1], self.buf[2], self.buf[3], self.buf[4]]);
110 if length < 4 {
111 return Err(ProtocolError::InvalidLength { length });
112 }
113
114 let total_len = length as usize + 1;
115 if total_len > self.max_message_size {
116 return Err(ProtocolError::MessageTooLarge {
117 length: total_len,
118 max: self.max_message_size,
119 });
120 }
121
122 if self.buf.len() < total_len {
123 return Ok(None);
124 }
125
126 let frame = self.buf[..total_len].to_vec();
127 self.buf.drain(..total_len);
128 Ok(Some(Self::parse_message(&frame)?))
129 }
130
131 pub fn parse_message(frame: &[u8]) -> Result<BackendMessage, ProtocolError> {
133 if frame.len() < 5 {
134 return Err(ProtocolError::Incomplete);
135 }
136
137 let ty = frame[0];
138 let length = i32::from_be_bytes([frame[1], frame[2], frame[3], frame[4]]);
139 if length < 4 {
140 return Err(ProtocolError::InvalidLength { length });
141 }
142
143 let total_len = length as usize + 1;
144 if frame.len() < total_len {
145 return Err(ProtocolError::Incomplete);
146 }
147
148 let payload = &frame[5..total_len];
149 let mut cur = Cursor::new(payload);
150
151 match ty {
152 backend_type::AUTHENTICATION => parse_authentication(&mut cur),
153 backend_type::BACKEND_KEY_DATA => parse_backend_key_data(&mut cur),
154 backend_type::PARAMETER_STATUS => parse_parameter_status(&mut cur),
155 backend_type::READY_FOR_QUERY => parse_ready_for_query(&mut cur),
156 backend_type::ROW_DESCRIPTION => parse_row_description(&mut cur),
157 backend_type::DATA_ROW => parse_data_row(&mut cur),
158 backend_type::COMMAND_COMPLETE => parse_command_complete(&mut cur),
159 backend_type::EMPTY_QUERY => Ok(BackendMessage::EmptyQueryResponse),
160 backend_type::PARSE_COMPLETE => Ok(BackendMessage::ParseComplete),
161 backend_type::BIND_COMPLETE => Ok(BackendMessage::BindComplete),
162 backend_type::CLOSE_COMPLETE => Ok(BackendMessage::CloseComplete),
163 backend_type::PARAMETER_DESCRIPTION => parse_parameter_description(&mut cur),
164 backend_type::NO_DATA => Ok(BackendMessage::NoData),
165 backend_type::PORTAL_SUSPENDED => Ok(BackendMessage::PortalSuspended),
166 backend_type::ERROR_RESPONSE => parse_error_response(&mut cur, true),
167 backend_type::NOTICE_RESPONSE => parse_error_response(&mut cur, false),
168 backend_type::COPY_IN_RESPONSE => parse_copy_in_response(&mut cur),
169 backend_type::COPY_OUT_RESPONSE => parse_copy_out_response(&mut cur),
170 backend_type::COPY_BOTH_RESPONSE => parse_copy_both_response(&mut cur),
171 backend_type::COPY_DATA => Ok(BackendMessage::CopyData(cur.take_remaining())),
172 backend_type::COPY_DONE => Ok(BackendMessage::CopyDone),
173 backend_type::NOTIFICATION_RESPONSE => parse_notification_response(&mut cur),
174 backend_type::FUNCTION_CALL_RESPONSE => parse_function_call_response(&mut cur),
175 backend_type::NEGOTIATE_PROTOCOL_VERSION => parse_negotiate_protocol_version(&mut cur),
176 _ => Err(ProtocolError::UnknownMessageType(ty)),
177 }
178 }
179}
180
181fn parse_authentication(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
182 let auth_type = cur.read_i32()?;
183 match auth_type {
184 auth_type::OK => Ok(BackendMessage::AuthenticationOk),
185 auth_type::CLEARTEXT_PASSWORD => Ok(BackendMessage::AuthenticationCleartextPassword),
186 auth_type::MD5_PASSWORD => {
187 let salt = cur.read_bytes(4)?;
188 let mut buf = [0_u8; 4];
189 buf.copy_from_slice(salt);
190 Ok(BackendMessage::AuthenticationMD5Password(buf))
191 }
192 auth_type::SASL => {
193 let mut mechanisms = Vec::new();
194 loop {
195 let mech = cur.read_cstring()?;
196 if mech.is_empty() {
197 break;
198 }
199 mechanisms.push(mech);
200 }
201 Ok(BackendMessage::AuthenticationSASL(mechanisms))
202 }
203 auth_type::SASL_CONTINUE => Ok(BackendMessage::AuthenticationSASLContinue(
204 cur.take_remaining(),
205 )),
206 auth_type::SASL_FINAL => Ok(BackendMessage::AuthenticationSASLFinal(
207 cur.take_remaining(),
208 )),
209 _ => Err(ProtocolError::InvalidField("unknown auth type")),
210 }
211}
212
213fn parse_backend_key_data(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
214 let process_id = cur.read_i32()?;
215 let secret_key = cur.read_i32()?;
216 Ok(BackendMessage::BackendKeyData {
217 process_id,
218 secret_key,
219 })
220}
221
222fn parse_parameter_status(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
223 let name = cur.read_cstring()?;
224 let value = cur.read_cstring()?;
225 Ok(BackendMessage::ParameterStatus { name, value })
226}
227
228fn parse_ready_for_query(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
229 let status = cur.read_u8()?;
230 let status = TransactionStatus::from_byte(status)
231 .ok_or(ProtocolError::InvalidField("invalid transaction status"))?;
232 Ok(BackendMessage::ReadyForQuery(status))
233}
234
235fn parse_row_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
236 let count = cur.read_i16()?;
237 if count < 0 {
238 return Err(ProtocolError::InvalidField("negative field count"));
239 }
240 let mut fields = Vec::with_capacity(count as usize);
241 for _ in 0..count {
242 let name = cur.read_cstring()?;
243 let table_oid = cur.read_u32()?;
244 let column_id = cur.read_i16()?;
245 let type_oid = cur.read_u32()?;
246 let type_size = cur.read_i16()?;
247 let type_modifier = cur.read_i32()?;
248 let format = cur.read_i16()?;
249 fields.push(FieldDescription {
250 name,
251 table_oid,
252 column_id,
253 type_oid,
254 type_size,
255 type_modifier,
256 format,
257 });
258 }
259 Ok(BackendMessage::RowDescription(fields))
260}
261
262fn parse_data_row(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
263 let count = cur.read_i16()?;
264 if count < 0 {
265 return Err(ProtocolError::InvalidField("negative column count"));
266 }
267 let mut values = Vec::with_capacity(count as usize);
268 for _ in 0..count {
269 let len = cur.read_i32()?;
270 if len == -1 {
271 values.push(None);
272 continue;
273 }
274 if len < 0 {
275 return Err(ProtocolError::InvalidField("negative data length"));
276 }
277 let bytes = cur.read_bytes(len as usize)?.to_vec();
278 values.push(Some(bytes));
279 }
280 Ok(BackendMessage::DataRow(values))
281}
282
283fn parse_command_complete(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
284 let tag = cur.read_cstring()?;
285 Ok(BackendMessage::CommandComplete(tag))
286}
287
288fn parse_parameter_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
289 let count = cur.read_i16()?;
290 if count < 0 {
291 return Err(ProtocolError::InvalidField("negative parameter count"));
292 }
293 let mut oids = Vec::with_capacity(count as usize);
294 for _ in 0..count {
295 oids.push(cur.read_u32()?);
296 }
297 Ok(BackendMessage::ParameterDescription(oids))
298}
299
300fn parse_copy_in_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
301 let format = cur.read_i8()?;
302 let column_formats = read_column_formats(cur)?;
303 Ok(BackendMessage::CopyInResponse {
304 format,
305 column_formats,
306 })
307}
308
309fn parse_copy_out_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
310 let format = cur.read_i8()?;
311 let column_formats = read_column_formats(cur)?;
312 Ok(BackendMessage::CopyOutResponse {
313 format,
314 column_formats,
315 })
316}
317
318fn parse_copy_both_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
319 let format = cur.read_i8()?;
320 let column_formats = read_column_formats(cur)?;
321 Ok(BackendMessage::CopyBothResponse {
322 format,
323 column_formats,
324 })
325}
326
327fn read_column_formats(cur: &mut Cursor<'_>) -> Result<Vec<i16>, ProtocolError> {
328 let count = cur.read_i16()?;
329 if count < 0 {
330 return Err(ProtocolError::InvalidField("negative format count"));
331 }
332 let mut formats = Vec::with_capacity(count as usize);
333 for _ in 0..count {
334 formats.push(cur.read_i16()?);
335 }
336 Ok(formats)
337}
338
339fn parse_notification_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
340 let process_id = cur.read_i32()?;
341 let channel = cur.read_cstring()?;
342 let payload = cur.read_cstring()?;
343 Ok(BackendMessage::NotificationResponse {
344 process_id,
345 channel,
346 payload,
347 })
348}
349
350fn parse_function_call_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
351 let len = cur.read_i32()?;
352 if len == -1 {
353 return Ok(BackendMessage::FunctionCallResponse(None));
354 }
355 if len < 0 {
356 return Err(ProtocolError::InvalidField("negative function length"));
357 }
358 let bytes = cur.read_bytes(len as usize)?.to_vec();
359 Ok(BackendMessage::FunctionCallResponse(Some(bytes)))
360}
361
362fn parse_negotiate_protocol_version(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
363 let newest_minor = cur.read_i32()?;
364 let count = cur.read_i32()?;
365 if count < 0 {
366 return Err(ProtocolError::InvalidField(
367 "negative protocol option count",
368 ));
369 }
370 let mut unrecognized = Vec::with_capacity(count as usize);
371 for _ in 0..count {
372 unrecognized.push(cur.read_cstring()?);
373 }
374 Ok(BackendMessage::NegotiateProtocolVersion {
375 newest_minor,
376 unrecognized,
377 })
378}
379
380fn parse_error_response(
381 cur: &mut Cursor<'_>,
382 is_error: bool,
383) -> Result<BackendMessage, ProtocolError> {
384 let mut fields = ErrorFields::default();
385 loop {
386 let code = cur.read_u8()?;
387 if code == 0 {
388 break;
389 }
390 let value = cur.read_cstring()?;
391 match code {
392 b'S' => fields.severity = value,
393 b'V' => fields.severity_localized = Some(value),
394 b'C' => fields.code = value,
395 b'M' => fields.message = value,
396 b'D' => fields.detail = Some(value),
397 b'H' => fields.hint = Some(value),
398 b'P' => fields.position = value.parse().ok(),
399 b'p' => fields.internal_position = value.parse().ok(),
400 b'q' => fields.internal_query = Some(value),
401 b'W' => fields.where_ = Some(value),
402 b's' => fields.schema = Some(value),
403 b't' => fields.table = Some(value),
404 b'c' => fields.column = Some(value),
405 b'd' => fields.data_type = Some(value),
406 b'n' => fields.constraint = Some(value),
407 b'F' => fields.file = Some(value),
408 b'L' => fields.line = value.parse().ok(),
409 b'R' => fields.routine = Some(value),
410 _ => {
411 }
413 }
414 }
415
416 if is_error {
417 Ok(BackendMessage::ErrorResponse(fields))
418 } else {
419 Ok(BackendMessage::NoticeResponse(fields))
420 }
421}
422
423#[derive(Debug)]
424struct Cursor<'a> {
425 buf: &'a [u8],
426 pos: usize,
427}
428
429impl<'a> Cursor<'a> {
430 fn new(buf: &'a [u8]) -> Self {
431 Self { buf, pos: 0 }
432 }
433
434 fn remaining(&self) -> usize {
435 self.buf.len().saturating_sub(self.pos)
436 }
437
438 fn read_u8(&mut self) -> Result<u8, ProtocolError> {
439 if self.remaining() < 1 {
440 return Err(ProtocolError::UnexpectedEof);
441 }
442 let b = self.buf[self.pos];
443 self.pos += 1;
444 Ok(b)
445 }
446
447 fn read_i8(&mut self) -> Result<i8, ProtocolError> {
448 let b = self.read_u8()?;
449 Ok(b as i8)
450 }
451
452 fn read_i16(&mut self) -> Result<i16, ProtocolError> {
453 let bytes = self.read_bytes(2)?;
454 Ok(i16::from_be_bytes([bytes[0], bytes[1]]))
455 }
456
457 fn read_u32(&mut self) -> Result<u32, ProtocolError> {
458 let bytes = self.read_bytes(4)?;
459 Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
460 }
461
462 fn read_i32(&mut self) -> Result<i32, ProtocolError> {
463 let bytes = self.read_bytes(4)?;
464 Ok(i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
465 }
466
467 fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], ProtocolError> {
468 if self.remaining() < n {
469 return Err(ProtocolError::UnexpectedEof);
470 }
471 let start = self.pos;
472 let end = self.pos + n;
473 self.pos = end;
474 Ok(&self.buf[start..end])
475 }
476
477 fn read_cstring(&mut self) -> Result<String, ProtocolError> {
478 let start = self.pos;
479 while self.pos < self.buf.len() && self.buf[self.pos] != 0 {
480 self.pos += 1;
481 }
482 if self.pos >= self.buf.len() {
483 return Err(ProtocolError::UnexpectedEof);
484 }
485 let bytes = self.buf[start..self.pos].to_vec();
486 self.pos += 1; Ok(String::from_utf8(bytes)?)
488 }
489
490 fn take_remaining(&mut self) -> Vec<u8> {
491 let remaining = self.buf[self.pos..].to_vec();
492 self.pos = self.buf.len();
493 remaining
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[allow(clippy::cast_possible_truncation)]
502 fn build_message(ty: u8, payload: &[u8]) -> Vec<u8> {
503 let mut buf = Vec::new();
504 buf.push(ty);
505 let len = (payload.len() + 4) as i32;
506 buf.extend_from_slice(&len.to_be_bytes());
507 buf.extend_from_slice(payload);
508 buf
509 }
510
511 #[test]
512 fn parse_auth_ok() {
513 let mut payload = Vec::new();
514 payload.extend_from_slice(&auth_type::OK.to_be_bytes());
515 let msg = build_message(backend_type::AUTHENTICATION, &payload);
516 let decoded = MessageReader::parse_message(&msg).unwrap();
517 assert!(matches!(decoded, BackendMessage::AuthenticationOk));
518 }
519
520 #[test]
521 fn parse_ready_for_query() {
522 let payload = [TransactionStatus::Idle.as_byte()];
523 let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
524 let decoded = MessageReader::parse_message(&msg).unwrap();
525 assert!(matches!(
526 decoded,
527 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
528 ));
529 }
530
531 #[test]
532 fn parse_error_response() {
533 let mut payload = Vec::new();
534 payload.push(b'S');
535 payload.extend_from_slice(b"ERROR\0");
536 payload.push(b'C');
537 payload.extend_from_slice(b"12345\0");
538 payload.push(b'M');
539 payload.extend_from_slice(b"bad\0");
540 payload.push(0);
541
542 let msg = build_message(backend_type::ERROR_RESPONSE, &payload);
543 let decoded = MessageReader::parse_message(&msg).unwrap();
544 match decoded {
545 BackendMessage::ErrorResponse(fields) => {
546 assert_eq!(fields.severity, "ERROR");
547 assert_eq!(fields.code, "12345");
548 assert_eq!(fields.message, "bad");
549 }
550 _ => panic!("unexpected message"),
551 }
552 }
553
554 #[test]
555 fn parse_data_row() {
556 let mut payload = Vec::new();
557 payload.extend_from_slice(&(2_i16).to_be_bytes());
558 payload.extend_from_slice(&(3_i32).to_be_bytes());
559 payload.extend_from_slice(b"foo");
560 payload.extend_from_slice(&(-1_i32).to_be_bytes());
561
562 let msg = build_message(backend_type::DATA_ROW, &payload);
563 let decoded = MessageReader::parse_message(&msg).unwrap();
564 match decoded {
565 BackendMessage::DataRow(values) => {
566 assert_eq!(values.len(), 2);
567 assert_eq!(values[0].as_deref(), Some(b"foo".as_slice()));
568 assert!(values[1].is_none());
569 }
570 _ => panic!("unexpected message"),
571 }
572 }
573
574 #[test]
575 fn reader_buffers_partial_frames() {
576 let payload = [TransactionStatus::Idle.as_byte()];
577 let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
578 let (left, right) = msg.split_at(3);
579
580 let mut reader = MessageReader::new();
581 let first = reader.feed(left).unwrap();
582 assert!(first.is_empty());
583
584 let second = reader.feed(right).unwrap();
585 assert_eq!(second.len(), 1);
586 }
587
588 #[test]
589 fn parse_row_description_negative_count_rejected() {
590 let payload = (-1_i16).to_be_bytes();
592 let msg = build_message(backend_type::ROW_DESCRIPTION, &payload);
593 let result = MessageReader::parse_message(&msg);
594 assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
595 }
596
597 #[test]
598 fn parse_data_row_negative_count_rejected() {
599 let payload = (-1_i16).to_be_bytes();
601 let msg = build_message(backend_type::DATA_ROW, &payload);
602 let result = MessageReader::parse_message(&msg);
603 assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
604 }
605
606 #[test]
607 fn parse_parameter_description_negative_count_rejected() {
608 let payload = (-1_i16).to_be_bytes();
610 let msg = build_message(backend_type::PARAMETER_DESCRIPTION, &payload);
611 let result = MessageReader::parse_message(&msg);
612 assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
613 }
614}