|
| 1 | +use crate::decode::Decode; |
| 2 | +use crate::encode::{Encode, IsNull}; |
| 3 | +use crate::error::BoxDynError; |
| 4 | +use crate::types::{PgPoint, Type}; |
| 5 | +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; |
| 6 | +use sqlx_core::bytes::Buf; |
| 7 | +use sqlx_core::Error; |
| 8 | +use std::mem; |
| 9 | +use std::str::FromStr; |
| 10 | + |
| 11 | +const BYTE_WIDTH: usize = mem::size_of::<f64>(); |
| 12 | + |
| 13 | +/// ## Postgres Geometric Path type |
| 14 | +/// |
| 15 | +/// Description: Open path or Closed path (similar to polygon) |
| 16 | +/// Representation: Open `[(x1,y1),...]`, Closed `((x1,y1),...)` |
| 17 | +/// |
| 18 | +/// Paths are represented by lists of connected points. Paths can be open, where the first and last points in the list are considered not connected, or closed, where the first and last points are considered connected. |
| 19 | +/// Values of type path are specified using any of the following syntaxes: |
| 20 | +/// ```text |
| 21 | +/// [ ( x1 , y1 ) , ... , ( xn , yn ) ] |
| 22 | +/// ( ( x1 , y1 ) , ... , ( xn , yn ) ) |
| 23 | +/// ( x1 , y1 ) , ... , ( xn , yn ) |
| 24 | +/// ( x1 , y1 , ... , xn , yn ) |
| 25 | +/// x1 , y1 , ... , xn , yn |
| 26 | +/// ``` |
| 27 | +/// where the points are the end points of the line segments comprising the path. Square brackets `([])` indicate an open path, while parentheses `(())` indicate a closed path. |
| 28 | +/// When the outermost parentheses are omitted, as in the third through fifth syntaxes, a closed path is assumed. |
| 29 | +/// |
| 30 | +/// See https://www.postgresql.org/docs/16/datatype-geometric.html#DATATYPE-GEOMETRIC-PATHS |
| 31 | +#[derive(Debug, Clone, PartialEq)] |
| 32 | +pub struct PgPath { |
| 33 | + pub closed: bool, |
| 34 | + pub points: Vec<PgPoint>, |
| 35 | +} |
| 36 | + |
| 37 | +#[derive(Copy, Clone, Debug, PartialEq, Eq)] |
| 38 | +struct Header { |
| 39 | + is_closed: bool, |
| 40 | + length: usize, |
| 41 | +} |
| 42 | + |
| 43 | +impl Type<Postgres> for PgPath { |
| 44 | + fn type_info() -> PgTypeInfo { |
| 45 | + PgTypeInfo::with_name("path") |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +impl PgHasArrayType for PgPath { |
| 50 | + fn array_type_info() -> PgTypeInfo { |
| 51 | + PgTypeInfo::with_name("_path") |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +impl<'r> Decode<'r, Postgres> for PgPath { |
| 56 | + fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { |
| 57 | + match value.format() { |
| 58 | + PgValueFormat::Text => Ok(PgPath::from_str(value.as_str()?)?), |
| 59 | + PgValueFormat::Binary => Ok(PgPath::from_bytes(value.as_bytes()?)?), |
| 60 | + } |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +impl<'q> Encode<'q, Postgres> for PgPath { |
| 65 | + fn produces(&self) -> Option<PgTypeInfo> { |
| 66 | + Some(PgTypeInfo::with_name("path")) |
| 67 | + } |
| 68 | + |
| 69 | + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { |
| 70 | + self.serialize(buf)?; |
| 71 | + Ok(IsNull::No) |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +impl FromStr for PgPath { |
| 76 | + type Err = Error; |
| 77 | + |
| 78 | + fn from_str(s: &str) -> Result<Self, Self::Err> { |
| 79 | + let closed = !s.contains('['); |
| 80 | + let sanitised = s.replace(['(', ')', '[', ']', ' '], ""); |
| 81 | + let parts = sanitised.split(',').collect::<Vec<_>>(); |
| 82 | + |
| 83 | + let mut points = vec![]; |
| 84 | + |
| 85 | + if parts.len() % 2 != 0 { |
| 86 | + return Err(Error::Decode( |
| 87 | + format!("Unmatched pair in PATH: {}", s).into(), |
| 88 | + )); |
| 89 | + } |
| 90 | + |
| 91 | + for chunk in parts.chunks_exact(2) { |
| 92 | + if let [x_str, y_str] = chunk { |
| 93 | + let x = parse_float_from_str(x_str, "could not get x")?; |
| 94 | + let y = parse_float_from_str(y_str, "could not get y")?; |
| 95 | + |
| 96 | + let point = PgPoint { x, y }; |
| 97 | + points.push(point); |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + if !points.is_empty() { |
| 102 | + return Ok(PgPath { points, closed }); |
| 103 | + } |
| 104 | + |
| 105 | + Err(Error::Decode( |
| 106 | + format!("could not get path from {}", s).into(), |
| 107 | + )) |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +impl PgPath { |
| 112 | + fn header(&self) -> Header { |
| 113 | + Header { |
| 114 | + is_closed: self.closed, |
| 115 | + length: self.points.len(), |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + fn from_bytes(mut bytes: &[u8]) -> Result<Self, BoxDynError> { |
| 120 | + let header = Header::try_read(&mut bytes)?; |
| 121 | + |
| 122 | + if bytes.len() != header.data_size() { |
| 123 | + return Err(format!( |
| 124 | + "expected {} bytes after header, got {}", |
| 125 | + header.data_size(), |
| 126 | + bytes.len() |
| 127 | + ) |
| 128 | + .into()); |
| 129 | + } |
| 130 | + |
| 131 | + if bytes.len() % BYTE_WIDTH * 2 != 0 { |
| 132 | + return Err(format!( |
| 133 | + "data length not divisible by pairs of {BYTE_WIDTH}: {}", |
| 134 | + bytes.len() |
| 135 | + ) |
| 136 | + .into()); |
| 137 | + } |
| 138 | + |
| 139 | + let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2)); |
| 140 | + |
| 141 | + while bytes.has_remaining() { |
| 142 | + let point = PgPoint { |
| 143 | + x: bytes.get_f64(), |
| 144 | + y: bytes.get_f64(), |
| 145 | + }; |
| 146 | + out_points.push(point) |
| 147 | + } |
| 148 | + Ok(PgPath { |
| 149 | + closed: header.is_closed, |
| 150 | + points: out_points, |
| 151 | + }) |
| 152 | + } |
| 153 | + |
| 154 | + fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> { |
| 155 | + let header = self.header(); |
| 156 | + buff.reserve(header.data_size()); |
| 157 | + header.try_write(buff)?; |
| 158 | + |
| 159 | + for point in &self.points { |
| 160 | + buff.extend_from_slice(&point.x.to_be_bytes()); |
| 161 | + buff.extend_from_slice(&point.y.to_be_bytes()); |
| 162 | + } |
| 163 | + Ok(()) |
| 164 | + } |
| 165 | + |
| 166 | + #[cfg(test)] |
| 167 | + fn serialize_to_vec(&self) -> Vec<u8> { |
| 168 | + let mut buff = PgArgumentBuffer::default(); |
| 169 | + self.serialize(&mut buff).unwrap(); |
| 170 | + buff.to_vec() |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +impl Header { |
| 175 | + const HEADER_WIDTH: usize = mem::size_of::<i8>() + mem::size_of::<i32>(); |
| 176 | + |
| 177 | + fn data_size(&self) -> usize { |
| 178 | + self.length * BYTE_WIDTH * 2 |
| 179 | + } |
| 180 | + |
| 181 | + fn try_read(buf: &mut &[u8]) -> Result<Self, String> { |
| 182 | + if buf.len() < Self::HEADER_WIDTH { |
| 183 | + return Err(format!( |
| 184 | + "expected PATH data to contain at least {} bytes, got {}", |
| 185 | + Self::HEADER_WIDTH, |
| 186 | + buf.len() |
| 187 | + )); |
| 188 | + } |
| 189 | + |
| 190 | + let is_closed = buf.get_i8(); |
| 191 | + let length = buf.get_i32(); |
| 192 | + |
| 193 | + let length = usize::try_from(length).ok().ok_or_else(|| { |
| 194 | + format!( |
| 195 | + "received PATH data length: {length}. Expected length between 0 and {}", |
| 196 | + usize::MAX |
| 197 | + ) |
| 198 | + })?; |
| 199 | + |
| 200 | + Ok(Self { |
| 201 | + is_closed: is_closed != 0, |
| 202 | + length, |
| 203 | + }) |
| 204 | + } |
| 205 | + |
| 206 | + fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> { |
| 207 | + let is_closed = self.is_closed as i8; |
| 208 | + |
| 209 | + let length = i32::try_from(self.length).map_err(|_| { |
| 210 | + format!( |
| 211 | + "PATH length exceeds allowed maximum ({} > {})", |
| 212 | + self.length, |
| 213 | + i32::MAX |
| 214 | + ) |
| 215 | + })?; |
| 216 | + |
| 217 | + buff.extend(is_closed.to_be_bytes()); |
| 218 | + buff.extend(length.to_be_bytes()); |
| 219 | + |
| 220 | + Ok(()) |
| 221 | + } |
| 222 | +} |
| 223 | + |
| 224 | +fn parse_float_from_str(s: &str, error_msg: &str) -> Result<f64, Error> { |
| 225 | + s.parse().map_err(|_| Error::Decode(error_msg.into())) |
| 226 | +} |
| 227 | + |
| 228 | +#[cfg(test)] |
| 229 | +mod path_tests { |
| 230 | + |
| 231 | + use std::str::FromStr; |
| 232 | + |
| 233 | + use crate::types::PgPoint; |
| 234 | + |
| 235 | + use super::PgPath; |
| 236 | + |
| 237 | + const PATH_CLOSED_BYTES: &[u8] = &[ |
| 238 | + 1, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, |
| 239 | + 64, 16, 0, 0, 0, 0, 0, 0, |
| 240 | + ]; |
| 241 | + |
| 242 | + const PATH_OPEN_BYTES: &[u8] = &[ |
| 243 | + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, |
| 244 | + 64, 16, 0, 0, 0, 0, 0, 0, |
| 245 | + ]; |
| 246 | + |
| 247 | + const PATH_UNEVEN_POINTS: &[u8] = &[ |
| 248 | + 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0, |
| 249 | + 64, 16, 0, 0, |
| 250 | + ]; |
| 251 | + |
| 252 | + #[test] |
| 253 | + fn can_deserialise_path_type_bytes_closed() { |
| 254 | + let path = PgPath::from_bytes(PATH_CLOSED_BYTES).unwrap(); |
| 255 | + assert_eq!( |
| 256 | + path, |
| 257 | + PgPath { |
| 258 | + closed: true, |
| 259 | + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] |
| 260 | + } |
| 261 | + ) |
| 262 | + } |
| 263 | + |
| 264 | + #[test] |
| 265 | + fn cannot_deserialise_path_type_uneven_point_bytes() { |
| 266 | + let path = PgPath::from_bytes(PATH_UNEVEN_POINTS); |
| 267 | + assert!(path.is_err()); |
| 268 | + |
| 269 | + if let Err(err) = path { |
| 270 | + assert_eq!( |
| 271 | + err.to_string(), |
| 272 | + format!("expected 32 bytes after header, got 28") |
| 273 | + ) |
| 274 | + } |
| 275 | + } |
| 276 | + |
| 277 | + #[test] |
| 278 | + fn can_deserialise_path_type_bytes_open() { |
| 279 | + let path = PgPath::from_bytes(PATH_OPEN_BYTES).unwrap(); |
| 280 | + assert_eq!( |
| 281 | + path, |
| 282 | + PgPath { |
| 283 | + closed: false, |
| 284 | + points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }] |
| 285 | + } |
| 286 | + ) |
| 287 | + } |
| 288 | + |
| 289 | + #[test] |
| 290 | + fn can_deserialise_path_type_str_first_syntax() { |
| 291 | + let path = PgPath::from_str("[( 1, 2), (3, 4 )]").unwrap(); |
| 292 | + assert_eq!( |
| 293 | + path, |
| 294 | + PgPath { |
| 295 | + closed: false, |
| 296 | + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] |
| 297 | + } |
| 298 | + ); |
| 299 | + } |
| 300 | + |
| 301 | + #[test] |
| 302 | + fn cannot_deserialise_path_type_str_uneven_points_first_syntax() { |
| 303 | + let input_str = "[( 1, 2), (3)]"; |
| 304 | + let path = PgPath::from_str(input_str); |
| 305 | + |
| 306 | + assert!(path.is_err()); |
| 307 | + |
| 308 | + if let Err(err) = path { |
| 309 | + assert_eq!( |
| 310 | + err.to_string(), |
| 311 | + format!("error occurred while decoding: Unmatched pair in PATH: {input_str}") |
| 312 | + ) |
| 313 | + } |
| 314 | + } |
| 315 | + |
| 316 | + #[test] |
| 317 | + fn can_deserialise_path_type_str_second_syntax() { |
| 318 | + let path = PgPath::from_str("(( 1, 2), (3, 4 ))").unwrap(); |
| 319 | + assert_eq!( |
| 320 | + path, |
| 321 | + PgPath { |
| 322 | + closed: true, |
| 323 | + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] |
| 324 | + } |
| 325 | + ); |
| 326 | + } |
| 327 | + |
| 328 | + #[test] |
| 329 | + fn can_deserialise_path_type_str_third_syntax() { |
| 330 | + let path = PgPath::from_str("(1, 2), (3, 4 )").unwrap(); |
| 331 | + assert_eq!( |
| 332 | + path, |
| 333 | + PgPath { |
| 334 | + closed: true, |
| 335 | + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] |
| 336 | + } |
| 337 | + ); |
| 338 | + } |
| 339 | + |
| 340 | + #[test] |
| 341 | + fn can_deserialise_path_type_str_fourth_syntax() { |
| 342 | + let path = PgPath::from_str("1, 2, 3, 4").unwrap(); |
| 343 | + assert_eq!( |
| 344 | + path, |
| 345 | + PgPath { |
| 346 | + closed: true, |
| 347 | + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }] |
| 348 | + } |
| 349 | + ); |
| 350 | + } |
| 351 | + |
| 352 | + #[test] |
| 353 | + fn can_deserialise_path_type_str_float() { |
| 354 | + let path = PgPath::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap(); |
| 355 | + assert_eq!( |
| 356 | + path, |
| 357 | + PgPath { |
| 358 | + closed: true, |
| 359 | + points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }] |
| 360 | + } |
| 361 | + ); |
| 362 | + } |
| 363 | + |
| 364 | + #[test] |
| 365 | + fn can_serialise_path_type() { |
| 366 | + let path = PgPath { |
| 367 | + closed: true, |
| 368 | + points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }], |
| 369 | + }; |
| 370 | + assert_eq!(path.serialize_to_vec(), PATH_CLOSED_BYTES,) |
| 371 | + } |
| 372 | +} |
0 commit comments