Skip to content

Commit 32f1273

Browse files
authored
Fix support for Postgres array of custom types (#1483)
This commit fixes the array decoder to support custom types. The core of the issue was that the array decoder did not use the type info retrieved from the database. It means that it only supported native types. This commit fixes the issue by using the element type info fetched from the database. A new internal helper method is added to the `PgType` struct: it returns the type info for the inner array element, if available. Closes #1477
1 parent dee5147 commit 32f1273

File tree

4 files changed

+216
-4
lines changed

4 files changed

+216
-4
lines changed

sqlx-core/src/postgres/connection/describe.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::sync::Arc;
1616
/// Describes the type of the `pg_type.typtype` column
1717
///
1818
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
19+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1920
enum TypType {
2021
Base,
2122
Composite,
@@ -45,6 +46,7 @@ impl TryFrom<u8> for TypType {
4546
/// Describes the type of the `pg_type.typcategory` column
4647
///
4748
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
49+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
4850
enum TypCategory {
4951
Array,
5052
Boolean,
@@ -198,7 +200,9 @@ impl PgConnection {
198200

199201
(Ok(TypType::Base), Ok(TypCategory::Array)) => {
200202
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
201-
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
203+
kind: PgTypeKind::Array(
204+
self.maybe_fetch_type_info_by_oid(element, true).await?,
205+
),
202206
name: name.into(),
203207
oid,
204208
}))))

sqlx-core/src/postgres/type_info.rs

+120
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![allow(dead_code)]
22

3+
use std::borrow::Cow;
34
use std::fmt::{self, Display, Formatter};
45
use std::ops::Deref;
56
use std::sync::Arc;
@@ -750,6 +751,125 @@ impl PgType {
750751
}
751752
}
752753
}
754+
755+
/// If `self` is an array type, return the type info for its element.
756+
///
757+
/// This method should only be called on resolved types: calling it on
758+
/// a type that is merely declared (DeclareWithOid/Name) is a bug.
759+
pub(crate) fn try_array_element(&self) -> Option<Cow<'_, PgTypeInfo>> {
760+
// We explicitly match on all the `None` cases to ensure an exhaustive match.
761+
match self {
762+
PgType::Bool => None,
763+
PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))),
764+
PgType::Bytea => None,
765+
PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))),
766+
PgType::Char => None,
767+
PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))),
768+
PgType::Name => None,
769+
PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))),
770+
PgType::Int8 => None,
771+
PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))),
772+
PgType::Int2 => None,
773+
PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))),
774+
PgType::Int4 => None,
775+
PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))),
776+
PgType::Text => None,
777+
PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))),
778+
PgType::Oid => None,
779+
PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))),
780+
PgType::Json => None,
781+
PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))),
782+
PgType::Point => None,
783+
PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))),
784+
PgType::Lseg => None,
785+
PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))),
786+
PgType::Path => None,
787+
PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))),
788+
PgType::Box => None,
789+
PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))),
790+
PgType::Polygon => None,
791+
PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))),
792+
PgType::Line => None,
793+
PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))),
794+
PgType::Cidr => None,
795+
PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))),
796+
PgType::Float4 => None,
797+
PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))),
798+
PgType::Float8 => None,
799+
PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))),
800+
PgType::Circle => None,
801+
PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))),
802+
PgType::Macaddr8 => None,
803+
PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))),
804+
PgType::Money => None,
805+
PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))),
806+
PgType::Macaddr => None,
807+
PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))),
808+
PgType::Inet => None,
809+
PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))),
810+
PgType::Bpchar => None,
811+
PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))),
812+
PgType::Varchar => None,
813+
PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))),
814+
PgType::Date => None,
815+
PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))),
816+
PgType::Time => None,
817+
PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))),
818+
PgType::Timestamp => None,
819+
PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))),
820+
PgType::Timestamptz => None,
821+
PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))),
822+
PgType::Interval => None,
823+
PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))),
824+
PgType::Timetz => None,
825+
PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))),
826+
PgType::Bit => None,
827+
PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))),
828+
PgType::Varbit => None,
829+
PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))),
830+
PgType::Numeric => None,
831+
PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))),
832+
PgType::Record => None,
833+
PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))),
834+
PgType::Uuid => None,
835+
PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))),
836+
PgType::Jsonb => None,
837+
PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))),
838+
PgType::Int4Range => None,
839+
PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))),
840+
PgType::NumRange => None,
841+
PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))),
842+
PgType::TsRange => None,
843+
PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))),
844+
PgType::TstzRange => None,
845+
PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))),
846+
PgType::DateRange => None,
847+
PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))),
848+
PgType::Int8Range => None,
849+
PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))),
850+
PgType::Jsonpath => None,
851+
PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))),
852+
// There is no `UnknownArray`
853+
PgType::Unknown => None,
854+
// There is no `VoidArray`
855+
PgType::Void => None,
856+
PgType::Custom(ty) => match &ty.kind {
857+
PgTypeKind::Simple => None,
858+
PgTypeKind::Pseudo => None,
859+
PgTypeKind::Domain(_) => None,
860+
PgTypeKind::Composite(_) => None,
861+
PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)),
862+
PgTypeKind::Enum(_) => None,
863+
PgTypeKind::Range(_) => None,
864+
},
865+
PgType::DeclareWithOid(oid) => {
866+
unreachable!("(bug) use of unresolved type declaration [oid={}]", oid);
867+
}
868+
PgType::DeclareWithName(name) => {
869+
unreachable!("(bug) use of unresolved type declaration [name={}]", name);
870+
}
871+
}
872+
}
753873
}
754874

