Skip to content

Commit c9027f8

Browse files
committed
Allow for more CORS configuration
Motivation: We added some level of CORS configuration support in grpc#1583. This change adds further flexibility. Modifications: - Add an 'originBased' mode where the value of the origin header is returned in the response head. - Add a custom fallback where the user can specify a callback which is passed the value of the origin header and returns the value to return in the 'access-control-allow-origin' response header (or nil, if the origin is not allowed). Result: More flexibility for CORS.
1 parent 75b390e commit c9027f8

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

Diff for: Sources/GRPC/Server.swift

+52
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,44 @@ extension Server.Configuration.CORS {
479479
public struct AllowedOrigins: Hashable, Sendable {
480480
enum Wrapped: Hashable, Sendable {
481481
case all
482+
case originBased
482483
case only([String])
484+
case custom(@Sendable (String) -> String?)
485+
486+
// Mistakes were made! CORS and AllowedOrigins were made 'Hashable' before the 'custom'
487+
// option was added. Removing the conformance is an API breaking change so we have to
488+
// provide a best effort implementation for Hashable and Eqatable.
489+
490+
static func == (lhs: Self, rhs: Self) -> Bool {
491+
switch (lhs, rhs) {
492+
case (.all, .all),
493+
(.originBased, .originBased):
494+
return true
495+
case let (.only(onlyLHS), .only(onlyRHS)):
496+
return onlyLHS == onlyRHS
497+
case (.custom, .custom):
498+
return true
499+
case (.all, _),
500+
(.originBased, _),
501+
(.only, _),
502+
(.custom, _):
503+
return false
504+
}
505+
}
506+
507+
func hash(into hasher: inout Hasher) {
508+
switch self {
509+
case .all:
510+
hasher.combine(0)
511+
case .originBased:
512+
hasher.combine(1)
513+
case let .only(only):
514+
hasher.combine(2)
515+
hasher.combine(only)
516+
case .custom:
517+
hasher.combine(4)
518+
}
519+
}
483520
}
484521

485522
private(set) var wrapped: Wrapped
@@ -490,10 +527,25 @@ extension Server.Configuration.CORS {
490527
/// Allow all origin values.
491528
public static let all = Self(.all)
492529

530+
/// Allow all origin values; similar to `all` but returns the value of the origin header field
531+
/// in the 'access-control-allow-origin' response header (rather than "*").
532+
public static let originBased = Self(.originBased)
533+
493534
/// Allow only the given origin values.
494535
public static func only(_ allowed: [String]) -> Self {
495536
return Self(.only(allowed))
496537
}
538+
539+
/// Provide a custom CORS origin check.
540+
///
541+
/// - Parameter checkOrigin: A closure which is called with the value of the 'origin' header
542+
/// and returns the value to use in the 'access-control-allow-origin' response header,
543+
/// or `nil` if the origin is not allowed.
544+
public static func custom(
545+
_ checkOrigin: @escaping @Sendable (_ origin: String) -> String?
546+
) -> Self {
547+
return Self(.custom(checkOrigin))
548+
}
497549
}
498550
}
499551

Diff for: Sources/GRPC/WebCORSHandler.swift

+4
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,12 @@ extension Server.Configuration.CORS.AllowedOrigins {
198198
switch self.wrapped {
199199
case .all:
200200
return "*"
201+
case .originBased:
202+
return origin
201203
case let .only(allowed):
202204
return allowed.contains(origin) ? origin : nil
205+
case let .custom(custom):
206+
return custom(origin)
203207
}
204208
}
205209
}

Diff for: Tests/GRPCTests/WebCORSHandlerTests.swift

+40
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,46 @@ internal final class WebCORSHandlerTests: XCTestCase {
9393
try self.runPreflightRequestTest(spec: spec)
9494
}
9595

96+
func testOptionsPreflightOriginBased() throws {
97+
let spec = PreflightRequestSpec(
98+
configuration: .init(
99+
allowedOrigins: .originBased,
100+
allowedHeaders: ["x-grpc-web"],
101+
allowCredentialedRequests: false,
102+
preflightCacheExpiration: 60
103+
),
104+
requestOrigin: "foo",
105+
expectOrigin: "foo",
106+
expectAllowedHeaders: ["x-grpc-web"],
107+
expectAllowCredentials: false,
108+
expectMaxAge: "60"
109+
)
110+
try self.runPreflightRequestTest(spec: spec)
111+
}
112+
113+
func testOptionsPreflightCustom() throws {
114+
let spec = PreflightRequestSpec(
115+
configuration: .init(
116+
allowedOrigins: .custom { origin in
117+
if origin == "foo" {
118+
return "bar"
119+
} else {
120+
return nil
121+
}
122+
},
123+
allowedHeaders: ["x-grpc-web"],
124+
allowCredentialedRequests: false,
125+
preflightCacheExpiration: 60
126+
),
127+
requestOrigin: "foo",
128+
expectOrigin: "bar",
129+
expectAllowedHeaders: ["x-grpc-web"],
130+
expectAllowCredentials: false,
131+
expectMaxAge: "60"
132+
)
133+
try self.runPreflightRequestTest(spec: spec)
134+
}
135+
96136
func testOptionsPreflightAllowSomeOrigins() throws {
97137
let spec = PreflightRequestSpec(
98138
configuration: .init(

0 commit comments

Comments
 (0)