Skip to content

Commit d8fe615

Browse files
glbrnttLukasa
authored andcommitted
Allow CORS to be configured for gRPC Web (grpc#1583)
Motivation: The WebCORS handler unconditionally sets "Access-Control-Allow-Origin" to "*" in response headers regardless of whether the request is a CORS request or whether the client sends credentials. Moreover we don't expose any knobs to control how CORS is configured. Modifications: - Add CORS configuration to the server and server builder - Let the allowed origins be '.any' (i.e. '*") or '.only' (limited to the provided origins) - Let the user configure what headers are permitted in responses. - Let the user configure whether credentialed requests are accepted. Result: More control over CORS Co-authored-by: Cory Benfield <[email protected]>
1 parent 23575f0 commit d8fe615

File tree

5 files changed

+559
-59
lines changed

5 files changed

+559
-59
lines changed

Sources/GRPC/GRPCServerPipelineConfigurator.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
191191
// we'll be on the right event loop and sync operations are fine.
192192
let sync = context.pipeline.syncOperations
193193
try sync.configureHTTPServerPipeline(withErrorHandling: true)
194-
try sync.addHandler(WebCORSHandler())
194+
try sync.addHandler(WebCORSHandler(configuration: self.configuration.webCORS))
195195
let scheme = self.configuration.tlsConfiguration == nil ? "http" : "https"
196196
try sync.addHandler(GRPCWebToHTTP2ServerCodec(scheme: scheme))
197197
// There's no need to normalize headers for HTTP/1.

Sources/GRPC/Server.swift

+55-5
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ extension Server {
367367
/// the need to recalculate this dictionary each time we receive an rpc.
368368
internal var serviceProvidersByName: [Substring: CallHandlerProvider]
369369

370+
/// CORS configuration for gRPC-Web support.
371+
public var webCORS = Configuration.CORS()
372+
370373
#if canImport(NIOSSL)
371374
/// Create a `Configuration` with some pre-defined defaults.
372375
///
@@ -401,11 +404,9 @@ extension Server {
401404
) {
402405
self.target = target
403406
self.eventLoopGroup = eventLoopGroup
404-
self
405-
.serviceProvidersByName = Dictionary(
406-
uniqueKeysWithValues: serviceProviders
407-
.map { ($0.serviceName, $0) }
408-
)
407+
self.serviceProvidersByName = Dictionary(
408+
uniqueKeysWithValues: serviceProviders.map { ($0.serviceName, $0) }
409+
)
409410
self.errorDelegate = errorDelegate
410411
self.tlsConfiguration = tls.map { GRPCTLSConfiguration(transforming: $0) }
411412
self.connectionKeepalive = connectionKeepalive
@@ -451,6 +452,55 @@ extension Server {
451452
}
452453
}
453454

455+
extension Server.Configuration {
456+
public struct CORS: Hashable, GRPCSendable {
457+
/// Determines which 'origin' header field values are permitted in a CORS request.
458+
public var allowedOrigins: AllowedOrigins
459+
/// Sets the headers which are permitted in a response to a CORS request.
460+
public var allowedHeaders: [String]
461+
/// Enabling this value allows sets the "access-control-allow-credentials" header field
462+
/// to "true" in respones to CORS requests. This must be enabled if the client intends to send
463+
/// credentials.
464+
public var allowCredentialedRequests: Bool
465+
/// The maximum age in seconds which pre-flight CORS requests may be cached for.
466+
public var preflightCacheExpiration: Int
467+
468+
public init(
469+
allowedOrigins: AllowedOrigins = .all,
470+
allowedHeaders: [String] = ["content-type", "x-grpc-web", "x-user-agent"],
471+
allowCredentialedRequests: Bool = false,
472+
preflightCacheExpiration: Int = 86400
473+
) {
474+
self.allowedOrigins = allowedOrigins
475+
self.allowedHeaders = allowedHeaders
476+
self.allowCredentialedRequests = allowCredentialedRequests
477+
self.preflightCacheExpiration = preflightCacheExpiration
478+
}
479+
}
480+
}
481+
482+
extension Server.Configuration.CORS {
483+
public struct AllowedOrigins: Hashable, Sendable {
484+
enum Wrapped: Hashable, Sendable {
485+
case all
486+
case only([String])
487+
}
488+
489+
private(set) var wrapped: Wrapped
490+
private init(_ wrapped: Wrapped) {
491+
self.wrapped = wrapped
492+
}
493+
494+
/// Allow all origin values.
495+
public static let all = Self(.all)
496+
497+
/// Allow only the given origin values.
498+
public static func only(_ allowed: [String]) -> Self {
499+
return Self(.only(allowed))
500+
}
501+
}
502+
}
503+
454504
extension ServerBootstrapProtocol {
455505
fileprivate func bind(to target: BindTarget) -> EventLoopFuture<Channel> {
456506
switch target.wrapped {

Sources/GRPC/ServerBuilder.swift

+9
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@ extension Server.Builder {
165165
}
166166
}
167167

168+
extension Server.Builder {
169+
/// Set the CORS configuration for gRPC Web.
170+
@discardableResult
171+
public func withCORSConfiguration(_ configuration: Server.Configuration.CORS) -> Self {
172+
self.configuration.webCORS = configuration
173+
return self
174+
}
175+
}
176+
168177
extension Server.Builder {
169178
/// Sets the root server logger. Accepted connections will branch from this logger and RPCs on
170179
/// each connection will use a logger branched from the connections logger. This logger is made

Sources/GRPC/WebCORSHandler.swift

+160-53
Original file line numberDiff line numberDiff line change
@@ -17,54 +17,110 @@ import NIOCore
1717
import NIOHTTP1
1818

1919
/// Handler that manages the CORS protocol for requests incoming from the browser.
20-
internal class WebCORSHandler {
21-
var requestMethod: HTTPMethod?
20+
internal final class WebCORSHandler {
21+
let configuration: Server.Configuration.CORS
22+
23+
private var state: State = .idle
24+
private enum State: Equatable {
25+
/// Starting state.
26+
case idle
27+
/// CORS preflight request is in progress.
28+
case processingPreflightRequest
29+
/// "Real" request is in progress.
30+
case processingRequest(origin: String?)
31+
}
32+
33+
init(configuration: Server.Configuration.CORS) {
34+
self.configuration = configuration
35+
}
2236
}
2337

2438
extension WebCORSHandler: ChannelInboundHandler {
2539
typealias InboundIn = HTTPServerRequestPart
40+
typealias InboundOut = HTTPServerRequestPart
2641
typealias OutboundOut = HTTPServerResponsePart
2742

2843
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
29-
// If the request is OPTIONS, the request is not propagated further.
3044
switch self.unwrapInboundIn(data) {
31-
case let .head(requestHead):
32-
self.requestMethod = requestHead.method
33-
if self.requestMethod == .OPTIONS {
34-
var headers = HTTPHeaders()
35-
headers.add(name: "Access-Control-Allow-Origin", value: "*")
36-
headers.add(name: "Access-Control-Allow-Methods", value: "POST")
37-
headers.add(
38-
name: "Access-Control-Allow-Headers",
39-
value: "content-type,x-grpc-web,x-user-agent"
40-
)
41-
headers.add(name: "Access-Control-Max-Age", value: "86400")
42-
context.write(
43-
self.wrapOutboundOut(.head(HTTPResponseHead(
44-
version: requestHead.version,
45-
status: .ok,
46-
headers: headers
47-
))),
48-
promise: nil
49-
)
50-
return
45+
case let .head(head):
46+
self.receivedRequestHead(context: context, head)
47+
48+
case let .body(body):
49+
self.receivedRequestBody(context: context, body)
50+
51+
case let .end(trailers):
52+
self.receivedRequestEnd(context: context, trailers)
53+
}
54+
}
55+
56+
private func receivedRequestHead(context: ChannelHandlerContext, _ head: HTTPRequestHead) {
57+
if head.method == .OPTIONS,
58+
head.headers.contains(.accessControlRequestMethod),
59+
let origin = head.headers.first(name: "origin") {
60+
// If the request is OPTIONS with a access-control-request-method header it's a CORS
61+
// preflight request and is not propagated further.
62+
self.state = .processingPreflightRequest
63+
self.handlePreflightRequest(context: context, head: head, origin: origin)
64+
} else {
65+
self.state = .processingRequest(origin: head.headers.first(name: "origin"))
66+
context.fireChannelRead(self.wrapInboundOut(.head(head)))
67+
}
68+
}
69+
70+
private func receivedRequestBody(context: ChannelHandlerContext, _ body: ByteBuffer) {
71+
// OPTIONS requests do not have a body, but still handle this case to be
72+
// cautious.
73+
if self.state == .processingPreflightRequest {
74+
return
75+
}
76+
77+
context.fireChannelRead(self.wrapInboundOut(.body(body)))
78+
}
79+
80+
private func receivedRequestEnd(context: ChannelHandlerContext, _ trailers: HTTPHeaders?) {
81+
if self.state == .processingPreflightRequest {
82+
// End of OPTIONS request; reset state and finish the response.
83+
self.state = .idle
84+
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
85+
} else {
86+
context.fireChannelRead(self.wrapInboundOut(.end(trailers)))
87+
}
88+
}
89+
90+
private func handlePreflightRequest(
91+
context: ChannelHandlerContext,
92+
head: HTTPRequestHead,
93+
origin: String
94+
) {
95+
let responseHead: HTTPResponseHead
96+
97+
if let allowedOrigin = self.configuration.allowedOrigins.header(origin) {
98+
var headers = HTTPHeaders()
99+
headers.reserveCapacity(4 + self.configuration.allowedHeaders.count)
100+
headers.add(name: .accessControlAllowOrigin, value: allowedOrigin)
101+
headers.add(name: .accessControlAllowMethods, value: "POST")
102+
103+
for value in self.configuration.allowedHeaders {
104+
headers.add(name: .accessControlAllowHeaders, value: value)
51105
}
52-
case .body:
53-
if self.requestMethod == .OPTIONS {
54-
// OPTIONS requests do not have a body, but still handle this case to be
55-
// cautious.
56-
return
106+
107+
if self.configuration.allowCredentialedRequests {
108+
headers.add(name: .accessControlAllowCredentials, value: "true")
57109
}
58110

59-
case .end:
60-
if self.requestMethod == .OPTIONS {
61-
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
62-
self.requestMethod = nil
63-
return
111+
if self.configuration.preflightCacheExpiration > 0 {
112+
headers.add(
113+
name: .accessControlMaxAge,
114+
value: "\(self.configuration.preflightCacheExpiration)"
115+
)
64116
}
117+
responseHead = HTTPResponseHead(version: head.version, status: .ok, headers: headers)
118+
} else {
119+
// Not allowed; respond with 403. This is okay in a pre-flight request.
120+
responseHead = HTTPResponseHead(version: head.version, status: .forbidden)
65121
}
66-
// The OPTIONS request should be fully handled at this point.
67-
context.fireChannelRead(data)
122+
123+
context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil)
68124
}
69125
}
70126

@@ -74,25 +130,76 @@ extension WebCORSHandler: ChannelOutboundHandler {
74130
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
75131
let responsePart = self.unwrapOutboundIn(data)
76132
switch responsePart {
77-
case let .head(responseHead):
78-
var headers = responseHead.headers
79-
// CORS requires all requests to have an Allow-Origin header.
80-
headers.add(name: "Access-Control-Allow-Origin", value: "*")
81-
//! FIXME: Check whether we can let browsers keep connections alive. It's not possible
82-
// now as the channel has a state that can't be reused since the pipeline is modified to
83-
// inject the gRPC call handler.
84-
headers.add(name: "Connection", value: "close")
85-
86-
context.write(
87-
self.wrapOutboundOut(.head(HTTPResponseHead(
88-
version: responseHead.version,
89-
status: responseHead.status,
90-
headers: headers
91-
))),
92-
promise: promise
93-
)
94-
default:
133+
case var .head(responseHead):
134+
switch self.state {
135+
case let .processingRequest(origin):
136+
self.prepareCORSResponseHead(&responseHead, origin: origin)
137+
context.write(self.wrapOutboundOut(.head(responseHead)), promise: promise)
138+
139+
case .idle, .processingPreflightRequest:
140+
assertionFailure("Writing response head when no request is in progress")
141+
context.close(promise: nil)
142+
}
143+
144+
case .body:
145+
context.write(data, promise: promise)
146+
147+
case .end:
148+
self.state = .idle
95149
context.write(data, promise: promise)
96150
}
97151
}
152+
153+
private func prepareCORSResponseHead(_ head: inout HTTPResponseHead, origin: String?) {
154+
guard let header = origin.flatMap({ self.configuration.allowedOrigins.header($0) }) else {
155+
// No origin or the origin is not allowed; don't treat it as a CORS request.
156+
return
157+
}
158+
159+
head.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: header)
160+
161+
if self.configuration.allowCredentialedRequests {
162+
head.headers.add(name: .accessControlAllowCredentials, value: "true")
163+
}
164+
165+
//! FIXME: Check whether we can let browsers keep connections alive. It's not possible
166+
// now as the channel has a state that can't be reused since the pipeline is modified to
167+
// inject the gRPC call handler.
168+
head.headers.replaceOrAdd(name: "Connection", value: "close")
169+
}
170+
}
171+
172+
extension HTTPHeaders {
173+
fileprivate enum CORSHeader: String {
174+
case accessControlRequestMethod = "access-control-request-method"
175+
case accessControlRequestHeaders = "access-control-request-headers"
176+
case accessControlAllowOrigin = "access-control-allow-origin"
177+
case accessControlAllowMethods = "access-control-allow-methods"
178+
case accessControlAllowHeaders = "access-control-allow-headers"
179+
case accessControlAllowCredentials = "access-control-allow-credentials"
180+
case accessControlMaxAge = "access-control-max-age"
181+
}
182+
183+
fileprivate func contains(_ name: CORSHeader) -> Bool {
184+
return self.contains(name: name.rawValue)
185+
}
186+
187+
fileprivate mutating func add(name: CORSHeader, value: String) {
188+
self.add(name: name.rawValue, value: value)
189+
}
190+
191+
fileprivate mutating func replaceOrAdd(name: CORSHeader, value: String) {
192+
self.replaceOrAdd(name: name.rawValue, value: value)
193+
}
194+
}
195+
196+
extension Server.Configuration.CORS.AllowedOrigins {
197+
internal func header(_ origin: String) -> String? {
198+
switch self.wrapped {
199+
case .all:
200+
return "*"
201+
case let .only(allowed):
202+
return allowed.contains(origin) ? origin : nil
203+
}
204+
}
98205
}

0 commit comments

Comments
 (0)