Skip to content

Commit 88b625f

Browse files
committed
feat(spin_sdk::pg): Improve pg type's conversion in Rust SDK
- Add conversion support for: int16, int32, int64, floating32, floating64, binary, boolean; - Add support for nullable variants; Signed-off-by: Konstantin Shabanov <[email protected]>
1 parent 58e98c9 commit 88b625f

File tree

6 files changed

+396
-57
lines changed

6 files changed

+396
-57
lines changed

crates/outbound-pg/src/lib.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ fn convert_data_type(pg_type: &Type) -> DbDataType {
139139
Type::INT2 => DbDataType::Int16,
140140
Type::INT4 => DbDataType::Int32,
141141
Type::INT8 => DbDataType::Int64,
142-
Type::TEXT => DbDataType::Str,
143-
Type::VARCHAR => DbDataType::Str,
142+
Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str,
144143
_ => {
145144
tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),);
146145
DbDataType::Other
@@ -208,17 +207,10 @@ fn convert_entry(row: &Row, index: usize) -> Result<DbValue, tokio_postgres::Err
208207
None => DbValue::DbNull,
209208
}
210209
}
211-
&Type::TEXT => {
212-
let value: Option<&str> = row.try_get(index)?;
210+
&Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => {
211+
let value: Option<String> = row.try_get(index)?;
213212
match value {
214-
Some(v) => DbValue::Str(v.to_owned()),
215-
None => DbValue::DbNull,
216-
}
217-
}
218-
&Type::VARCHAR => {
219-
let value: Option<&str> = row.try_get(index)?;
220-
match value {
221-
Some(v) => DbValue::Str(v.to_owned()),
213+
Some(v) => DbValue::Str(v),
222214
None => DbValue::DbNull,
223215
}
224216
}

examples/rust-outbound-pg/db/testdata.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ CREATE TABLE articletest (
22
id integer GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
33
title varchar(40) NOT NULL,
44
content text NOT NULL,
5-
authorname varchar(40) NOT NULL
5+
authorname varchar(40) NOT NULL,
6+
coauthor text
67
);
78

89
INSERT INTO articletest (title, content, authorname) VALUES

examples/rust-outbound-pg/src/lib.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#![allow(dead_code)]
22
use anyhow::{anyhow, Result};
3+
use spin_sdk::pg::Decode;
34
use spin_sdk::{
45
http::{Request, Response},
5-
http_component,
6-
pg,
6+
http_component, pg,
77
};
88

99
// The environment variable set in `spin.toml` that points to the
@@ -16,22 +16,25 @@ struct Article {
1616
title: String,
1717
content: String,
1818
authorname: String,
19+
coauthor: Option<String>,
1920
}
2021

2122
impl TryFrom<&pg::Row> for Article {
2223
type Error = anyhow::Error;
2324

2425
fn try_from(row: &pg::Row) -> Result<Self, Self::Error> {
25-
let id: i32 = (&row[0]).try_into()?;
26-
let title: String = (&row[1]).try_into()?;
27-
let content: String = (&row[2]).try_into()?;
28-
let authorname: String = (&row[3]).try_into()?;
26+
let id = i32::decode(&row[0])?;
27+
let title = String::decode(&row[1])?;
28+
let content = String::decode(&row[2])?;
29+
let authorname = String::decode(&row[3])?;
30+
let coauthor = Option::<String>::decode(&row[4])?;
2931

3032
Ok(Self {
3133
id,
3234
title,
3335
content,
3436
authorname,
37+
coauthor,
3538
})
3639
}
3740
}
@@ -51,7 +54,7 @@ fn process(req: Request) -> Result<Response> {
5154
fn read(_req: Request) -> Result<Response> {
5255
let address = std::env::var(DB_URL_ENV)?;
5356

54-
let sql = "SELECT id, title, content, authorname FROM articletest";
57+
let sql = "SELECT id, title, content, authorname, coauthor FROM articletest";
5558
let rowset = pg::query(&address, sql, &[])
5659
.map_err(|e| anyhow!("Error executing Postgres query: {:?}", e))?;
5760

@@ -98,7 +101,7 @@ fn write(_req: Request) -> Result<Response> {
98101
let rowset = pg::query(&address, sql, &[])
99102
.map_err(|e| anyhow!("Error executing Postgres query: {:?}", e))?;
100103
let row = &rowset.rows[0];
101-
let count: i64 = (&row[0]).try_into()?;
104+
let count = i64::decode(&row[0])?;
102105
let response = format!("Count: {}\n", count);
103106

104107
Ok(http::Response::builder()
@@ -116,7 +119,7 @@ fn pg_backend_pid(_req: Request) -> Result<Response> {
116119

117120
let row = &rowset.rows[0];
118121

119-
i32::try_from(&row[0])
122+
i32::decode(&row[0])
120123
};
121124

122125
assert_eq!(get_pid()?, get_pid()?);

sdk/rust/src/pg.rs

Lines changed: 153 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,185 @@
1-
21
wit_bindgen_rust::import!("../../wit/ephemeral/outbound-pg.wit");
32

43
/// Exports the generated outbound Pg items.
54
pub use outbound_pg::*;
65

7-
impl TryFrom<&DbValue> for i32 {
8-
type Error = anyhow::Error;
6+
pub trait Decode: Sized {
7+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error>;
8+
}
9+
10+
impl<T> Decode for Option<T>
11+
where
12+
T: Decode,
13+
{
14+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
15+
match value {
16+
DbValue::DbNull => Ok(None),
17+
v => Ok(Some(T::decode(v)?)),
18+
}
19+
}
20+
}
21+
22+
impl Decode for bool {
23+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
24+
match value {
25+
DbValue::Boolean(boolean) => Ok(*boolean),
26+
_ => Err(anyhow::anyhow!(
27+
"Expected BOOL from the DB but got {:?}",
28+
value
29+
)),
30+
}
31+
}
32+
}
33+
34+
impl Decode for i16 {
35+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
36+
match value {
37+
DbValue::Int16(n) => Ok(*n),
38+
_ => Err(anyhow::anyhow!(
39+
"Expected SMALLINT from the DB but got {:?}",
40+
value
41+
)),
42+
}
43+
}
44+
}
945

10-
fn try_from(value: &DbValue) -> Result<Self, Self::Error> {
46+
impl Decode for i32 {
47+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
1148
match value {
1249
DbValue::Int32(n) => Ok(*n),
1350
_ => Err(anyhow::anyhow!(
14-
"Expected integer from database but got {:?}",
51+
"Expected INT from the DB but got {:?}",
1552
value
1653
)),
1754
}
1855
}
1956
}
2057

21-
impl TryFrom<&DbValue> for String {
22-
type Error = anyhow::Error;
58+
impl Decode for i64 {
59+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
60+
match value {
61+
DbValue::Int64(n) => Ok(*n),
62+
_ => Err(anyhow::anyhow!(
63+
"Expected BIGINT from the DB but got {:?}",
64+
value
65+
)),
66+
}
67+
}
68+
}
2369

24-
fn try_from(value: &DbValue) -> Result<Self, Self::Error> {
70+
impl Decode for f32 {
71+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
2572
match value {
26-
DbValue::Str(s) => Ok(s.to_owned()),
73+
DbValue::Floating32(n) => Ok(*n),
74+
_ => Err(anyhow::anyhow!(
75+
"Expected REAL from the DB but got {:?}",
76+
value
77+
)),
78+
}
79+
}
80+
}
81+
82+
impl Decode for f64 {
83+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
84+
match value {
85+
DbValue::Floating64(n) => Ok(*n),
2786
_ => Err(anyhow::anyhow!(
28-
"Expected string from the DB but got {:?}",
87+
"Expected DOUBLE PRECISION from the DB but got {:?}",
2988
value
3089
)),
3190
}
3291
}
3392
}
3493

35-
impl TryFrom<&DbValue> for i64 {
36-
type Error = anyhow::Error;
94+
impl Decode for Vec<u8> {
95+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
96+
match value {
97+
DbValue::Binary(n) => Ok(n.to_owned()),
98+
_ => Err(anyhow::anyhow!(
99+
"Expected BYTEA from the DB but got {:?}",
100+
value
101+
)),
102+
}
103+
}
104+
}
37105

38-
fn try_from(value: &DbValue) -> Result<Self, Self::Error> {
106+
impl Decode for String {
107+
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
39108
match value {
40-
DbValue::Int64(n) => Ok(*n),
109+
DbValue::Str(s) => Ok(s.to_owned()),
41110
_ => Err(anyhow::anyhow!(
42-
"Expected integer from the DB but got {:?}",
111+
"Expected CHAR, VARCHAR, TEXT from the DB but got {:?}",
43112
value
44113
)),
45114
}
46115
}
47116
}
117+
118+
#[cfg(test)]
119+
mod tests {
120+
use super::*;
121+
122+
#[test]
123+
fn boolean() {
124+
assert!(bool::decode(&DbValue::Boolean(true)).unwrap());
125+
assert!(bool::decode(&DbValue::Int32(0)).is_err());
126+
assert!(Option::<bool>::decode(&DbValue::DbNull).unwrap().is_none());
127+
}
128+
129+
#[test]
130+
fn int16() {
131+
assert_eq!(i16::decode(&DbValue::Int16(0)).unwrap(), 0);
132+
assert!(i16::decode(&DbValue::Int32(0)).is_err());
133+
assert!(Option::<i16>::decode(&DbValue::DbNull).unwrap().is_none());
134+
}
135+
136+
#[test]
137+
fn int32() {
138+
assert_eq!(i32::decode(&DbValue::Int32(0)).unwrap(), 0);
139+
assert!(i32::decode(&DbValue::Boolean(false)).is_err());
140+
assert!(Option::<i32>::decode(&DbValue::DbNull).unwrap().is_none());
141+
}
142+
143+
#[test]
144+
fn int64() {
145+
assert_eq!(i64::decode(&DbValue::Int64(0)).unwrap(), 0);
146+
assert!(i64::decode(&DbValue::Boolean(false)).is_err());
147+
assert!(Option::<i64>::decode(&DbValue::DbNull).unwrap().is_none());
148+
}
149+
150+
#[test]
151+
fn floating32() {
152+
assert!(f32::decode(&DbValue::Floating32(0.0)).is_ok());
153+
assert!(f32::decode(&DbValue::Boolean(false)).is_err());
154+
assert!(Option::<f32>::decode(&DbValue::DbNull).unwrap().is_none());
155+
}
156+
157+
#[test]
158+
fn floating64() {
159+
assert!(f64::decode(&DbValue::Floating64(0.0)).is_ok());
160+
assert!(f64::decode(&DbValue::Boolean(false)).is_err());
161+
assert!(Option::<f64>::decode(&DbValue::DbNull).unwrap().is_none());
162+
}
163+
164+
#[test]
165+
fn str() {
166+
assert_eq!(
167+
String::decode(&DbValue::Str(String::from("foo"))).unwrap(),
168+
String::from("foo")
169+
);
170+
171+
assert!(String::decode(&DbValue::Int32(0)).is_err());
172+
assert!(Option::<String>::decode(&DbValue::DbNull)
173+
.unwrap()
174+
.is_none());
175+
}
176+
177+
#[test]
178+
fn binary() {
179+
assert!(Vec::<u8>::decode(&DbValue::Binary(vec![0, 0])).is_ok());
180+
assert!(Vec::<u8>::decode(&DbValue::Boolean(false)).is_err());
181+
assert!(Option::<Vec<u8>>::decode(&DbValue::DbNull)
182+
.unwrap()
183+
.is_none());
184+
}
185+
}

tests/integration.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,9 @@ mod integration_tests {
632632
)
633633
.await?;
634634

635-
assert_status(&s, "/test_read_types", 200).await?;
635+
assert_status(&s, "/test_numeric_types", 200).await?;
636+
assert_status(&s, "/test_character_types", 200).await?;
637+
assert_status(&s, "/test_general_types", 200).await?;
636638
assert_status(&s, "/pg_backend_pid", 200).await?;
637639

638640
Ok(())

0 commit comments

Comments
 (0)