Skip to content

Commit ce64704

Browse files
glbrnttWendellXY
authored andcommitted
Allow for more CORS configuration (grpc#1594)
Motivation: We added some level of CORS configuration support in grpc#1583. This change adds further flexibility. Modifications: - Add an 'originBased' mode where the value of the origin header is returned in the response head. - Add a custom fallback where the user can specify a callback which is passed the value of the origin header and returns the value to return in the 'access-control-allow-origin' response header (or nil, if the origin is not allowed). Result: More flexibility for CORS.
1 parent 1b93013 commit ce64704

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

Sources/GRPC/Server.swift

+58
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,9 @@ extension Server.Configuration.CORS {
489489
public struct AllowedOrigins: Hashable, Sendable {
490490
enum Wrapped: Hashable, Sendable {
491491
case all
492+
case originBased
492493
case only([String])
494+
case custom(AnyCustomCORSAllowedOrigin)
493495
}
494496

495497
private(set) var wrapped: Wrapped
@@ -500,10 +502,23 @@ extension Server.Configuration.CORS {
500502
/// Allow all origin values.
501503
public static let all = Self(.all)
502504

505+
/// Allow all origin values; similar to `all` but returns the value of the origin header field
506+
/// in the 'access-control-allow-origin' response header (rather than "*").
507+
public static let originBased = Self(.originBased)
508+
503509
/// Allow only the given origin values.
504510
public static func only(_ allowed: [String]) -> Self {
505511
return Self(.only(allowed))
506512
}
513+
514+
/// Provide a custom CORS origin check.
515+
///
516+
/// - Parameter checkOrigin: A closure which is called with the value of the 'origin' header
517+
/// and returns the value to use in the 'access-control-allow-origin' response header,
518+
/// or `nil` if the origin is not allowed.
519+
public static func custom<C: GRPCCustomCORSAllowedOrigin>(_ custom: C) -> Self {
520+
return Self(.custom(AnyCustomCORSAllowedOrigin(custom)))
521+
}
507522
}
508523
}
509524

@@ -530,3 +545,46 @@ extension Comparable {
530545
return min(max(self, range.lowerBound), range.upperBound)
531546
}
532547
}
548+
549+
public protocol GRPCCustomCORSAllowedOrigin: Sendable, Hashable {
550+
/// Returns the value to use for the 'access-control-allow-origin' response header for the given
551+
/// value of the 'origin' request header.
552+
///
553+
/// - Parameter origin: The value of the 'origin' request header field.
554+
/// - Returns: The value to use for the 'access-control-allow-origin' header field or `nil` if no
555+
/// CORS related headers should be returned.
556+
func check(origin: String) -> String?
557+
}
558+
559+
extension Server.Configuration.CORS.AllowedOrigins {
560+
struct AnyCustomCORSAllowedOrigin: GRPCCustomCORSAllowedOrigin {
561+
private var checkOrigin: @Sendable (String) -> String?
562+
private let hashInto: @Sendable (inout Hasher) -> Void
563+
#if swift(>=5.7)
564+
private let isEqualTo: @Sendable (any GRPCCustomCORSAllowedOrigin) -> Bool
565+
#else
566+
private let isEqualTo: @Sendable (Any) -> Bool
567+
#endif
568+
569+
init<W: GRPCCustomCORSAllowedOrigin>(_ wrap: W) {
570+
self.checkOrigin = { wrap.check(origin: $0) }
571+
self.hashInto = { wrap.hash(into: &$0) }
572+
self.isEqualTo = { wrap == ($0 as? W) }
573+
}
574+
575+
func check(origin: String) -> String? {
576+
return self.checkOrigin(origin)
577+
}
578+
579+
func hash(into hasher: inout Hasher) {
580+
self.hashInto(&hasher)
581+
}
582+
583+
static func == (
584+
lhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin,
585+
rhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin
586+
) -> Bool {
587+
return lhs.isEqualTo(rhs)
588+
}
589+
}
590+
}

Sources/GRPC/WebCORSHandler.swift

+4
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,12 @@ extension Server.Configuration.CORS.AllowedOrigins {
198198
switch self.wrapped {
199199
case .all:
200200
return "*"
201+
case .originBased:
202+
return origin
201203
case let .only(allowed):
202204
return allowed.contains(origin) ? origin : nil
205+
case let .custom(custom):
206+
return custom.check(origin: origin)
203207
}
204208
}
205209
}

Tests/GRPCTests/WebCORSHandlerTests.swift

+44
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,50 @@ internal final class WebCORSHandlerTests: XCTestCase {
9393
try self.runPreflightRequestTest(spec: spec)
9494
}
9595

96+
func testOptionsPreflightOriginBased() throws {
97+
let spec = PreflightRequestSpec(
98+
configuration: .init(
99+
allowedOrigins: .originBased,
100+
allowedHeaders: ["x-grpc-web"],
101+
allowCredentialedRequests: false,
102+
preflightCacheExpiration: 60
103+
),
104+
requestOrigin: "foo",
105+
expectOrigin: "foo",
106+
expectAllowedHeaders: ["x-grpc-web"],
107+
expectAllowCredentials: false,
108+
expectMaxAge: "60"
109+
)
110+
try self.runPreflightRequestTest(spec: spec)
111+
}
112+
113+
func testOptionsPreflightCustom() throws {
114+
struct Wrapper: GRPCCustomCORSAllowedOrigin {
115+
func check(origin: String) -> String? {
116+
if origin == "foo" {
117+
return "bar"
118+
} else {
119+
return nil
120+
}
121+
}
122+
}
123+
124+
let spec = PreflightRequestSpec(
125+
configuration: .init(
126+
allowedOrigins: .custom(Wrapper()),
127+
allowedHeaders: ["x-grpc-web"],
128+
allowCredentialedRequests: false,
129+
preflightCacheExpiration: 60
130+
),
131+
requestOrigin: "foo",
132+
expectOrigin: "bar",
133+
expectAllowedHeaders: ["x-grpc-web"],
134+
expectAllowCredentials: false,
135+
expectMaxAge: "60"
136+
)
137+
try self.runPreflightRequestTest(spec: spec)
138+
}
139+
96140
func testOptionsPreflightAllowSomeOrigins() throws {
97141
let spec = PreflightRequestSpec(
98142
configuration: .init(

0 commit comments

Comments
 (0)