Skip to content

Commit 20fe435

Browse files
authored
Merge pull request http-rs#879 from marcoslopes/main
Allow cors origin to be dynamically matched
2 parents 9cca13f + 9927041 commit 20fe435

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pin-project-lite = "0.2.0"
4949
serde = "1.0.117"
5050
serde_json = "1.0.59"
5151
routefinder = "0.5.0"
52+
regex = "1.5.5"
5253

5354
[dev-dependencies]
5455
async-std = { version = "1.6.5", features = ["unstable", "attributes"] }

src/security/cors.rs

+81-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use http_types::headers::{HeaderValue, HeaderValues};
22
use http_types::{headers, Method, StatusCode};
3+
use regex::Regex;
4+
use std::hash::Hash;
35

46
use crate::middleware::{Middleware, Next};
57
use crate::{Request, Result};
@@ -128,6 +130,7 @@ impl CorsMiddleware {
128130
Origin::Any => true,
129131
Origin::Exact(s) => s == &origin,
130132
Origin::List(list) => list.contains(&origin),
133+
Origin::Match(regex) => regex.is_match(&origin),
131134
}
132135
}
133136
}
@@ -187,14 +190,16 @@ impl Default for CorsMiddleware {
187190
}
188191

189192
/// `allow_origin` enum
190-
#[derive(Clone, Debug, Hash, PartialEq)]
193+
#[derive(Clone, Debug)]
191194
pub enum Origin {
192195
/// Wildcard. Accept all origin requests
193196
Any,
194197
/// Set a single allow_origin target
195198
Exact(String),
196199
/// Set multiple allow_origin targets
197200
List(Vec<String>),
201+
/// Set a regex allow_origin targets
202+
Match(Regex),
198203
}
199204

200205
impl From<String> for Origin {
@@ -222,6 +227,12 @@ impl From<Vec<String>> for Origin {
222227
}
223228
}
224229

230+
impl From<Regex> for Origin {
231+
fn from(regex: Regex) -> Self {
232+
Self::Match(regex)
233+
}
234+
}
235+
225236
impl From<Vec<&str>> for Origin {
226237
fn from(list: Vec<&str>) -> Self {
227238
Self::from(
@@ -232,6 +243,28 @@ impl From<Vec<&str>> for Origin {
232243
}
233244
}
234245

246+
impl PartialEq for Origin {
247+
fn eq(&self, other: &Self) -> bool {
248+
match (self, other) {
249+
(Self::Exact(this), Self::Exact(other)) => this == other,
250+
(Self::List(this), Self::List(other)) => this == other,
251+
(Self::Match(this), Self::Match(other)) => this.to_string() == other.to_string(),
252+
_ => core::mem::discriminant(self) == core::mem::discriminant(other),
253+
}
254+
}
255+
}
256+
257+
impl Hash for Origin {
258+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
259+
match self {
260+
Self::Any => core::mem::discriminant(self).hash(state),
261+
Self::Exact(s) => s.hash(state),
262+
Self::List(list) => list.hash(state),
263+
Self::Match(regex) => regex.to_string().hash(state),
264+
}
265+
}
266+
}
267+
235268
#[cfg(test)]
236269
mod test {
237270
use super::*;
@@ -313,6 +346,23 @@ mod test {
313346
assert_eq!(res[headers::ACCESS_CONTROL_ALLOW_ORIGIN], ALLOW_ORIGIN);
314347
}
315348

349+
#[async_std::test]
350+
async fn regex_cors_middleware() {
351+
let regex = Regex::new(r"e[xzs]a.*le.com*").unwrap();
352+
let mut app = app();
353+
app.with(
354+
CorsMiddleware::new()
355+
.allow_origin(Origin::from(regex))
356+
.allow_credentials(false)
357+
.allow_methods(ALLOW_METHODS.parse::<HeaderValue>().unwrap())
358+
.expose_headers(EXPOSE_HEADER.parse::<HeaderValue>().unwrap()),
359+
);
360+
let res: crate::http::Response = app.respond(request()).await.unwrap();
361+
362+
assert_eq!(res.status(), 200);
363+
assert_eq!(res[headers::ACCESS_CONTROL_ALLOW_ORIGIN], ALLOW_ORIGIN);
364+
}
365+
316366
#[async_std::test]
317367
async fn credentials_true() {
318368
let mut app = app();
@@ -396,4 +446,34 @@ mod test {
396446
assert_eq!(res.status(), 400);
397447
assert_eq!(res[headers::ACCESS_CONTROL_ALLOW_ORIGIN], ALLOW_ORIGIN);
398448
}
449+
450+
#[cfg(test)]
451+
mod origin {
452+
use super::super::Origin;
453+
use regex::Regex;
454+
455+
#[test]
456+
fn transitive() {
457+
let regex = Regex::new(r"e[xzs]a.*le.com*").unwrap();
458+
let x = Origin::from(regex.clone());
459+
let y = Origin::from(regex.clone());
460+
let z = Origin::from(regex);
461+
assert!(x == y && y == z && x == z);
462+
}
463+
464+
#[test]
465+
fn symetrical() {
466+
let regex = Regex::new(r"e[xzs]a.*le.com*").unwrap();
467+
let x = Origin::from(regex.clone());
468+
let y = Origin::from(regex);
469+
assert!(x == y && y == x);
470+
}
471+
472+
#[test]
473+
fn reflexive() {
474+
let regex = Regex::new(r"e[xzs]a.*le.com*").unwrap();
475+
let x = Origin::from(regex);
476+
assert!(x == x);
477+
}
478+
}
399479
}

0 commit comments

Comments
 (0)