diff --git a/src/postgres/common.rs b/src/postgres/common.rs index 485a259..e0fd9d9 100644 --- a/src/postgres/common.rs +++ b/src/postgres/common.rs @@ -31,14 +31,14 @@ pub(in crate::postgres) struct PostgresDecimal { } impl Decimal { - pub(in crate::postgres) fn from_postgres>( + pub(in crate::postgres) fn checked_from_postgres>( PostgresDecimal { neg, scale, digits, weight, }: PostgresDecimal, - ) -> Self { + ) -> Option { let mut digits = digits.into_iter().collect::>(); let fractionals_part_count = digits.len() as i32 + (-weight as i32) - 1; @@ -54,23 +54,26 @@ impl Decimal { }; let integers: Vec<_> = digits.drain(..last as usize).collect(); for digit in integers { - result *= Decimal::from_i128_with_scale(10i128.pow(4), 0); - result += Decimal::new(digit as i64, 0); + result = result.checked_mul(Decimal::from_i128_with_scale(10i128.pow(4), 0))?; + result = result.checked_add(Decimal::new(digit as i64, 0))?; } - result *= Decimal::from_i128_with_scale(10i128.pow(4 * start_integers as u32), 0); + result = result.checked_mul(Decimal::from_i128_with_scale(10i128.pow(4 * start_integers as u32), 0))?; } // adding fractional part if fractionals_part_count > 0 { let start_fractionals = if weight < 0 { (-weight as u32) - 1 } else { 0 }; for (i, digit) in digits.into_iter().enumerate() { - let fract_pow = 4 * (i as u32 + 1 + start_fractionals); + let fract_pow = 4_u32.checked_mul(i as u32 + 1 + start_fractionals)?; if fract_pow <= MAX_PRECISION_U32 { - result += Decimal::new(digit as i64, 0) / Decimal::from_i128_with_scale(10i128.pow(fract_pow), 0); + result = result.checked_add( + Decimal::new(digit as i64, 0) / Decimal::from_i128_with_scale(10i128.pow(fract_pow), 0), + )?; } else if fract_pow == MAX_PRECISION_U32 + 4 { // rounding last digit if digit >= 5000 { - result += - Decimal::new(1_i64, 0) / Decimal::from_i128_with_scale(10i128.pow(MAX_PRECISION_U32), 0); + result = result.checked_add( + Decimal::new(1_i64, 0) / Decimal::from_i128_with_scale(10i128.pow(MAX_PRECISION_U32), 0), + )?; } } } @@ -79,7 +82,7 @@ impl Decimal { result.set_sign_negative(neg); // Rescale to the postgres value, automatically rounding as needed. result.rescale((scale as u32).min(MAX_PRECISION_U32)); - result + Some(result) } pub(in crate::postgres) fn to_postgres(self) -> PostgresDecimal> { diff --git a/src/postgres/diesel.rs b/src/postgres/diesel.rs index 26cd3b3..d523ad1 100644 --- a/src/postgres/diesel.rs +++ b/src/postgres/diesel.rs @@ -27,12 +27,15 @@ impl<'a> TryFrom<&'a PgNumeric> for Decimal { PgNumeric::NaN => return Err(Box::from("NaN is not supported in Decimal")), }; - Ok(Self::from_postgres(PostgresDecimal { + let Some(result) = Self::checked_from_postgres(PostgresDecimal { neg, weight, scale, digits: digits.iter().copied().map(|v| v.try_into().unwrap()), - })) + }) else { + return Err(Box::new(crate::error::Error::ExceedsMaximumPossibleValue)); + }; + Ok(result) } } diff --git a/src/postgres/driver.rs b/src/postgres/driver.rs index d8e789f..9e761bf 100644 --- a/src/postgres/driver.rs +++ b/src/postgres/driver.rs @@ -79,12 +79,15 @@ impl<'a> FromSql<'a> for Decimal { groups.push(u16::from_be_bytes(read_two_bytes(&mut raw)?)); } - Ok(Self::from_postgres(PostgresDecimal { + let Some(result) = Self::checked_from_postgres(PostgresDecimal { neg: sign == 0x4000, weight, scale, digits: groups.into_iter(), - })) + }) else { + return Err(Box::new(crate::error::Error::ExceedsMaximumPossibleValue)); + }; + Ok(result) } fn accepts(ty: &Type) -> bool { @@ -422,4 +425,23 @@ mod test { } } } + + #[test] + fn numeric_overflow_from_sql() { + let close_to_overflow = Decimal::from_sql( + &Type::NUMERIC, + &[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01], + ); + assert!(close_to_overflow.is_ok()); + assert_eq!(close_to_overflow.unwrap().to_string(), "10000000000000000000000000000"); + let overflow = Decimal::from_sql( + &Type::NUMERIC, + &[0x00, 0x01, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a], + ); + assert!(overflow.is_err()); + assert_eq!( + overflow.unwrap_err().to_string(), + crate::error::Error::ExceedsMaximumPossibleValue.to_string() + ); + } }