Skip to content

Commit 327af64

Browse files
committed
Fix support for Postgres array of custom types
This commit fixes the array decoder to support custom types. The core of the issue was that the array decoder did not use the type info retrieved from the database. It means that it only supported native types. This commit fixes the issue by using the element type info fetched from the database. A new internal helper method is added to the `PgType` struct: it returns the type info for the inner array element, if available. Closes #1477
1 parent 9abe9b3 commit 327af64

File tree

4 files changed

+219
-4
lines changed

4 files changed

+219
-4
lines changed

sqlx-core/src/postgres/connection/describe.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::sync::Arc;
1616
/// Describes the type of the `pg_type.typtype` column
1717
///
1818
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
19+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1920
enum TypType {
2021
Base,
2122
Composite,
@@ -45,6 +46,7 @@ impl TryFrom<u8> for TypType {
4546
/// Describes the type of the `pg_type.typcategory` column
4647
///
4748
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
49+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
4850
enum TypCategory {
4951
Array,
5052
Boolean,
@@ -198,7 +200,9 @@ impl PgConnection {
198200

199201
(Ok(TypType::Base), Ok(TypCategory::Array)) => {
200202
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
201-
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
203+
kind: PgTypeKind::Array(
204+
self.maybe_fetch_type_info_by_oid(element, true).await?,
205+
),
202206
name: name.into(),
203207
oid,
204208
}))))

sqlx-core/src/postgres/type_info.rs

+120
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![allow(dead_code)]
22

3+
use std::borrow::Cow;
34
use std::fmt::{self, Display, Formatter};
45
use std::ops::Deref;
56
use std::sync::Arc;
@@ -750,6 +751,125 @@ impl PgType {
750751
}
751752
}
752753
}
754+
755+
/// If `self` is an array type, return the type info for its element.
756+
///
757+
/// This method should only be called on resolved types: calling it on
758+
/// a type that is merely declared (DeclareWithOid/Name) is a bug.
759+
pub(crate) fn try_array_element(&self) -> Option<Cow<'_, PgTypeInfo>> {
760+
// We explicitly match on all the `None` cases to ensure an exhaustive match.
761+
match self {
762+
PgType::Bool => None,
763+
PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))),
764+
PgType::Bytea => None,
765+
PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))),
766+
PgType::Char => None,
767+
PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))),
768+
PgType::Name => None,
769+
PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))),
770+
PgType::Int8 => None,
771+
PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))),
772+
PgType::Int2 => None,
773+
PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))),
774+
PgType::Int4 => None,
775+
PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))),
776+
PgType::Text => None,
777+
PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))),
778+
PgType::Oid => None,
779+
PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))),
780+
PgType::Json => None,
781+
PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))),
782+
PgType::Point => None,
783+
PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))),
784+
PgType::Lseg => None,
785+
PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))),
786+
PgType::Path => None,
787+
PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))),
788+
PgType::Box => None,
789+
PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))),
790+
PgType::Polygon => None,
791+
PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))),
792+
PgType::Line => None,
793+
PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))),
794+
PgType::Cidr => None,
795+
PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))),
796+
PgType::Float4 => None,
797+
PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))),
798+
PgType::Float8 => None,
799+
PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))),
800+
PgType::Circle => None,
801+
PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))),
802+
PgType::Macaddr8 => None,
803+
PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))),
804+
PgType::Money => None,
805+
PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))),
806+
PgType::Macaddr => None,
807+
PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))),
808+
PgType::Inet => None,
809+
PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))),
810+
PgType::Bpchar => None,
811+
PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))),
812+
PgType::Varchar => None,
813+
PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))),
814+
PgType::Date => None,
815+
PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))),
816+
PgType::Time => None,
817+
PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))),
818+
PgType::Timestamp => None,
819+
PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))),
820+
PgType::Timestamptz => None,
821+
PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))),
822+
PgType::Interval => None,
823+
PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))),
824+
PgType::Timetz => None,
825+
PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))),
826+
PgType::Bit => None,
827+
PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))),
828+
PgType::Varbit => None,
829+
PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))),
830+
PgType::Numeric => None,
831+
PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))),
832+
PgType::Record => None,
833+
PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))),
834+
PgType::Uuid => None,
835+
PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))),
836+
PgType::Jsonb => None,
837+
PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))),
838+
PgType::Int4Range => None,
839+
PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))),
840+
PgType::NumRange => None,
841+
PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))),
842+
PgType::TsRange => None,
843+
PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))),
844+
PgType::TstzRange => None,
845+
PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))),
846+
PgType::DateRange => None,
847+
PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))),
848+
PgType::Int8Range => None,
849+
PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))),
850+
PgType::Jsonpath => None,
851+
PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))),
852+
// There is no `UnknownArray`
853+
PgType::Unknown => None,
854+
// There is no `VoidArray`
855+
PgType::Void => None,
856+
PgType::Custom(ty) => match &ty.kind {
857+
PgTypeKind::Simple => None,
858+
PgTypeKind::Pseudo => None,
859+
PgTypeKind::Domain(_) => None,
860+
PgTypeKind::Composite(_) => None,
861+
PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)),
862+
PgTypeKind::Enum(_) => None,
863+
PgTypeKind::Range(_) => None,
864+
},
865+
PgType::DeclareWithOid(oid) => {
866+
unreachable!("(bug) use of unresolved type declaration [oid={}]", oid);
867+
}
868+
PgType::DeclareWithName(name) => {
869+
unreachable!("(bug) use of unresolved type declaration [name={}]", name);
870+
}
871+
}
872+
}
753873
}
754874

