From 9f500ffd0f5451d24b03da8b4c1094cc783db0c0 Mon Sep 17 00:00:00 2001 From: alexwl Date: Tue, 25 Feb 2020 21:06:43 +0000 Subject: [PATCH] Add support for deriving FromSql and ToSql for structs with references --- .../src/compile-fail/invalid-types.rs | 5 +++ .../src/compile-fail/invalid-types.stderr | 6 ++++ postgres-derive-test/src/composites.rs | 29 +++++++++++++++++ postgres-derive-test/src/domains.rs | 20 ++++++++++++ postgres-derive/src/fromsql.rs | 31 +++++++++++++++++-- postgres-derive/src/tosql.rs | 4 ++- 6 files changed, 91 insertions(+), 4 deletions(-) diff --git a/postgres-derive-test/src/compile-fail/invalid-types.rs b/postgres-derive-test/src/compile-fail/invalid-types.rs index ef41ac820..1f937d122 100644 --- a/postgres-derive-test/src/compile-fail/invalid-types.rs +++ b/postgres-derive-test/src/compile-fail/invalid-types.rs @@ -22,4 +22,9 @@ enum FromSqlEnum { Foo(i32), } +#[derive(FromSql)] +struct FromSqlTypeParameter { + foo: T, +} + fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-types.stderr b/postgres-derive-test/src/compile-fail/invalid-types.stderr index 9b563d58b..a27ecf1e1 100644 --- a/postgres-derive-test/src/compile-fail/invalid-types.stderr +++ b/postgres-derive-test/src/compile-fail/invalid-types.stderr @@ -33,3 +33,9 @@ error: non-C-like enums are not supported | 22 | Foo(i32), | ^^^^^^^^ + +error: #[derive(FromSql)] does not support type parameters. + --> $DIR/invalid-types.rs:26:28 + | +26 | struct FromSqlTypeParameter { + | ^^^ diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index 5efd3944c..c8024b084 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -215,3 +215,32 @@ fn wrong_type() { .unwrap_err(); assert!(err.source().unwrap().is::()); } + +#[test] +fn struct_with_references() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "item")] + struct Item<'a, 'b: 'a> { + name: &'a str, + data: &'b [u8], + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.item AS ( + name TEXT, + data BYTEA + );", + ) + .unwrap(); + + let item = Item { + name: "foobar", + data: b"12345", + }; + + let row = conn.query_one("SELECT $1::item", &[&item]).unwrap(); + let result: Item<'_, '_> = row.get(0); + assert_eq!(item.name, result.name); + assert_eq!(item.data, result.data); +} diff --git a/postgres-derive-test/src/domains.rs b/postgres-derive-test/src/domains.rs index 25674f75e..c44f819a2 100644 --- a/postgres-derive-test/src/domains.rs +++ b/postgres-derive-test/src/domains.rs @@ -119,3 +119,23 @@ fn domain_in_composite() { )], ); } + +#[test] +fn struct_with_reference() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + struct SessionId<'b>(&'b [u8]); + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE DOMAIN pg_temp.\"SessionId\" AS bytea CHECK(octet_length(VALUE) = 16);", + &[], + ) + .unwrap(); + + let session_id = b"0123456789abcdef"; + let row = conn + .query_one("SELECT $1::\"SessionId\"", &[&SessionId(session_id)]) + .unwrap(); + let result: SessionId<'_> = row.get(0); + assert_eq!(session_id, result.0); +} diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index e1ab6ffa7..1d2f828d2 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -8,6 +8,8 @@ use crate::composites::Field; use crate::enums::Variant; use crate::overrides::Overrides; +const DEFAULT_LIFETIME: &str = "de"; + pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let overrides = Overrides::extract(&input.attrs)?; @@ -58,10 +60,33 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { }; let ident = &input.ident; + let mut generics = input.generics; + if generics.type_params().count() > 0 { + return Err(Error::new_spanned( + &generics, + "#[derive(FromSql)] does not support type parameters.", + )); + } + + let generics_clone = &generics.clone(); + let (_, type_generics, _) = generics_clone.split_for_impl(); + + let lifetime = syn::Lifetime::new(&format!("'{}", DEFAULT_LIFETIME), Span::call_site()); + let mut lifetime_def = syn::LifetimeDef::new(lifetime.clone()); + let lifetimes: Vec = generics.lifetimes().map(|l| l.lifetime.clone()).collect(); + lifetime_def.bounds = syn::punctuated::Punctuated::new(); + for l in lifetimes { + lifetime_def.bounds.push(l); + } + generics + .params + .push(syn::GenericParam::Lifetime(lifetime_def)); + let (impl_generics, _, _) = generics.split_for_impl(); + let out = quote! { - impl<'a> postgres_types::FromSql<'a> for #ident { - fn from_sql(_type: &postgres_types::Type, buf: &'a [u8]) - -> std::result::Result<#ident, + impl #impl_generics postgres_types::FromSql<#lifetime> for #ident #type_generics { + fn from_sql(_type: &postgres_types::Type, buf: & #lifetime [u8]) + -> std::result::Result<#ident #type_generics, std::boxed::Box> { diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index a1c87b0ff..214608cc0 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -55,8 +55,10 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { }; let ident = &input.ident; + let generics = &input.generics; + let (impl_generics, type_generics, where_clause) = generics.split_for_impl(); let out = quote! { - impl postgres_types::ToSql for #ident { + impl #impl_generics postgres_types::ToSql for #ident #type_generics #where_clause { fn to_sql(&self, _type: &postgres_types::Type, buf: &mut postgres_types::private::BytesMut)