From bc07d0b067dd22734bf6b1c21cbf3be1a25145a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Sat, 24 Feb 2024 14:17:09 +0100 Subject: [PATCH] feat: postgres type conversions --- connector_arrow/src/params.rs | 42 ++-- connector_arrow/src/postgres/mod.rs | 21 +- .../src/postgres/protocol_extended.rs | 185 ++++++++++++------ .../src/postgres/protocol_simple.rs | 6 +- connector_arrow/src/postgres/types.rs | 9 +- .../tests/it/test_postgres_common.rs | 96 +++++++-- .../tests/it/test_postgres_extended.rs | 5 + .../tests/it/test_postgres_simple.rs | 5 + connector_arrow/tests/it/util.rs | 16 +- 9 files changed, 275 insertions(+), 110 deletions(-) diff --git a/connector_arrow/src/params.rs b/connector_arrow/src/params.rs index 75e421e..1d276fc 100644 --- a/connector_arrow/src/params.rs +++ b/connector_arrow/src/params.rs @@ -88,6 +88,16 @@ macro_rules! impl_arrow_value_tuple { }; } +impl_arrow_value_tuple!( + i32, + ( + Date32Type, + Time32SecondType, + Time32MillisecondType, + IntervalYearMonthType, + ) +); + impl_arrow_value_tuple!( i64, ( @@ -95,33 +105,23 @@ impl_arrow_value_tuple!( TimestampMillisecondType, TimestampMicrosecondType, TimestampNanosecondType, - ) -); - -impl_produce_unsupported!( - &'r dyn ArrowValue, - ( - NullType, - Float16Type, - // BinaryType, - // Utf8Type, - Date32Type, Date64Type, - Time32SecondType, - Time32MillisecondType, Time64MicrosecondType, Time64NanosecondType, - IntervalYearMonthType, - IntervalDayTimeType, - IntervalMonthDayNanoType, DurationSecondType, DurationMillisecondType, DurationMicrosecondType, DurationNanosecondType, - LargeBinaryType, - FixedSizeBinaryType, - LargeUtf8Type, - Decimal128Type, - Decimal256Type, + IntervalDayTimeType, ) ); + +impl_arrow_value_tuple!(i128, (IntervalMonthDayNanoType, Decimal128Type,)); + +impl_arrow_value_tuple!(i256, (Decimal256Type,)); + +impl_arrow_value_tuple!(String, (LargeUtf8Type,)); + +impl_arrow_value_tuple!(Vec, (LargeBinaryType, FixedSizeBinaryType,)); + +impl_produce_unsupported!(&'r dyn ArrowValue, (NullType, Float16Type,)); diff --git a/connector_arrow/src/postgres/mod.rs b/connector_arrow/src/postgres/mod.rs index ea420f2..5d2d26d 100644 --- a/connector_arrow/src/postgres/mod.rs +++ b/connector_arrow/src/postgres/mod.rs @@ -19,7 +19,7 @@ mod protocol_simple; mod schema; mod types; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; use postgres::Client; use std::marker::PhantomData; use thiserror::Error; @@ -130,9 +130,9 @@ where DataType::Time64(_) => Some(DataType::Int64), DataType::Duration(_) => Some(DataType::Int64), - DataType::Utf8 => Some(DataType::LargeUtf8), - DataType::Binary => Some(DataType::LargeBinary), - DataType::FixedSizeBinary(_) => Some(DataType::LargeBinary), + DataType::LargeUtf8 => Some(DataType::Utf8), + DataType::LargeBinary => Some(DataType::Binary), + DataType::FixedSizeBinary(_) => Some(DataType::Binary), DataType::Decimal128(_, _) => Some(DataType::Utf8), DataType::Decimal256(_, _) => Some(DataType::Utf8), @@ -156,6 +156,19 @@ where "timestamptz" | "timestamp with time zone" => { DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())) } + "date" => DataType::Date32, + "time" | "time without time zone" => DataType::Time64(TimeUnit::Microsecond), + "interval" => DataType::Interval(IntervalUnit::MonthDayNano), + + "bytea" => DataType::Binary, + "bit" | "bit varying" | "varbit" => DataType::Binary, + _ if ty.starts_with("bit") => DataType::Binary, + + "text" | "varchar" | "char" | "bpchar" => DataType::Utf8, + _ if ty.starts_with("varchar") | ty.starts_with("char") | ty.starts_with("bpchar") => { + DataType::Utf8 + } + _ => return None, }) } diff --git a/connector_arrow/src/postgres/protocol_extended.rs b/connector_arrow/src/postgres/protocol_extended.rs index bb2cafd..1000d43 100644 --- a/connector_arrow/src/postgres/protocol_extended.rs +++ b/connector_arrow/src/postgres/protocol_extended.rs @@ -2,7 +2,7 @@ use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use itertools::Itertools; use postgres::fallible_iterator::FallibleIterator; -use postgres::types::FromSql; +use postgres::types::{FromSql, Type}; use postgres::{Row, RowIter}; use crate::api::{ArrowValue, ResultReader, Statement}; @@ -153,27 +153,21 @@ impl_produce!(Int32Type, i32, Result::Ok); impl_produce!(Int64Type, i64, Result::Ok); impl_produce!(Float32Type, f32, Result::Ok); impl_produce!(Float64Type, f64, Result::Ok); -impl_produce!(LargeBinaryType, Vec, Result::Ok); +impl_produce!(BinaryType, Binary, Binary::into_arrow); +impl_produce!(LargeBinaryType, Binary, Binary::into_arrow); +impl_produce!(Utf8Type, StrOrNum, StrOrNum::into_arrow); impl_produce!(LargeUtf8Type, String, Result::Ok); -impl_produce!( - TimestampSecondType, - TimestampY2000, - TimestampY2000::into_second -); -impl_produce!( - TimestampMillisecondType, - TimestampY2000, - TimestampY2000::into_millisecond -); impl_produce!( TimestampMicrosecondType, TimestampY2000, TimestampY2000::into_microsecond ); +impl_produce!(Time64MicrosecondType, Time64, Time64::into_microsecond); +impl_produce!(Date32Type, DaysSinceY2000, DaysSinceY2000::into_date32); impl_produce!( - TimestampNanosecondType, - TimestampY2000, - TimestampY2000::into_nanosecond + IntervalMonthDayNanoType, + IntervalMonthDayMicros, + IntervalMonthDayMicros::into_arrow ); crate::impl_produce_unsupported!( @@ -184,96 +178,173 @@ crate::impl_produce_unsupported!( UInt32Type, UInt64Type, Float16Type, - // TimestampSecondType, - // TimestampMillisecondType, - // TimestampMicrosecondType, - // TimestampNanosecondType, - Date32Type, + TimestampSecondType, + TimestampMillisecondType, + TimestampNanosecondType, Date64Type, Time32SecondType, Time32MillisecondType, - Time64MicrosecondType, Time64NanosecondType, IntervalYearMonthType, IntervalDayTimeType, - IntervalMonthDayNanoType, DurationSecondType, DurationMillisecondType, DurationMicrosecondType, DurationNanosecondType, - BinaryType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, ) ); -struct Numeric(String); +struct StrOrNum(String); + +impl StrOrNum { + fn into_arrow(self) -> Result { + Ok(self.0) + } +} -impl<'a> FromSql<'a> for Numeric { +impl<'a> FromSql<'a> for StrOrNum { fn from_sql( - _ty: &postgres::types::Type, + ty: &Type, raw: &'a [u8], ) -> Result> { - Ok(super::decimal::from_sql(raw).map(Numeric)?) + if matches!(ty, &Type::NUMERIC) { + Ok(super::decimal::from_sql(raw).map(StrOrNum)?) + } else { + let slice = postgres_protocol::types::text_from_sql(raw)?; + Ok(StrOrNum(slice.to_string())) + } } - fn accepts(_ty: &postgres::types::Type) -> bool { + fn accepts(_ty: &Type) -> bool { true } } -impl<'c> transport::ProduceTy<'c, Utf8Type> for CellRef<'c> { - fn produce(self) -> Result { - Ok(self.0.get::<_, Numeric>(self.1).0) +const DUR_1970_TO_2000_DAYS: i32 = 10957; +const DUR_1970_TO_2000_SEC: i64 = DUR_1970_TO_2000_DAYS as i64 * 24 * 60 * 60; + +struct TimestampY2000(i64); + +impl<'a> FromSql<'a> for TimestampY2000 { + fn from_sql( + _ty: &Type, + raw: &'a [u8], + ) -> Result> { + postgres_protocol::types::timestamp_from_sql(raw).map(TimestampY2000) } - fn produce_opt(self) -> Result, ConnectorError> { - Ok(self.0.get::<_, Option>(self.1).map(|n| n.0)) + fn accepts(_ty: &Type) -> bool { + true } } -struct TimestampY2000(i64); - -const DUR_1970_TO_2000_SEC: i64 = 10957 * 24 * 60 * 60; - impl TimestampY2000 { - fn into_nanosecond(self) -> Result { - self.0 - .checked_add(DUR_1970_TO_2000_SEC * 1000 * 1000) - .and_then(|micros_y1970| micros_y1970.mul_checked(1000).ok()) - .ok_or(ConnectorError::DataOutOfRange) - } fn into_microsecond(self) -> Result { self.0 .checked_add(DUR_1970_TO_2000_SEC * 1000 * 1000) .ok_or(ConnectorError::DataOutOfRange) } - fn into_millisecond(self) -> Result { - self.0 - .div_checked(1000) - .ok() - .and_then(|millis_y2000| millis_y2000.checked_add(DUR_1970_TO_2000_SEC * 1000)) - .ok_or(ConnectorError::DataOutOfRange) +} + +struct DaysSinceY2000(i32); + +impl<'a> FromSql<'a> for DaysSinceY2000 { + fn from_sql( + _ty: &Type, + raw: &'a [u8], + ) -> Result> { + postgres_protocol::types::date_from_sql(raw).map(DaysSinceY2000) + } + + fn accepts(_ty: &Type) -> bool { + true } - fn into_second(self) -> Result { +} + +impl DaysSinceY2000 { + fn into_date32(self) -> Result { self.0 - .div_checked(1_000_000) - .ok() - .and_then(|sec_y2000| sec_y2000.checked_add(DUR_1970_TO_2000_SEC)) + .checked_add(DUR_1970_TO_2000_DAYS) .ok_or(ConnectorError::DataOutOfRange) } } -impl<'a> FromSql<'a> for TimestampY2000 { +struct Time64(i64); + +impl<'a> FromSql<'a> for Time64 { fn from_sql( - _ty: &postgres::types::Type, + _ty: &Type, raw: &'a [u8], ) -> Result> { - postgres_protocol::types::timestamp_from_sql(raw).map(TimestampY2000) + postgres_protocol::types::time_from_sql(raw).map(Time64) } + fn accepts(_ty: &Type) -> bool { + true + } +} - fn accepts(_ty: &postgres::types::Type) -> bool { +impl Time64 { + fn into_microsecond(self) -> Result { + Ok(self.0) + } +} + +struct IntervalMonthDayMicros(i32, i32, i64); + +impl<'a> FromSql<'a> for IntervalMonthDayMicros { + fn from_sql( + _ty: &Type, + raw: &'a [u8], + ) -> Result> { + let micros = postgres_protocol::types::time_from_sql(&raw[0..8])?; + let days = postgres_protocol::types::int4_from_sql(&raw[8..12])?; + let months = postgres_protocol::types::int4_from_sql(&raw[12..16])?; + Ok(IntervalMonthDayMicros(months, days, micros)) + } + fn accepts(_ty: &Type) -> bool { true } } + +impl IntervalMonthDayMicros { + fn into_arrow(self) -> Result { + let nanos = (self.2.checked_mul(1000)).ok_or(ConnectorError::DataOutOfRange)?; + + let mut bytes = [0; 16]; + bytes[0..4].copy_from_slice(&self.0.to_be_bytes()); + bytes[4..8].copy_from_slice(&self.1.to_be_bytes()); + bytes[8..16].copy_from_slice(&nanos.to_be_bytes()); + Ok(i128::from_be_bytes(bytes)) + } +} + +struct Binary<'a>(&'a [u8]); + +impl<'a> FromSql<'a> for Binary<'a> { + fn from_sql( + ty: &Type, + raw: &'a [u8], + ) -> Result> { + Ok(if matches!(ty, &Type::VARBIT | &Type::BIT) { + let varbit = postgres_protocol::types::varbit_from_sql(raw)?; + dbg!(varbit.len()); + dbg!(varbit.bytes()); + Binary(varbit.bytes()) + } else { + Binary(postgres_protocol::types::bytea_from_sql(raw)) + }) + } + fn accepts(_ty: &Type) -> bool { + true + } +} + +impl Binary<'_> { + fn into_arrow(self) -> Result, ConnectorError> { + // this is a clone, that is needed because Produce requires Vec + Ok(self.0.to_vec()) + } +} diff --git a/connector_arrow/src/postgres/protocol_simple.rs b/connector_arrow/src/postgres/protocol_simple.rs index a0d913a..cb194bc 100644 --- a/connector_arrow/src/postgres/protocol_simple.rs +++ b/connector_arrow/src/postgres/protocol_simple.rs @@ -152,7 +152,7 @@ crate::impl_produce_unsupported!( DurationMillisecondType, DurationMicrosecondType, DurationNanosecondType, - BinaryType, + LargeBinaryType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, @@ -181,9 +181,9 @@ impl<'r> transport::ProduceTy<'r, LargeUtf8Type> for CellRef<'r> { } } -impl<'r> transport::ProduceTy<'r, LargeBinaryType> for CellRef<'r> { +impl<'r> transport::ProduceTy<'r, BinaryType> for CellRef<'r> { fn produce(self) -> Result, ConnectorError> { - transport::ProduceTy::::produce_opt(self)?.ok_or_else(err_null) + transport::ProduceTy::::produce_opt(self)?.ok_or_else(err_null) } fn produce_opt(self) -> Result>, ConnectorError> { diff --git a/connector_arrow/src/postgres/types.rs b/connector_arrow/src/postgres/types.rs index fc50ccd..cbd2cbf 100644 --- a/connector_arrow/src/postgres/types.rs +++ b/connector_arrow/src/postgres/types.rs @@ -19,13 +19,8 @@ pub fn pg_stmt_to_arrow( } pub fn pg_ty_to_arrow(ty: &PgType) -> ArrowType { - match ty.name() { - "text" => ArrowType::LargeUtf8, - "varchar" => ArrowType::LargeUtf8, - "bytea" => ArrowType::LargeBinary, - name => PostgresConnection::::type_db_into_arrow(name) - .unwrap_or_else(|| unimplemented!("{name}")), - } + PostgresConnection::::type_db_into_arrow(ty.name()) + .unwrap_or_else(|| unimplemented!("{}", ty.name())) } pub(crate) fn arrow_ty_to_pg(data_type: &ArrowType) -> String { diff --git a/connector_arrow/tests/it/test_postgres_common.rs b/connector_arrow/tests/it/test_postgres_common.rs index 57e9a04..699b671 100644 --- a/connector_arrow/tests/it/test_postgres_common.rs +++ b/connector_arrow/tests/it/test_postgres_common.rs @@ -35,6 +35,8 @@ fn ident_escaping() { } pub mod literals_cases { + use arrow::datatypes::{DataType, TimeUnit}; + use crate::util::QueryOfSingleLiteral; pub fn bool() -> Vec { @@ -126,22 +128,85 @@ pub mod literals_cases { ] } + pub fn date() -> Vec { + vec![ + ("date", "DATE '2024-02-23'", (DataType::Date32, 19776_i32)).into(), + ("date", "'4713-01-01 BC'", (DataType::Date32, -2440550_i32)).into(), + ("date", "'5874897-1-1'", (DataType::Date32, 2145042541_i32)).into(), + ] + } + + pub fn time() -> Vec { + vec![ + ( + "time", + "'17:18:36'", + (DataType::Time64(TimeUnit::Microsecond), 62316000000_i64), + ) + .into(), + ( + "time without time zone", + "'17:18:36.789'", + (DataType::Time64(TimeUnit::Microsecond), 62316789000_i64), + ) + .into(), + // timetz not supported by postgres crate + // ( + // "time with time zone", + // "'17:18:36.789+01:00'", + // (DataType::Time64(TimeUnit::Microsecond), 65916789000_i64), + // ) + // .into(), + // ( + // "time with time zone", + // "'17:18:36.789 CEST'", + // (DataType::Time64(TimeUnit::Microsecond), 65916789000_i64), + // ) + // .into(), + ] + } + + pub fn interval() -> Vec { + vec![ + ( + "interval", + "'P12M3DT4H5M6S'", + ( + DataType::Duration(TimeUnit::Microsecond), + 0x0000000C_00000003_00000d6001e7f400_i128, + ), + ) + .into(), + ( + "interval", + "'P-1Y-2M3DT-4H-5M-6S'", + ( + DataType::Duration(TimeUnit::Microsecond), + 0xfffffff2_00000003_fffff29ffe180c00_u128 as i128, + ), + ) + .into(), + ] + } + + pub fn binary() -> Vec { + vec![ + ("bytea", "'\\xDEADBEEF'", vec![0xDE, 0xAD, 0xBE, 0xEF]).into(), + ("bit(4)", "B'1011'", vec![0b10110000]).into(), + ("bit varying(6)", "B'1011'", vec![0b10110000]).into(), + ] + } + + pub fn text() -> Vec { + vec![ + ("text", "'ok'", "ok".to_string()).into(), + ("char(6)", "'hello'", "hello ".to_string()).into(), + ("varchar(6)", "'world'", "world".to_string()).into(), + ("bpchar", "' nope '", " nope ".to_string()).into(), + ] + } + // TODO: - // bit [ (n) ] - // bit varying [ (n) ] - // bytea - // - // character [ (n) ] - // character varying [ (n) ] - // text - // - // timestamp [ (p) ] [ without time zone ] - // timestamp [ (p) ] with time zone - // date - // time [ (p) ] [ without time zone ] - // time [ (p) ] with time zone - // interval [ fields ] [ (p) ] - // // point // box // circle @@ -163,7 +228,6 @@ pub mod literals_cases { // // tsquery // tsvector - // // pg_lsn // pg_snapshot // txid_snapshot diff --git a/connector_arrow/tests/it/test_postgres_extended.rs b/connector_arrow/tests/it/test_postgres_extended.rs index a0fb208..b19b2cc 100644 --- a/connector_arrow/tests/it/test_postgres_extended.rs +++ b/connector_arrow/tests/it/test_postgres_extended.rs @@ -50,6 +50,11 @@ fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { #[case::float(literals_cases::float())] #[case::decimal(literals_cases::decimal())] #[case::timestamp(literals_cases::timestamp())] +#[case::date(literals_cases::date())] +#[case::time(literals_cases::time())] +#[case::interval(literals_cases::interval())] +#[case::binary(literals_cases::binary())] +#[case::text(literals_cases::text())] fn query_literals(#[case] queries: Vec) { let mut conn = init(); crate::util::query_literals(&mut conn, queries) diff --git a/connector_arrow/tests/it/test_postgres_simple.rs b/connector_arrow/tests/it/test_postgres_simple.rs index b210e32..9467c5b 100644 --- a/connector_arrow/tests/it/test_postgres_simple.rs +++ b/connector_arrow/tests/it/test_postgres_simple.rs @@ -50,6 +50,11 @@ fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { #[case::float(literals_cases::float())] #[case::decimal(literals_cases::decimal())] // #[case::timestamp(literals_cases::timestamp())] +// #[case::date(literals_cases::date())] +// #[case::time(literals_cases::time())] +// #[case::interval(literals_cases::interval())] +#[case::text(literals_cases::text())] +// #[case::binary(literals_cases::binary())] fn query_literals(#[case] queries: Vec) { let mut conn = init(); crate::util::query_literals(&mut conn, queries) diff --git a/connector_arrow/tests/it/util.rs b/connector_arrow/tests/it/util.rs index 2d1c9a7..e732690 100644 --- a/connector_arrow/tests/it/util.rs +++ b/connector_arrow/tests/it/util.rs @@ -68,7 +68,12 @@ pub fn query_literals(conn: &mut C, queries: Vec(conn: &mut C, queries: Vec) -> ArrayRef {