@@ -2,15 +2,18 @@ use mysql_async::consts::ColumnType;
2
2
use mysql_async:: { from_value_opt, prelude:: * } ;
3
3
pub use outbound_mysql:: add_to_linker;
4
4
use spin_core:: HostComponent ;
5
+ use std:: collections:: HashMap ;
5
6
use std:: sync:: Arc ;
6
7
use wit_bindgen_wasmtime:: async_trait;
7
8
8
9
wit_bindgen_wasmtime:: export!( { paths: [ "../../wit/ephemeral/outbound-mysql.wit" ] , async : * } ) ;
9
10
use outbound_mysql:: * ;
10
11
11
12
/// A simple implementation to support outbound mysql connection
12
- #[ derive( Default , Clone ) ]
13
- pub struct OutboundMysql ;
13
+ #[ derive( Default ) ]
14
+ pub struct OutboundMysql {
15
+ pub connections : HashMap < String , mysql_async:: Conn > ,
16
+ }
14
17
15
18
impl HostComponent for OutboundMysql {
16
19
type Data = Self ;
@@ -35,20 +38,17 @@ impl outbound_mysql::OutboundMysql for OutboundMysql {
35
38
statement : & str ,
36
39
params : Vec < ParameterValue < ' _ > > ,
37
40
) -> Result < ( ) , MysqlError > {
38
- let connection_pool = mysql_async:: Pool :: new ( address) ;
39
- let mut connection = connection_pool
40
- . get_conn ( )
41
- . await
42
- . map_err ( |e| MysqlError :: ConnectionFailed ( format ! ( "{:?}" , e) ) ) ?;
43
-
44
41
let db_params = params
45
42
. iter ( )
46
43
. map ( to_sql_parameter)
47
44
. collect :: < anyhow:: Result < Vec < _ > > > ( )
48
45
. map_err ( |e| MysqlError :: QueryFailed ( format ! ( "{:?}" , e) ) ) ?;
49
46
50
47
let parameters = mysql_async:: Params :: Positional ( db_params) ;
51
- connection
48
+
49
+ self . get_conn ( address)
50
+ . await
51
+ . map_err ( |e| MysqlError :: ConnectionFailed ( format ! ( "{:?}" , e) ) ) ?
52
52
. exec_batch ( statement, & [ parameters] )
53
53
. await
54
54
. map_err ( |e| MysqlError :: QueryFailed ( format ! ( "{:?}" , e) ) ) ?;
@@ -62,20 +62,18 @@ impl outbound_mysql::OutboundMysql for OutboundMysql {
62
62
statement : & str ,
63
63
params : Vec < ParameterValue < ' _ > > ,
64
64
) -> Result < RowSet , MysqlError > {
65
- let connection_pool = mysql_async:: Pool :: new ( address) ;
66
- let mut connection = connection_pool
67
- . get_conn ( )
68
- . await
69
- . map_err ( |e| MysqlError :: ConnectionFailed ( format ! ( "{:?}" , e) ) ) ?;
70
-
71
65
let db_params = params
72
66
. iter ( )
73
67
. map ( to_sql_parameter)
74
68
. collect :: < anyhow:: Result < Vec < _ > > > ( )
75
69
. map_err ( |e| MysqlError :: QueryFailed ( format ! ( "{:?}" , e) ) ) ?;
76
70
77
71
let parameters = mysql_async:: Params :: Positional ( db_params) ;
78
- let mut query_result = connection
72
+
73
+ let mut query_result = self
74
+ . get_conn ( address)
75
+ . await
76
+ . map_err ( |e| MysqlError :: ConnectionFailed ( format ! ( "{:?}" , e) ) ) ?
79
77
. exec_iter ( statement, parameters)
80
78
. await
81
79
. map_err ( |e| MysqlError :: QueryFailed ( format ! ( "{:?}" , e) ) ) ?;
@@ -132,24 +130,45 @@ fn convert_column(column: &mysql_async::Column) -> Column {
132
130
}
133
131
134
132
fn convert_data_type ( column : & mysql_async:: Column ) -> DbDataType {
135
- match ( column. column_type ( ) , is_signed ( column) ) {
136
- ( ColumnType :: MYSQL_TYPE_BIT , _) => DbDataType :: Boolean ,
133
+ let column_type = column. column_type ( ) ;
134
+
135
+ if column_type. is_numeric_type ( ) {
136
+ convert_numeric_type ( column)
137
+ } else if column_type. is_character_type ( ) {
138
+ convert_character_type ( column)
139
+ } else {
140
+ DbDataType :: Other
141
+ }
142
+ }
143
+
144
+ fn convert_character_type ( column : & mysql_async:: Column ) -> DbDataType {
145
+ match ( column. column_type ( ) , is_binary ( column) ) {
146
+ ( ColumnType :: MYSQL_TYPE_BLOB , false ) => DbDataType :: Str , // TEXT type
137
147
( ColumnType :: MYSQL_TYPE_BLOB , _) => DbDataType :: Binary ,
148
+ ( ColumnType :: MYSQL_TYPE_LONG_BLOB , _) => DbDataType :: Binary ,
149
+ ( ColumnType :: MYSQL_TYPE_MEDIUM_BLOB , _) => DbDataType :: Binary ,
150
+ ( ColumnType :: MYSQL_TYPE_STRING , true ) => DbDataType :: Binary , // BINARY type
151
+ ( ColumnType :: MYSQL_TYPE_STRING , _) => DbDataType :: Str ,
152
+ ( ColumnType :: MYSQL_TYPE_VAR_STRING , true ) => DbDataType :: Binary , // VARBINARY type
153
+ ( ColumnType :: MYSQL_TYPE_VAR_STRING , _) => DbDataType :: Str ,
154
+ ( _, _) => DbDataType :: Other ,
155
+ }
156
+ }
157
+
158
+ fn convert_numeric_type ( column : & mysql_async:: Column ) -> DbDataType {
159
+ match ( column. column_type ( ) , is_signed ( column) ) {
138
160
( ColumnType :: MYSQL_TYPE_DOUBLE , _) => DbDataType :: Floating64 ,
139
161
( ColumnType :: MYSQL_TYPE_FLOAT , _) => DbDataType :: Floating32 ,
162
+ ( ColumnType :: MYSQL_TYPE_INT24 , true ) => DbDataType :: Int32 ,
163
+ ( ColumnType :: MYSQL_TYPE_INT24 , false ) => DbDataType :: Uint32 ,
140
164
( ColumnType :: MYSQL_TYPE_LONG , true ) => DbDataType :: Int32 ,
141
165
( ColumnType :: MYSQL_TYPE_LONG , false ) => DbDataType :: Uint32 ,
142
166
( ColumnType :: MYSQL_TYPE_LONGLONG , true ) => DbDataType :: Int64 ,
143
167
( ColumnType :: MYSQL_TYPE_LONGLONG , false ) => DbDataType :: Uint64 ,
144
- ( ColumnType :: MYSQL_TYPE_LONG_BLOB , _) => DbDataType :: Binary ,
145
- ( ColumnType :: MYSQL_TYPE_MEDIUM_BLOB , _) => DbDataType :: Binary ,
146
168
( ColumnType :: MYSQL_TYPE_SHORT , true ) => DbDataType :: Int16 ,
147
169
( ColumnType :: MYSQL_TYPE_SHORT , false ) => DbDataType :: Uint16 ,
148
- ( ColumnType :: MYSQL_TYPE_STRING , _) => DbDataType :: Str ,
149
170
( ColumnType :: MYSQL_TYPE_TINY , true ) => DbDataType :: Int8 ,
150
171
( ColumnType :: MYSQL_TYPE_TINY , false ) => DbDataType :: Uint8 ,
151
- ( ColumnType :: MYSQL_TYPE_VARCHAR , _) => DbDataType :: Str ,
152
- ( ColumnType :: MYSQL_TYPE_VAR_STRING , _) => DbDataType :: Str ,
153
172
( _, _) => DbDataType :: Other ,
154
173
}
155
174
}
@@ -160,6 +179,12 @@ fn is_signed(column: &mysql_async::Column) -> bool {
160
179
. contains ( mysql_async:: consts:: ColumnFlags :: UNSIGNED_FLAG )
161
180
}
162
181
182
+ fn is_binary ( column : & mysql_async:: Column ) -> bool {
183
+ column
184
+ . flags ( )
185
+ . contains ( mysql_async:: consts:: ColumnFlags :: BINARY_FLAG )
186
+ }
187
+
163
188
fn convert_row ( mut row : mysql_async:: Row , columns : & [ Column ] ) -> Result < Vec < DbValue > , MysqlError > {
164
189
let mut result = Vec :: with_capacity ( row. len ( ) ) ;
165
190
for index in 0 ..row. len ( ) {
@@ -206,6 +231,24 @@ fn convert_value(value: mysql_async::Value, column: &Column) -> Result<DbValue,
206
231
}
207
232
}
208
233
234
+ impl OutboundMysql {
235
+ async fn get_conn ( & mut self , address : & str ) -> anyhow:: Result < & mut mysql_async:: Conn > {
236
+ let client = match self . connections . entry ( address. to_owned ( ) ) {
237
+ std:: collections:: hash_map:: Entry :: Occupied ( o) => o. into_mut ( ) ,
238
+ std:: collections:: hash_map:: Entry :: Vacant ( v) => v. insert ( build_conn ( address) . await ?) ,
239
+ } ;
240
+ Ok ( client)
241
+ }
242
+ }
243
+
244
+ async fn build_conn ( address : & str ) -> Result < mysql_async:: Conn , mysql_async:: Error > {
245
+ tracing:: log:: debug!( "Build new connection: {}" , address) ;
246
+
247
+ let connection_pool = mysql_async:: Pool :: new ( address) ;
248
+
249
+ connection_pool. get_conn ( ) . await
250
+ }
251
+
209
252
fn convert_value_to < T : FromValue > ( value : mysql_async:: Value ) -> Result < T , MysqlError > {
210
253
from_value_opt :: < T > ( value) . map_err ( |e| MysqlError :: ValueConversionFailed ( format ! ( "{}" , e) ) )
211
254
}
0 commit comments