755875
impl TypeInfo for PgTypeInfo {

sqlx-core/src/postgres/types/array.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use bytes::Buf;
2+
use std::borrow::Cow;
23

34
use crate::decode::Decode;
45
use crate::encode::{Encode, IsNull};
@@ -103,7 +104,6 @@ where
103104
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
104105
{
105106
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
106-
let element_type_info;
107107
let format = value.format();
108108

109109
match format {
@@ -131,7 +131,8 @@ where
131131

132132
// the OID of the element
133133
let element_type_oid = buf.get_u32();
134-
element_type_info = PgTypeInfo::try_from_oid(element_type_oid)
134+
let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid)
135+
.or_else(|| value.type_info.try_array_element().map(Cow::into_owned))
135136
.unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid)));
136137

137138
// length of the array axis
@@ -159,7 +160,7 @@ where
159160

160161
PgValueFormat::Text => {
161162
// no type is provided from the database for the element
162-
element_type_info = T::type_info();
163+
let element_type_info = T::type_info();
163164

164165
let s = value.as_str()?;
165166

tests/postgres/postgres.rs

+87
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,93 @@ CREATE TABLE heating_bills (
10941094
Ok(())
10951095
}
10961096

1097+
#[sqlx_macros::test]
1098+
async fn it_resolves_custom_type_in_array() -> anyhow::Result<()> {
1099+
// Only supported in Postgres 11+
1100+
let mut conn = new::<Postgres>().await?;
1101+
if matches!(conn.server_version_num(), Some(version) if version < 110000) {
1102+
return Ok(());
1103+
}
1104+
1105+
// language=PostgreSQL
1106+
conn.execute(
1107+
r#"
1108+
DROP TABLE IF EXISTS pets;
1109+
DROP TYPE IF EXISTS pet_name_and_race;
1110+
1111+
CREATE TYPE pet_name_and_race AS (
1112+
name TEXT,
1113+
race TEXT
1114+
);
1115+
CREATE TABLE pets (
1116+
owner TEXT NOT NULL,
1117+
name TEXT NOT NULL,
1118+
race TEXT NOT NULL,
1119+
PRIMARY KEY (owner, name)
1120+
);
1121+
INSERT INTO pets(owner, name, race)
1122+
VALUES
1123+
('Alice', 'Foo', 'cat');
1124+
INSERT INTO pets(owner, name, race)
1125+
VALUES
1126+
('Alice', 'Bar', 'dog');
1127+
"#,
1128+
)
1129+
.await?;
1130+
1131+
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
1132+
struct PetNameAndRace {
1133+
name: String,
1134+
race: String,
1135+
}
1136+
1137+
impl sqlx::Type<Postgres> for PetNameAndRace {
1138+
fn type_info() -> sqlx::postgres::PgTypeInfo {
1139+
sqlx::postgres::PgTypeInfo::with_name("pet_name_and_race")
1140+
}
1141+
}
1142+
1143+
impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRace {
1144+
fn decode(
1145+
value: sqlx::postgres::PgValueRef<'r>,
1146+
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
1147+
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
1148+
let name = decoder.try_decode::<String>()?;
1149+
let race = decoder.try_decode::<String>()?;
1150+
Ok(Self { name, race })
1151+
}
1152+
}
1153+
1154+
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
1155+
struct PetNameAndRaceArray(Vec<PetNameAndRace>);
1156+
1157+
impl sqlx::Type<Postgres> for PetNameAndRaceArray {
1158+
fn type_info() -> sqlx::postgres::PgTypeInfo {
1159+
// Array type name is the name of the element type prefixed with `_`
1160+
sqlx::postgres::PgTypeInfo::with_name("_pet_name_and_race")
1161+
}
1162+
}
1163+
1164+
impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRaceArray {
1165+
fn decode(
1166+
value: sqlx::postgres::PgValueRef<'r>,
1167+
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
1168+
Ok(Self(Vec::<PetNameAndRace>::decode(value)?))
1169+
}
1170+
}
1171+
1172+
let mut conn = new::<Postgres>().await?;
1173+
1174+
let row = sqlx::query("select owner, array_agg(row(name, race)::pet_name_and_race) as pets from pets group by owner")
1175+
.fetch_one(&mut conn)
1176+
.await?;
1177+
1178+
let pets: PetNameAndRaceArray = row.get("pets");
1179+
1180+
assert_eq!(pets.0.len(), 2);
1181+
Ok(())
1182+
}
1183+
10971184
#[sqlx_macros::test]
10981185
async fn test_pg_server_num() -> anyhow::Result<()> {
10991186
use sqlx::postgres::PgConnectionInfo;

0 commit comments

Comments
 (0)