1
1
use http_types:: headers:: { HeaderValue , HeaderValues } ;
2
2
use http_types:: { headers, Method , StatusCode } ;
3
+ use regex:: Regex ;
4
+ use std:: hash:: Hash ;
3
5
4
6
use crate :: middleware:: { Middleware , Next } ;
5
7
use crate :: { Request , Result } ;
@@ -128,6 +130,7 @@ impl CorsMiddleware {
128
130
Origin :: Any => true ,
129
131
Origin :: Exact ( s) => s == & origin,
130
132
Origin :: List ( list) => list. contains ( & origin) ,
133
+ Origin :: Match ( regex) => regex. is_match ( & origin) ,
131
134
}
132
135
}
133
136
}
@@ -187,14 +190,16 @@ impl Default for CorsMiddleware {
187
190
}
188
191
189
192
/// `allow_origin` enum
190
- #[ derive( Clone , Debug , Hash , PartialEq ) ]
193
+ #[ derive( Clone , Debug ) ]
191
194
pub enum Origin {
192
195
/// Wildcard. Accept all origin requests
193
196
Any ,
194
197
/// Set a single allow_origin target
195
198
Exact ( String ) ,
196
199
/// Set multiple allow_origin targets
197
200
List ( Vec < String > ) ,
201
+ /// Set a regex allow_origin targets
202
+ Match ( Regex ) ,
198
203
}
199
204
200
205
impl From < String > for Origin {
@@ -222,6 +227,12 @@ impl From<Vec<String>> for Origin {
222
227
}
223
228
}
224
229
230
+ impl From < Regex > for Origin {
231
+ fn from ( regex : Regex ) -> Self {
232
+ Self :: Match ( regex)
233
+ }
234
+ }
235
+
225
236
impl From < Vec < & str > > for Origin {
226
237
fn from ( list : Vec < & str > ) -> Self {
227
238
Self :: from (
@@ -232,6 +243,28 @@ impl From<Vec<&str>> for Origin {
232
243
}
233
244
}
234
245
246
+ impl PartialEq for Origin {
247
+ fn eq ( & self , other : & Self ) -> bool {
248
+ match ( self , other) {
249
+ ( Self :: Exact ( this) , Self :: Exact ( other) ) => this == other,
250
+ ( Self :: List ( this) , Self :: List ( other) ) => this == other,
251
+ ( Self :: Match ( this) , Self :: Match ( other) ) => this. to_string ( ) == other. to_string ( ) ,
252
+ _ => core:: mem:: discriminant ( self ) == core:: mem:: discriminant ( other) ,
253
+ }
254
+ }
255
+ }
256
+
257
+ impl Hash for Origin {
258
+ fn hash < H : std:: hash:: Hasher > ( & self , state : & mut H ) {
259
+ match self {
260
+ Self :: Any => core:: mem:: discriminant ( self ) . hash ( state) ,
261
+ Self :: Exact ( s) => s. hash ( state) ,
262
+ Self :: List ( list) => list. hash ( state) ,
263
+ Self :: Match ( regex) => regex. to_string ( ) . hash ( state) ,
264
+ }
265
+ }
266
+ }
267
+
235
268
#[ cfg( test) ]
236
269
mod test {
237
270
use super :: * ;
@@ -313,6 +346,23 @@ mod test {
313
346
assert_eq ! ( res[ headers:: ACCESS_CONTROL_ALLOW_ORIGIN ] , ALLOW_ORIGIN ) ;
314
347
}
315
348
349
+ #[ async_std:: test]
350
+ async fn regex_cors_middleware ( ) {
351
+ let regex = Regex :: new ( r"e[xzs]a.*le.com*" ) . unwrap ( ) ;
352
+ let mut app = app ( ) ;
353
+ app. with (
354
+ CorsMiddleware :: new ( )
355
+ . allow_origin ( Origin :: from ( regex) )
356
+ . allow_credentials ( false )
357
+ . allow_methods ( ALLOW_METHODS . parse :: < HeaderValue > ( ) . unwrap ( ) )
358
+ . expose_headers ( EXPOSE_HEADER . parse :: < HeaderValue > ( ) . unwrap ( ) ) ,
359
+ ) ;
360
+ let res: crate :: http:: Response = app. respond ( request ( ) ) . await . unwrap ( ) ;
361
+
362
+ assert_eq ! ( res. status( ) , 200 ) ;
363
+ assert_eq ! ( res[ headers:: ACCESS_CONTROL_ALLOW_ORIGIN ] , ALLOW_ORIGIN ) ;
364
+ }
365
+
316
366
#[ async_std:: test]
317
367
async fn credentials_true ( ) {
318
368
let mut app = app ( ) ;
@@ -396,4 +446,34 @@ mod test {
396
446
assert_eq ! ( res. status( ) , 400 ) ;
397
447
assert_eq ! ( res[ headers:: ACCESS_CONTROL_ALLOW_ORIGIN ] , ALLOW_ORIGIN ) ;
398
448
}
449
+
450
+ #[ cfg( test) ]
451
+ mod origin {
452
+ use super :: super :: Origin ;
453
+ use regex:: Regex ;
454
+
455
+ #[ test]
456
+ fn transitive ( ) {
457
+ let regex = Regex :: new ( r"e[xzs]a.*le.com*" ) . unwrap ( ) ;
458
+ let x = Origin :: from ( regex. clone ( ) ) ;
459
+ let y = Origin :: from ( regex. clone ( ) ) ;
460
+ let z = Origin :: from ( regex) ;
461
+ assert ! ( x == y && y == z && x == z) ;
462
+ }
463
+
464
+ #[ test]
465
+ fn symetrical ( ) {
466
+ let regex = Regex :: new ( r"e[xzs]a.*le.com*" ) . unwrap ( ) ;
467
+ let x = Origin :: from ( regex. clone ( ) ) ;
468
+ let y = Origin :: from ( regex) ;
469
+ assert ! ( x == y && y == x) ;
470
+ }
471
+
472
+ #[ test]
473
+ fn reflexive ( ) {
474
+ let regex = Regex :: new ( r"e[xzs]a.*le.com*" ) . unwrap ( ) ;
475
+ let x = Origin :: from ( regex) ;
476
+ assert ! ( x == x) ;
477
+ }
478
+ }
399
479
}
0 commit comments