From 913dc5bc13dd20f9977f5a5591da0465d4db641a Mon Sep 17 00:00:00 2001 From: Luke Osborne Date: Fri, 29 Mar 2024 18:55:01 -0400 Subject: [PATCH] Fix FromSql for Postgres Numeric NaNs (#656) * Fix FromSql for Postgres Numeric NaNs This fixes a bug where from_sql was converting Numeric::NaN to 0 rather than returning an error. It also returns a more descriptive error message when from_sql is called with Numeric Infinity and -Infinity, which are also not representable in Decimal. Closes #655 * Rename to_from_sql to postgres_to_from_sql * Rename from_sql_special_numeric to postgres_from_sql_special_numeric --------- Co-authored-by: Paul Mason --- src/postgres/driver.rs | 21 +++++++++++++++++++++ tests/decimal_tests.rs | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/postgres/driver.rs b/src/postgres/driver.rs index 9e761bf..6bbab11 100644 --- a/src/postgres/driver.rs +++ b/src/postgres/driver.rs @@ -1,9 +1,16 @@ +use crate::error::Error; use crate::postgres::common::*; use crate::Decimal; use bytes::{BufMut, BytesMut}; use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type}; use std::io::{Cursor, Read}; +// These are from numeric.c in the PostgreSQL source code +const NUMERIC_NAN: u16 = 0xC000; +const NUMERIC_PINF: u16 = 0xD000; +const NUMERIC_NINF: u16 = 0xF000; +const NUMERIC_SPECIAL: u16 = 0xC000; + fn read_two_bytes(cursor: &mut Cursor<&[u8]>) -> std::io::Result<[u8; 2]> { let mut result = [0; 2]; cursor.read_exact(&mut result)?; @@ -70,6 +77,20 @@ impl<'a> FromSql<'a> for Decimal { let weight = i16::from_be_bytes(read_two_bytes(&mut raw)?); // 10000^weight // Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN let sign = u16::from_be_bytes(read_two_bytes(&mut raw)?); + + if (sign & NUMERIC_SPECIAL) == NUMERIC_SPECIAL { + let special = match sign { + NUMERIC_NAN => "NaN", + NUMERIC_PINF => "Infinity", + NUMERIC_NINF => "-Infinity", + // This shouldn't be hit unless postgres adds a new special numeric type in the + // future + _ => "unknown special numeric", + }; + + return Err(Box::new(Error::ConversionTo(special.to_string()))); + } + // Number of digits (in base 10) to print after decimal separator let scale = u16::from_be_bytes(read_two_bytes(&mut raw)?); diff --git a/tests/decimal_tests.rs b/tests/decimal_tests.rs index 6bb7df1..202ef99 100644 --- a/tests/decimal_tests.rs +++ b/tests/decimal_tests.rs @@ -3473,9 +3473,9 @@ fn declarative_ref_dec_sum() { assert_eq!(sum, Decimal::from(45)) } -#[cfg(feature = "postgres")] +#[cfg(feature = "db-postgres")] #[test] -fn to_from_sql() { +fn postgres_to_from_sql() { use bytes::BytesMut; use postgres::types::{FromSql, Kind, ToSql, Type}; @@ -3514,6 +3514,36 @@ fn to_from_sql() { } } +#[cfg(feature = "db-postgres")] +#[test] +fn postgres_from_sql_special_numeric() { + use postgres::types::{FromSql, Kind, Type}; + + // The numbers below are the big-endian equivalent of the NUMERIC_* masks for NAN, PINF, NINF + let tests = &[ + ("NaN", &[0, 0, 0, 0, 192, 0, 0, 0]), + ("Infinity", &[0, 0, 0, 0, 208, 0, 0, 0]), + ("-Infinity", &[0, 0, 0, 0, 240, 0, 0, 0]), + ]; + + let t = Type::new("".into(), 0, Kind::Simple, "".into()); + + for (name, bytes) in tests { + let res = Decimal::from_sql(&t, *bytes); + match &res { + Ok(_) => panic!("Expected error, got Ok"), + Err(e) => { + let error_message = e.to_string(); + assert!( + error_message.contains(name), + "Error message does not contain the expected value: {}", + name + ); + } + } + } +} + fn hash_it(d: Decimal) -> u64 { use core::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher;