Skip to content

Commit 1dbc15e

Browse files
committed
Allow for more CORS configuration
Motivation: We added some level of CORS configuration support in #1583. This change adds further flexibility. Modifications: - Add an 'originBased' mode where the value of the origin header is returned in the response head. - Add a custom fallback where the user can specify a callback which is passed the value of the origin header and returns the value to return in the 'access-control-allow-origin' response header (or nil, if the origin is not allowed). Result: More flexibility for CORS.
1 parent 75b390e commit 1dbc15e

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

Diff for: Sources/GRPC/Server.swift

+51
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ extension Server.Configuration.CORS {
479479
public struct AllowedOrigins: Hashable, Sendable {
480480
enum Wrapped: Hashable, Sendable {
481481
case all
482+
case originBased
482483
case only([String])
484+
case custom(AnyCustomCORSAllowedOrigin)
483485
}
484486

485487
private(set) var wrapped: Wrapped
@@ -490,10 +492,23 @@ extension Server.Configuration.CORS {
490492
/// Allow all origin values.
491493
public static let all = Self(.all)
492494

495+
/// Allow all origin values; similar to `all` but returns the value of the origin header field
496+
/// in the 'access-control-allow-origin' response header (rather than "*").
497+
public static let originBased = Self(.originBased)
498+
493499
/// Allow only the given origin values.
494500
public static func only(_ allowed: [String]) -> Self {
495501
return Self(.only(allowed))
496502
}
503+
504+
/// Provide a custom CORS origin check.
505+
///
506+
/// - Parameter checkOrigin: A closure which is called with the value of the 'origin' header
507+
/// and returns the value to use in the 'access-control-allow-origin' response header,
508+
/// or `nil` if the origin is not allowed.
509+
public static func custom<C: GRPCCustomCORSAllowedOrigin>(_ custom: C) -> Self {
510+
return Self(.custom(AnyCustomCORSAllowedOrigin(custom)))
511+
}
497512
}
498513
}
499514

@@ -520,3 +535,39 @@ extension Comparable {
520535
return min(max(self, range.lowerBound), range.upperBound)
521536
}
522537
}
538+
539+
public protocol GRPCCustomCORSAllowedOrigin: Sendable, Hashable {
540+
/// Returns the value to use for the 'access-control-allow-origin' response header for the given
541+
/// value of the 'origin' request header.
542+
///
543+
/// - Parameter origin: The value of the 'origin' request header field.
544+
/// - Returns: The value to use for the 'access-control-allow-origin' header field or `nil` if no
545+
/// CORS related headers should be returned.
546+
func check(origin: String) -> String?
547+
}
548+
549+
extension Server.Configuration.CORS.AllowedOrigins {
550+
struct AnyCustomCORSAllowedOrigin: GRPCCustomCORSAllowedOrigin {
551+
private var checkOrigin: @Sendable (String) -> String?
552+
private let hashInto: @Sendable (inout Hasher) -> Void
553+
private let isEqualTo: @Sendable (any GRPCCustomCORSAllowedOrigin) -> Bool
554+
555+
init<W: GRPCCustomCORSAllowedOrigin>(_ wrap: W) {
556+
self.checkOrigin = { wrap.check(origin: $0) }
557+
self.hashInto = { wrap.hash(into: &$0) }
558+
self.isEqualTo = { wrap == ($0 as? W) }
559+
}
560+
561+
func check(origin: String) -> String? {
562+
return self.checkOrigin(origin)
563+
}
564+
565+
func hash(into hasher: inout Hasher) {
566+
self.hashInto(&hasher)
567+
}
568+
569+
static func == (lhs: Self, rhs: Self) -> Bool {
570+
return lhs.isEqualTo(rhs)
571+
}
572+
}
573+
}

Diff for: Sources/GRPC/WebCORSHandler.swift

+4
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,12 @@ extension Server.Configuration.CORS.AllowedOrigins {
198198
switch self.wrapped {
199199
case .all:
200200
return "*"
201+
case .originBased:
202+
return origin
201203
case let .only(allowed):
202204
return allowed.contains(origin) ? origin : nil
205+
case let .custom(custom):
206+
return custom.check(origin: origin)
203207
}
204208
}
205209
}

Diff for: Tests/GRPCTests/WebCORSHandlerTests.swift

+44
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,50 @@ internal final class WebCORSHandlerTests: XCTestCase {
9393
try self.runPreflightRequestTest(spec: spec)
9494
}
9595

96+
func testOptionsPreflightOriginBased() throws {
97+
let spec = PreflightRequestSpec(
98+
configuration: .init(
99+
allowedOrigins: .originBased,
100+
allowedHeaders: ["x-grpc-web"],
101+
allowCredentialedRequests: false,
102+
preflightCacheExpiration: 60
103+
),
104+
requestOrigin: "foo",
105+
expectOrigin: "foo",
106+
expectAllowedHeaders: ["x-grpc-web"],
107+
expectAllowCredentials: false,
108+
expectMaxAge: "60"
109+
)
110+
try self.runPreflightRequestTest(spec: spec)
111+
}
112+
113+
func testOptionsPreflightCustom() throws {
114+
struct Wrapper: GRPCCustomCORSAllowedOrigin {
115+
func check(origin: String) -> String? {
116+
if origin == "foo" {
117+
return "bar"
118+
} else {
119+
return nil
120+
}
121+
}
122+
}
123+
124+
let spec = PreflightRequestSpec(
125+
configuration: .init(
126+
allowedOrigins: .custom(Wrapper()),
127+
allowedHeaders: ["x-grpc-web"],
128+
allowCredentialedRequests: false,
129+
preflightCacheExpiration: 60
130+
),
131+
requestOrigin: "foo",
132+
expectOrigin: "bar",
133+
expectAllowedHeaders: ["x-grpc-web"],
134+
expectAllowCredentials: false,
135+
expectMaxAge: "60"
136+
)
137+
try self.runPreflightRequestTest(spec: spec)
138+
}
139+
96140
func testOptionsPreflightAllowSomeOrigins() throws {
97141
let spec = PreflightRequestSpec(
98142
configuration: .init(

0 commit comments

Comments
 (0)