diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index 314fda0b17..68a4fcfeff 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -38,6 +38,8 @@ impl_type_checking!( sqlx::postgres::types::PgLSeg, + sqlx::postgres::types::PgBox, + #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-postgres/src/types/geometry/box.rs b/sqlx-postgres/src/types/geometry/box.rs new file mode 100644 index 0000000000..988c028ed4 --- /dev/null +++ b/sqlx-postgres/src/types/geometry/box.rs @@ -0,0 +1,321 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use sqlx_core::bytes::Buf; +use std::str::FromStr; + +const ERROR: &str = "error decoding BOX"; + +/// ## Postgres Geometric Box type +/// +/// Description: Rectangular box +/// Representation: `((upper_right_x,upper_right_y),(lower_left_x,lower_left_y))` +/// +/// Boxes are represented by pairs of points that are opposite corners of the box. Values of type box are specified using any of the following syntaxes: +/// +/// ```text +/// ( ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y ) ) +/// ( upper_right_x , upper_right_y ) , ( lower_left_x , lower_left_y ) +/// upper_right_x , upper_right_y , lower_left_x , lower_left_y +/// ``` +/// where `(upper_right_x,upper_right_y) and (lower_left_x,lower_left_y)` are any two opposite corners of the box. +/// Any two opposite corners can be supplied on input, but the values will be reordered as needed to store the upper right and lower left corners, in that order. +/// +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-BOXES +#[derive(Debug, Clone, PartialEq)] +pub struct PgBox { + pub upper_right_x: f64, + pub upper_right_y: f64, + pub lower_left_x: f64, + pub lower_left_y: f64, +} + +impl Type for PgBox { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("box") + } +} + +impl PgHasArrayType for PgBox { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_box") + } +} + +impl<'r> Decode<'r, Postgres> for PgBox { + fn decode(value: PgValueRef<'r>) -> Result> { + match value.format() { + PgValueFormat::Text => Ok(PgBox::from_str(value.as_str()?)?), + PgValueFormat::Binary => Ok(PgBox::from_bytes(value.as_bytes()?)?), + } + } +} + +impl<'q> Encode<'q, Postgres> for PgBox { + fn produces(&self) -> Option { + Some(PgTypeInfo::with_name("box")) + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + self.serialize(buf)?; + Ok(IsNull::No) + } +} + +impl FromStr for PgBox { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); + let mut parts = sanitised.split(','); + + let upper_right_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get upper_right_x from {}", ERROR, s))?; + + let upper_right_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get upper_right_y from {}", ERROR, s))?; + + let lower_left_x = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get lower_left_x from {}", ERROR, s))?; + + let lower_left_y = parts + .next() + .and_then(|s| s.parse::().ok()) + .ok_or_else(|| format!("{}: could not get lower_left_y from {}", ERROR, s))?; + + if parts.next().is_some() { + return Err(format!("{}: too many numbers inputted in {}", ERROR, s).into()); + } + + Ok(PgBox { + upper_right_x, + upper_right_y, + lower_left_x, + lower_left_y, + }) + } +} + +impl PgBox { + fn from_bytes(mut bytes: &[u8]) -> Result { + let upper_right_x = bytes.get_f64(); + let upper_right_y = bytes.get_f64(); + let lower_left_x = bytes.get_f64(); + let lower_left_y = bytes.get_f64(); + + Ok(PgBox { + upper_right_x, + upper_right_y, + lower_left_x, + lower_left_y, + }) + } + + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { + let min_x = &self.upper_right_x.min(self.lower_left_x); + let min_y = &self.upper_right_y.min(self.lower_left_y); + let max_x = &self.upper_right_x.max(self.lower_left_x); + let max_y = &self.upper_right_y.max(self.lower_left_y); + + buff.extend_from_slice(&max_x.to_be_bytes()); + buff.extend_from_slice(&max_y.to_be_bytes()); + buff.extend_from_slice(&min_x.to_be_bytes()); + buff.extend_from_slice(&min_y.to_be_bytes()); + + Ok(()) + } + + #[cfg(test)] + fn serialize_to_vec(&self) -> Vec { + let mut buff = PgArgumentBuffer::default(); + self.serialize(&mut buff).unwrap(); + buff.to_vec() + } +} + +#[cfg(test)] +mod box_tests { + + use std::str::FromStr; + + use super::PgBox; + + const BOX_BYTES: &[u8] = &[ + 64, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, 0, 0, 0, 0, 192, 0, 0, 0, + 0, 0, 0, 0, + ]; + + #[test] + fn can_deserialise_box_type_bytes_in_order() { + let pg_box = PgBox::from_bytes(BOX_BYTES).unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 2., + upper_right_y: 2., + lower_left_x: -2., + lower_left_y: -2. + } + ) + } + + #[test] + fn can_deserialise_box_type_str_first_syntax() { + let pg_box = PgBox::from_str("[( 1, 2), (3, 4 )]").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + #[test] + fn can_deserialise_box_type_str_second_syntax() { + let pg_box = PgBox::from_str("(( 1, 2), (3, 4 ))").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn can_deserialise_box_type_str_third_syntax() { + let pg_box = PgBox::from_str("(1, 2), (3, 4 )").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn can_deserialise_box_type_str_fourth_syntax() { + let pg_box = PgBox::from_str("1, 2, 3, 4").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1., + upper_right_y: 2., + lower_left_x: 3., + lower_left_y: 4. + } + ); + } + + #[test] + fn cannot_deserialise_too_many_numbers() { + let input_str = "1, 2, 3, 4, 5"; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: too many numbers inputted in {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_too_few_numbers() { + let input_str = "1, 2, 3 "; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: could not get lower_left_y from {input_str}") + ) + } + } + + #[test] + fn cannot_deserialise_invalid_numbers() { + let input_str = "1, 2, 3, FOUR"; + let pg_box = PgBox::from_str(input_str); + assert!(pg_box.is_err()); + if let Err(err) = pg_box { + assert_eq!( + err.to_string(), + format!("error decoding BOX: could not get lower_left_y from {input_str}") + ) + } + } + + #[test] + fn can_deserialise_box_type_str_float() { + let pg_box = PgBox::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 1.1, + upper_right_y: 2.2, + lower_left_x: 3.3, + lower_left_y: 4.4 + } + ); + } + + #[test] + fn can_serialise_box_type_in_order() { + let pg_box = PgBox { + upper_right_x: 2., + lower_left_x: -2., + upper_right_y: -2., + lower_left_y: 2., + }; + assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,) + } + + #[test] + fn can_serialise_box_type_out_of_order() { + let pg_box = PgBox { + upper_right_x: -2., + lower_left_x: 2., + upper_right_y: 2., + lower_left_y: -2., + }; + assert_eq!(pg_box.serialize_to_vec(), BOX_BYTES,) + } + + #[test] + fn can_order_box() { + let pg_box = PgBox { + upper_right_x: -2., + lower_left_x: 2., + upper_right_y: 2., + lower_left_y: -2., + }; + let bytes = pg_box.serialize_to_vec(); + + let pg_box = PgBox::from_bytes(&bytes).unwrap(); + assert_eq!( + pg_box, + PgBox { + upper_right_x: 2., + upper_right_y: 2., + lower_left_x: -2., + lower_left_y: -2. + } + ) + } +} diff --git a/sqlx-postgres/src/types/geometry/mod.rs b/sqlx-postgres/src/types/geometry/mod.rs index 0da73fef08..7fe2898fcd 100644 --- a/sqlx-postgres/src/types/geometry/mod.rs +++ b/sqlx-postgres/src/types/geometry/mod.rs @@ -1,3 +1,4 @@ +pub mod r#box; pub mod line; pub mod line_segment; pub mod point; diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index b5b3266cbc..e53f88ffec 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -24,6 +24,7 @@ //! | [`PgPoint] | POINT | //! | [`PgLine] | LINE | //! | [`PgLSeg] | LSEG | +//! | [`PgBox] | BOX | //! | [`PgHstore`] | HSTORE | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., @@ -262,6 +263,7 @@ pub use cube::PgCube; pub use geometry::line::PgLine; pub use geometry::line_segment::PgLSeg; pub use geometry::point::PgPoint; +pub use geometry::r#box::PgBox; pub use hstore::PgHstore; pub use interval::PgInterval; pub use lquery::PgLQuery; diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index c1cf87983c..ccf88b1099 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -514,6 +514,16 @@ test_type!(lseg(Postgres, "lseg('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgLSeg { start_x: 1., start_y: 2., end_x: 3. , end_y: 4.}, )); +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(box(Postgres, + "box('((1.0, 2.0), (3.0,4.0))')" == sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1. , lower_left_y: 2.}, +)); + +#[cfg(any(postgres_12, postgres_13, postgres_14, postgres_15))] +test_type!(_box>(Postgres, + "array[box('1,2,3,4'),box('((1.1, 2.2), (3.3, 4.4))')]" @= vec![sqlx::postgres::types::PgBox { upper_right_x: 3., upper_right_y: 4., lower_left_x: 1., lower_left_y: 2. }, sqlx::postgres::types::PgBox { upper_right_x: 3.3, upper_right_y: 4.4, lower_left_x: 1.1, lower_left_y: 2.2 }], +)); + #[cfg(feature = "rust_decimal")] test_type!(decimal(Postgres, "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(),