755875
impl TypeInfo for PgTypeInfo {

sqlx-core/src/postgres/types/array.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use bytes::Buf;
2+
use std::borrow::Cow;
23

34
use crate::decode::Decode;
45
use crate::encode::{Encode, IsNull};
@@ -77,7 +78,6 @@ where
7778
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
7879
{
7980
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
80-
let element_type_info;
8181
let format = value.format();
8282

8383
match format {
@@ -105,7 +105,11 @@ where
105105

106106
// the OID of the element
107107
let element_type_oid = buf.get_u32();
108-
element_type_info = PgTypeInfo::try_from_oid(element_type_oid)
108+
let element_type_info: PgTypeInfo = value
109+
.type_info
110+
.try_array_element()
111+
.map(Cow::into_owned)
112+
.or_else(|| PgTypeInfo::try_from_oid(element_type_oid))
109113
.unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid)));
110114

111115
// length of the array axis
@@ -133,7 +137,7 @@ where
133137

134138
PgValueFormat::Text => {
135139
// no type is provided from the database for the element
136-
element_type_info = T::type_info();
140+
let element_type_info = T::type_info();
137141

138142
let s = value.as_str()?;
139143

tests/postgres/postgres.rs

+87
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,93 @@ CREATE TABLE heating_bills (
10941094
Ok(())
10951095
}
10961096

1097+
#[sqlx_macros::test]
1098+
async fn it_resolves_custom_type_in_array() -> anyhow::Result<()> {
1099+
// Only supported in Postgres 11+
1100+
let mut conn = new::<Postgres>().await?;
1101+
if matches!(conn.server_version_num(), Some(version) if version < 110000) {
1102+
return Ok(());
1103+
}
1104+
1105+
// language=PostgreSQL
1106+
conn.execute(
1107+
r#"
1108+
DROP TABLE IF EXISTS pets;
1109+
DROP TYPE IF EXISTS pet_name_and_race;
1110+
1111+
CREATE TYPE pet_name_and_race AS (
1112+
name TEXT,
1113+
race TEXT
1114+
);
1115+
CREATE TABLE pets (
1116+
owner TEXT NOT NULL,
1117+
name TEXT NOT NULL,
1118+
race TEXT NOT NULL,
1119+
PRIMARY KEY (owner, name)
1120+
);
1121+
INSERT INTO pets(owner, name, race)
1122+
VALUES
1123+
('Alice', 'Foo', 'cat');
1124+
INSERT INTO pets(owner, name, race)
1125+
VALUES
1126+
('Alice', 'Bar', 'dog');
1127+
"#,
1128+
)
1129+
.await?;
1130+
1131+
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
1132+
struct PetNameAndRace {
1133+
name: String,
1134+
race: String,
1135+
}
1136+
1137+
impl sqlx::Type<Postgres> for PetNameAndRace {
1138+
fn type_info() -> sqlx::postgres::PgTypeInfo {
1139+
sqlx::postgres::PgTypeInfo::with_name("pet_name_and_race")
1140+
}
1141+
}
1142+
1143+
impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRace {
1144+
fn decode(
1145+
value: sqlx::postgres::PgValueRef<'r>,
1146+
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
1147+
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
1148+
let name = decoder.try_decode::<String>()?;
1149+
let race = decoder.try_decode::<String>()?;
1150+
Ok(Self { name, race })
1151+
}
1152+
}
1153+
1154+
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
1155+
struct PetNameAndRaceArray(Vec<PetNameAndRace>);
1156+
1157+
impl sqlx::Type<Postgres> for PetNameAndRaceArray {
1158+
fn type_info() -> sqlx::postgres::PgTypeInfo {
1159+
// Array type name is the name of the element type prefixed with `_`
1160+
sqlx::postgres::PgTypeInfo::with_name("_pet_name_and_race")
1161+
}
1162+
}
1163+
1164+
impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRaceArray {
1165+
fn decode(
1166+
value: sqlx::postgres::PgValueRef<'r>,
1167+
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
1168+
Ok(Self(Vec::<PetNameAndRace>::decode(value)?))
1169+
}
1170+
}
1171+
1172+
let mut conn = new::<Postgres>().await?;
1173+
1174+
let row = sqlx::query("select owner, array_agg(row(name, race)::pet_name_and_race) as pets from pets group by owner")
1175+
.fetch_one(&mut conn)
1176+
.await?;
1177+
1178+
let pets: PetNameAndRaceArray = row.get("pets");
1179+
1180+
assert_eq!(pets.0.len(), 2);
1181+
Ok(())
1182+
}
1183+
10971184
#[sqlx_macros::test]
10981185
async fn test_pg_server_num() -> anyhow::Result<()> {
10991186
use sqlx::postgres::PgConnectionInfo;

0 commit comments

Comments
 (0)